Compare commits
24 Commits
62c671e07b
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 15b2efe52a | |||
| 6c96c43bea | |||
| 9185f26b83 | |||
| f4052d229a | |||
| bc2182720d | |||
| eddf38d316 | |||
| 03e768d232 | |||
| bcb0a9b322 | |||
| 743a2c9d65 | |||
| 3639a31163 | |||
| 7200df6d87 | |||
| 40b6a31c98 | |||
| 19837b4817 | |||
| b5d09009e3 | |||
| 0b521ae8ab | |||
| eac0766e79 | |||
| 53b757de7a | |||
| 69f852a4ce | |||
| 800c85bf8d | |||
| 729773cae3 | |||
| d831c9c0d6 | |||
| 4c1b7a8aad | |||
| a99091ea7a | |||
| 209cb24ab4 |
@@ -3,7 +3,10 @@
|
||||
"allow": [
|
||||
"Bash(git rev-parse:*)",
|
||||
"Bash(findstr:*)",
|
||||
"Bash(find:*)"
|
||||
"Bash(find:*)",
|
||||
"Bash(yarn tsc:*)",
|
||||
"Bash(dir:*)",
|
||||
"mcp__server-notify__notify"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
37
.dockerignore
Normal file
37
.dockerignore
Normal file
@@ -0,0 +1,37 @@
|
||||
.git/
|
||||
.gitignore
|
||||
.github/
|
||||
.idea/
|
||||
.vscode/
|
||||
.venv/
|
||||
.env
|
||||
.env.*
|
||||
.run/
|
||||
.claude/
|
||||
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
||||
tests/
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
coverage.xml
|
||||
|
||||
*.db
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
*.log
|
||||
logs/
|
||||
data/
|
||||
|
||||
Dockerfile
|
||||
.dockerignore
|
||||
|
||||
# Cython 编译产物
|
||||
*.c
|
||||
build/
|
||||
|
||||
# 许可证私钥和工具脚本
|
||||
license_private.pem
|
||||
scripts/
|
||||
31
.gitea/workflows/test.yml
Normal file
31
.gitea/workflows/test.yml
Normal file
@@ -0,0 +1,31 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
image: ghcr.io/catthehacker/ubuntu:act-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests/ -v --tb=short
|
||||
14
.github/copilot-instructions.md
vendored
14
.github/copilot-instructions.md
vendored
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
29
.github/workflows/test.yml
vendored
Normal file
29
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
name: Test
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.13"
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest tests/ -v --tb=short
|
||||
15
.gitignore
vendored
15
.gitignore
vendored
@@ -1,8 +1,6 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.so
|
||||
*.egg
|
||||
*.egg-info/
|
||||
@@ -69,3 +67,16 @@ data/
|
||||
# JB 的运行配置(换设备用不了)
|
||||
.run/
|
||||
.xml
|
||||
|
||||
# 前端构建产物(Docker 构建时复制)
|
||||
statics/
|
||||
|
||||
# Cython 编译产物
|
||||
*.c
|
||||
|
||||
# 许可证密钥(保密)
|
||||
license_private.pem
|
||||
license.key
|
||||
|
||||
avatar/
|
||||
.dev/
|
||||
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "ee"]
|
||||
path = ee
|
||||
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git
|
||||
14
AGENTS.md
14
AGENTS.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
14
CLAUDE.md
14
CLAUDE.md
@@ -449,13 +449,13 @@ return device # 此时device已过期
|
||||
```python
|
||||
import asyncio
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import UUIDTableBase, SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
|
||||
|
||||
class CharacterBase(SQLModelBase):
|
||||
name: str
|
||||
"""角色名称"""
|
||||
|
||||
class Character(CharacterBase, UUIDTableBase):
|
||||
class Character(CharacterBase, UUIDTableBaseMixin):
|
||||
"""充血模型:包含数据和业务逻辑"""
|
||||
|
||||
# ==================== 运行时属性(在model_post_init初始化) ====================
|
||||
@@ -570,11 +570,11 @@ async with character.init(session):
|
||||
from abc import ABC, abstractmethod
|
||||
from uuid import UUID
|
||||
from sqlmodel import Field
|
||||
from sqlmodels.base import (
|
||||
from sqlmodel_ext import (
|
||||
SQLModelBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
|
||||
# 2. 定义抽象父类(有表)
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBase,
|
||||
UUIDTableBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='__polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
|
||||
# 3. 本地应用导入(从项目根目录的包开始)
|
||||
from dependencies import SessionDep
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.base import UUIDTableBase
|
||||
from sqlmodel_ext import UUIDTableBaseMixin
|
||||
|
||||
# 4. 相对导入(同包内的模块)
|
||||
from .base import BaseClass
|
||||
|
||||
45
Dockerfile
45
Dockerfile
@@ -1,13 +1,52 @@
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim
|
||||
# ============================================================
|
||||
# 基础层:安装运行时依赖
|
||||
# ============================================================
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
COPY . .
|
||||
|
||||
EXPOSE 5213
|
||||
# ============================================================
|
||||
# Community 版本:删除 ee/ 目录
|
||||
# ============================================================
|
||||
FROM base AS community
|
||||
|
||||
RUN rm -rf ee/
|
||||
COPY statics/ /app/statics/
|
||||
|
||||
EXPOSE 5213
|
||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||
|
||||
# ============================================================
|
||||
# Pro 编译层:Cython 编译 ee/ 模块
|
||||
# ============================================================
|
||||
FROM base AS pro-builder
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends gcc libc6-dev && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN uv sync --frozen --no-dev --extra build
|
||||
|
||||
RUN uv run python setup_cython.py build_ext --inplace && \
|
||||
uv run python setup_cython.py clean_artifacts
|
||||
|
||||
# ============================================================
|
||||
# Pro 版本:包含编译后的 ee/ 模块(仅 __init__.py + .so)
|
||||
# ============================================================
|
||||
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS pro
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY pyproject.toml uv.lock ./
|
||||
RUN uv sync --frozen --no-dev
|
||||
|
||||
COPY --from=pro-builder /app/ /app/
|
||||
COPY statics/ /app/statics/
|
||||
|
||||
EXPOSE 5213
|
||||
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
|
||||
@@ -229,6 +229,12 @@ pytest tests/integration
|
||||
pytest --cov
|
||||
```
|
||||
|
||||
## 忘记密码
|
||||
|
||||
将密码字段设置为 `$argon2id$v=19$m=65536,t=3,p=4$09YTQpkw7eS4qW732OazkQ$Szzbi3VIaJXBJ02rkVKrSFCAKHjRTl+EQWk4PNxCYFI`
|
||||
|
||||
密码即可重设为 `11223344`
|
||||
|
||||
## 开发规范
|
||||
|
||||
详细的开发规范请参阅 [CLAUDE.md](CLAUDE.md),主要包括:
|
||||
|
||||
92
docs/CLA.md
Normal file
92
docs/CLA.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# DiskNext Contributor License Agreement
|
||||
|
||||
Thank you for your interest in contributing to the DiskNext project ("We", "Us", or "Our"). This Contributor License Agreement ("Agreement") is for our mutual protection. It clarifies the intellectual property rights You grant to Us for Your Contributions.
|
||||
|
||||
By signing this Agreement, You accept its terms and conditions.
|
||||
|
||||
## 1. The Purpose of This Agreement
|
||||
|
||||
The DiskNext project is developed with a dual-licensing strategy. We maintain a free, open-source community edition alongside a commercial Pro edition. This model allows Us to support a vibrant community while also funding the project's sustainable development.
|
||||
|
||||
To make this model work, We require broad rights to use the code You contribute. This Agreement ensures that We can include Your Contributions in all editions of DiskNext under their respective licenses. By signing this Agreement, You grant Us the rights needed to manage the project effectively, including the right to incorporate Your Contribution into Our commercial products and to transfer the project to another entity.
|
||||
|
||||
## 2. Definitions
|
||||
|
||||
**"You"** means the individual copyright owner who Submits a Contribution to Us.
|
||||
|
||||
**"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that you intentionally Submit to Us for inclusion in the Material.
|
||||
|
||||
**"Material"** means the software and documentation We make available to third parties. Your Contribution may be included in the Material.
|
||||
|
||||
**"Submit"** means any form of communication sent to Us (e.g., via a pull request, issue tracker, or email) that is managed by Us for the purpose of discussing and improving the Material, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
|
||||
**"Copyright"** means all rights protecting works of authorship, including copyright, moral rights, and neighboring rights, for the full term of their existence.
|
||||
|
||||
## 3. Copyright License Grant
|
||||
|
||||
Subject to the terms and conditions of this Agreement, You hereby grant to Us a worldwide, royalty-free, **non-exclusive**, perpetual, and irrevocable license under the Copyright covering your Contribution. This license includes the right to sublicense and to assign Your Contribution.
|
||||
|
||||
This license allows Us to use, reproduce, prepare derivative works of, publicly display, publicly perform, distribute, and publish your Contribution and such derivative works in any form. This includes, without limitation, the right to sell and distribute the Contribution as part of a commercial product under a proprietary license.
|
||||
|
||||
You retain full ownership of the Copyright in Your Contribution. Nothing in this Agreement shall be construed to restrict or transfer Your rights to use Your own Contribution for any purpose.
|
||||
|
||||
## 4. Patent License Grant
|
||||
|
||||
You hereby grant to Us and to recipients of the Material a worldwide, royalty-free, non-exclusive, perpetual, and irrevocable patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer Your Contribution. This license applies to all patents owned or controlled by You, now or in the future, that would be infringed by Your Contribution alone or in combination with the Material.
|
||||
|
||||
## 5. Your Representations
|
||||
|
||||
You represent and warrant that:
|
||||
|
||||
1. The Contribution is Your original work.
|
||||
2. You are legally entitled to grant the licenses in this Agreement.
|
||||
3. If Your employer has rights to intellectual property that You create, You have either (i) received permission from Your employer to make the Contribution on behalf of that employer, or (ii) Your employer has waived such rights for the Contribution.
|
||||
4. To the best of Your knowledge, the Contribution does not violate any third-party rights, including copyright, patent, trademark, or trade secret.
|
||||
|
||||
You agree to notify Us of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect.
|
||||
|
||||
## 6. Our Licensing Rights
|
||||
|
||||
You acknowledge that We may license the Material, including Your Contribution, under different license terms. We intend to distribute a community edition of DiskNext under a free and open-source license. We also reserve the right to distribute a Pro edition and other commercial versions of the Material, including Your Contribution, under a proprietary license at Our sole discretion.
|
||||
|
||||
## 7. Disclaimer of Warranty
|
||||
|
||||
THE CONTRIBUTION IS PROVIDED "AS IS" AND WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
## 8. Limitation of Liability
|
||||
|
||||
TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL YOU OR WE BE LIABLE FOR ANY LOSS OF PROFITS, LOSS OF ANTICIPATED SAVINGS, LOSS OF DATA, OR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THIS AGREEMENT, REGARDLESS OF THE LEGAL THEORY UPON WHICH THE CLAIM IS BASED.
|
||||
|
||||
## 9. Term
|
||||
|
||||
This Agreement is effective on the date You accept it and shall continue for the full term of the copyrights and patents licensed herein. This Agreement is irrevocable.
|
||||
|
||||
## 10. Miscellaneous
|
||||
|
||||
**10.1 Governing Law:** This Agreement shall be governed by the laws of the People's Republic of China, excluding its conflict of law provisions.
|
||||
|
||||
**10.2 Entire Agreement:** This Agreement sets out the entire agreement between You and Us for Your Contributions and supersedes all prior communications and understandings.
|
||||
|
||||
**10.3 Assignment:** We may assign Our rights and obligations under this Agreement at Our sole discretion. This Agreement will be binding upon and will inure to the benefit of the parties, their successors, and permitted assigns.
|
||||
|
||||
**10.4 Severability:** If any provision of this Agreement is found to be void or unenforceable, it will be replaced with a provision that comes closest to the meaning of the original and is enforceable.
|
||||
|
||||
---
|
||||
|
||||
## To Accept This Agreement
|
||||
|
||||
Please provide the following information to signify your acceptance.
|
||||
|
||||
### Contributor ("You"):
|
||||
|
||||
- **Date:**
|
||||
- **Full Name:**
|
||||
- **Address:**
|
||||
- **Email:**
|
||||
- **GitHub Username (if applicable):**
|
||||
|
||||
### For DiskNext ("Us"):
|
||||
|
||||
- **Date:**
|
||||
- **[NAME]**
|
||||
- **Owner of DiskNext Org**
|
||||
594
docs/file-viewer-api.md
Normal file
594
docs/file-viewer-api.md
Normal file
@@ -0,0 +1,594 @@
|
||||
# 文件预览应用选择器 — 前端适配文档
|
||||
|
||||
## 概述
|
||||
|
||||
文件预览系统类似 Android 的"使用什么应用打开"机制:用户点击文件时,前端根据扩展名查询可用查看器列表,展示选择弹窗,用户可选"仅此一次"或"始终使用"。
|
||||
|
||||
### 应用类型
|
||||
|
||||
| type | 说明 | 前端处理方式 |
|
||||
|------|------|-------------|
|
||||
| `builtin` | 前端内置组件 | 根据 `app_key` 路由到内置组件(如 `pdfjs`、`monaco`) |
|
||||
| `iframe` | iframe 内嵌 | 将 `iframe_url_template` 中的 `{file_url}` 替换为文件下载 URL,嵌入 iframe |
|
||||
| `wopi` | WOPI 协议 | 调用 `/file/{id}/wopi-session` 获取 `editor_url`,嵌入 iframe |
|
||||
|
||||
### 内置 app_key 映射
|
||||
|
||||
前端需要为以下 `app_key` 实现对应的内置预览组件:
|
||||
|
||||
| app_key | 组件 | 说明 |
|
||||
|---------|------|------|
|
||||
| `pdfjs` | PDF.js 阅读器 | pdf |
|
||||
| `monaco` | Monaco Editor | txt, md, json, py, js, ts, html, css, ... |
|
||||
| `markdown` | Markdown 渲染器 | md, markdown, mdx |
|
||||
| `image_viewer` | 图片查看器 | jpg, png, gif, webp, svg, ... |
|
||||
| `video_player` | HTML5 Video | mp4, webm, ogg, mov, mkv, m3u8 |
|
||||
| `audio_player` | HTML5 Audio | mp3, wav, flac, aac, m4a, opus |
|
||||
|
||||
> `office_viewer`(iframe)、`collabora`(wopi)、`onlyoffice`(wopi)默认禁用,需管理员在后台启用和配置。
|
||||
|
||||
---
|
||||
|
||||
## 文件下载 URL 与 iframe 预览
|
||||
|
||||
### 现有下载流程(两步式)
|
||||
|
||||
```
|
||||
步骤1: POST /api/v1/file/download/{file_id} → { access_token, expires_in }
|
||||
步骤2: GET /api/v1/file/download/{access_token} → 文件二进制流
|
||||
```
|
||||
|
||||
- 步骤 1 需要 JWT 认证,返回一个下载令牌(有效期 1 小时)
|
||||
- 步骤 2 **不需要认证**,用令牌直接下载,**令牌为一次性**,下载后失效
|
||||
|
||||
### 各类型查看器获取文件内容的方式
|
||||
|
||||
| type | 获取文件方式 | 说明 |
|
||||
|------|-------------|------|
|
||||
| `builtin` | 前端自行获取 | 前端用 JS 调用下载接口拿到 Blob/ArrayBuffer,传给内置组件渲染 |
|
||||
| `iframe` | 需要公开可访问的 URL | 第三方服务(如 Office Online)会**从服务端拉取文件** |
|
||||
| `wopi` | WOPI 协议自动处理 | 编辑器通过 `/wopi/files/{id}/contents` 获取,前端只需嵌入 `editor_url` |
|
||||
|
||||
### builtin 类型 — 前端自行获取
|
||||
|
||||
内置组件(pdfjs、monaco 等)运行在前端,直接用 JS 获取文件内容即可:
|
||||
|
||||
```typescript
|
||||
// 方式 A:用下载令牌拼 URL(适用于 PDF.js 等需要 URL 的组件)
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
const fileUrl = `${baseUrl}/api/v1/file/download/${access_token}`
|
||||
// 传给 PDF.js: pdfjsLib.getDocument(fileUrl)
|
||||
|
||||
// 方式 B:用 fetch + Authorization 头获取 Blob(适用于需要 ArrayBuffer 的组件)
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
const blob = await fetch(`${baseUrl}/api/v1/file/download/${access_token}`).then(r => r.blob())
|
||||
// 传给 Monaco: monaco.editor.create(el, { value: await blob.text() })
|
||||
```
|
||||
|
||||
### iframe 类型 — `{file_url}` 替换规则
|
||||
|
||||
`iframe_url_template` 中的 `{file_url}` 需要替换为一个**外部可访问的文件直链**。
|
||||
|
||||
**问题**:当前下载令牌是一次性的,而 Office Online 等服务可能多次请求该 URL。
|
||||
|
||||
**当前可行方案**:
|
||||
|
||||
```typescript
|
||||
// 1. 创建下载令牌
|
||||
const { access_token } = await api.post(`/file/download/${fileId}`)
|
||||
|
||||
// 2. 拼出完整的文件 URL(必须是公网可达的地址)
|
||||
const fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
|
||||
// 3. 替换模板
|
||||
const iframeSrc = viewer.iframe_url_template.replace(
|
||||
'{file_url}',
|
||||
encodeURIComponent(fileUrl)
|
||||
)
|
||||
|
||||
// 4. 嵌入 iframe
|
||||
// <iframe src={iframeSrc} />
|
||||
```
|
||||
|
||||
> **已知限制**:下载令牌为一次性使用。如果第三方服务多次拉取文件(如 Office Online 可能重试),
|
||||
> 第二次请求会 404。后续版本将实现 `/file/get/{id}/{name}` 外链端点(多次可用),届时
|
||||
> iframe 应改用外链 URL。目前建议:
|
||||
>
|
||||
> 1. **优先使用 WOPI 类型**(Collabora/OnlyOffice),不存在此限制
|
||||
> 2. Office Online 预览在**文件较小**时通常只拉取一次,大多数场景可用
|
||||
> 3. 如需稳定方案,可等待外链端点实现后再启用 iframe 类型应用
|
||||
|
||||
### wopi 类型 — 无需关心文件 URL
|
||||
|
||||
WOPI 类型的查看器完全由后端处理文件传输,前端只需:
|
||||
|
||||
```typescript
|
||||
// 1. 创建 WOPI 会话
|
||||
const session = await api.post(`/file/${fileId}/wopi-session`)
|
||||
|
||||
// 2. 直接嵌入编辑器
|
||||
// <iframe src={session.editor_url} />
|
||||
```
|
||||
|
||||
编辑器(Collabora/OnlyOffice)会通过 WOPI 协议自动从 `/wopi/files/{id}/contents` 获取文件内容,使用 `access_token` 认证,前端无需干预。
|
||||
|
||||
---
|
||||
|
||||
## 用户端 API
|
||||
|
||||
### 1. 查询可用查看器
|
||||
|
||||
用户点击文件时调用,获取该扩展名的可用查看器列表。
|
||||
|
||||
```
|
||||
GET /api/v1/file/viewers?ext={extension}
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**Query 参数**
|
||||
|
||||
| 参数 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| ext | string | 是 | 文件扩展名,最长 20 字符。支持带点号(`.pdf`)、大写(`PDF`),后端会自动规范化 |
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"viewers": [
|
||||
{
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
],
|
||||
"default_viewer_id": null
|
||||
}
|
||||
```
|
||||
|
||||
**字段说明**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| viewers | FileAppSummary[] | 可用查看器列表,已按优先级排序 |
|
||||
| default_viewer_id | string \| null | 用户设置的"始终使用"查看器 UUID,未设置则为 null |
|
||||
|
||||
**FileAppSummary**
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| id | UUID | 应用 UUID |
|
||||
| name | string | 应用显示名称 |
|
||||
| app_key | string | 应用唯一标识,前端路由用 |
|
||||
| type | `"builtin"` \| `"iframe"` \| `"wopi"` | 应用类型 |
|
||||
| icon | string \| null | 图标名称(可映射到 icon library) |
|
||||
| description | string \| null | 应用描述 |
|
||||
| iframe_url_template | string \| null | iframe 类型专用,URL 模板含 `{file_url}` 占位符 |
|
||||
| wopi_editor_url_template | string \| null | wopi 类型专用,编辑器 URL 模板 |
|
||||
|
||||
---
|
||||
|
||||
### 2. 设置默认查看器("始终使用")
|
||||
|
||||
用户在选择弹窗中勾选"始终使用此应用"时调用。
|
||||
|
||||
```
|
||||
PUT /api/v1/user/settings/file-viewers/default
|
||||
Authorization: Bearer {token}
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
**请求体**
|
||||
|
||||
```json
|
||||
{
|
||||
"extension": "pdf",
|
||||
"app_id": "550e8400-e29b-41d4-a716-446655440000"
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 必填 | 说明 |
|
||||
|------|------|------|------|
|
||||
| extension | string | 是 | 文件扩展名(小写,无点号) |
|
||||
| app_id | UUID | 是 | 选择的查看器应用 UUID |
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||
"extension": "pdf",
|
||||
"app": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "基于 pdf.js 的 PDF 在线阅读器",
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 该应用不支持此扩展名 |
|
||||
| 404 | 应用不存在 |
|
||||
|
||||
> 同一扩展名只允许一个默认值。重复 PUT 同一 extension 会更新(upsert),不会冲突。
|
||||
|
||||
---
|
||||
|
||||
### 3. 列出所有默认查看器设置
|
||||
|
||||
用于用户设置页展示"已设为始终使用"的列表。
|
||||
|
||||
```
|
||||
GET /api/v1/user/settings/file-viewers/defaults
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"id": "660e8400-e29b-41d4-a716-446655440001",
|
||||
"extension": "pdf",
|
||||
"app": {
|
||||
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": null,
|
||||
"iframe_url_template": null,
|
||||
"wopi_editor_url_template": null
|
||||
}
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 4. 撤销默认查看器设置
|
||||
|
||||
用户在设置页点击"取消始终使用"时调用。
|
||||
|
||||
```
|
||||
DELETE /api/v1/user/settings/file-viewers/default/{id}
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应** 204 No Content
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 404 | 记录不存在或不属于当前用户 |
|
||||
|
||||
---
|
||||
|
||||
### 5. 创建 WOPI 会话
|
||||
|
||||
打开 WOPI 类型应用(如 Collabora、OnlyOffice)时调用。
|
||||
|
||||
```
|
||||
POST /api/v1/file/{file_id}/wopi-session
|
||||
Authorization: Bearer {token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"wopi_src": "http://localhost:8000/wopi/files/770e8400-e29b-41d4-a716-446655440002",
|
||||
"access_token": "eyJhbGciOiJIUzI1NiIs...",
|
||||
"access_token_ttl": 1739577600000,
|
||||
"editor_url": "http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc=...&access_token=...&access_token_ttl=..."
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| wopi_src | string | WOPI 源 URL(传给编辑器) |
|
||||
| access_token | string | WOPI 访问令牌 |
|
||||
| access_token_ttl | int | 令牌过期毫秒时间戳 |
|
||||
| editor_url | string | 完整的编辑器 URL,**直接嵌入 iframe 即可** |
|
||||
|
||||
**错误码**
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 文件无扩展名 / WOPI 应用未配置编辑器 URL |
|
||||
| 403 | 用户组无权限 |
|
||||
| 404 | 文件不存在 / 无可用 WOPI 查看器 |
|
||||
|
||||
---
|
||||
|
||||
## 前端交互流程
|
||||
|
||||
### 打开文件预览
|
||||
|
||||
```
|
||||
用户点击文件
|
||||
│
|
||||
▼
|
||||
GET /file/viewers?ext={扩展名}
|
||||
│
|
||||
├── viewers 为空 → 提示"暂无可用的预览方式"
|
||||
│
|
||||
├── default_viewer_id 不为空 → 直接用对应 viewer 打开(跳过选择弹窗)
|
||||
│
|
||||
└── viewers.length == 1 → 直接用唯一 viewer 打开(可选策略)
|
||||
│
|
||||
└── viewers.length > 1 → 展示选择弹窗
|
||||
│
|
||||
├── 用户选择 + 不勾选"始终使用" → 仅此一次打开
|
||||
│
|
||||
└── 用户选择 + 勾选"始终使用" → PUT /user/settings/file-viewers/default
|
||||
│
|
||||
└── 然后打开
|
||||
```
|
||||
|
||||
### 根据 type 打开查看器
|
||||
|
||||
```
|
||||
获取到 viewer 对象
|
||||
│
|
||||
├── type == "builtin"
|
||||
│ └── 根据 app_key 路由到内置组件
|
||||
│ switch(app_key):
|
||||
│ "pdfjs" → <PdfViewer />
|
||||
│ "monaco" → <CodeEditor />
|
||||
│ "markdown" → <MarkdownPreview />
|
||||
│ "image_viewer" → <ImageViewer />
|
||||
│ "video_player" → <VideoPlayer />
|
||||
│ "audio_player" → <AudioPlayer />
|
||||
│
|
||||
│ 获取文件内容:
|
||||
│ POST /file/download/{file_id} → { access_token }
|
||||
│ fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
│ → 传 URL 或 fetch Blob 给内置组件
|
||||
│
|
||||
├── type == "iframe"
|
||||
│ └── 1. POST /file/download/{file_id} → { access_token }
|
||||
│ 2. fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
|
||||
│ 3. iframeSrc = viewer.iframe_url_template
|
||||
│ .replace("{file_url}", encodeURIComponent(fileUrl))
|
||||
│ 4. <iframe src={iframeSrc} />
|
||||
│
|
||||
└── type == "wopi"
|
||||
└── 1. POST /file/{file_id}/wopi-session → { editor_url }
|
||||
2. <iframe src={editor_url} />
|
||||
(编辑器自动通过 WOPI 协议获取文件,前端无需处理)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 管理员 API
|
||||
|
||||
所有管理端点需要管理员身份(JWT 中 group.admin == true)。
|
||||
|
||||
### 1. 列出所有文件应用
|
||||
|
||||
```
|
||||
GET /api/v1/admin/file-app/list?page=1&page_size=20
|
||||
Authorization: Bearer {admin_token}
|
||||
```
|
||||
|
||||
**响应 200**
|
||||
|
||||
```json
|
||||
{
|
||||
"apps": [
|
||||
{
|
||||
"id": "...",
|
||||
"name": "PDF 阅读器",
|
||||
"app_key": "pdfjs",
|
||||
"type": "builtin",
|
||||
"icon": "file-pdf",
|
||||
"description": "...",
|
||||
"is_enabled": true,
|
||||
"is_restricted": false,
|
||||
"iframe_url_template": null,
|
||||
"wopi_discovery_url": null,
|
||||
"wopi_editor_url_template": null,
|
||||
"extensions": ["pdf"],
|
||||
"allowed_group_ids": []
|
||||
}
|
||||
],
|
||||
"total": 9
|
||||
}
|
||||
```
|
||||
|
||||
### 2. 创建文件应用
|
||||
|
||||
```
|
||||
POST /api/v1/admin/file-app/
|
||||
Authorization: Bearer {admin_token}
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "自定义查看器",
|
||||
"app_key": "my_viewer",
|
||||
"type": "iframe",
|
||||
"description": "自定义 iframe 查看器",
|
||||
"is_enabled": true,
|
||||
"is_restricted": false,
|
||||
"iframe_url_template": "https://example.com/view?url={file_url}",
|
||||
"extensions": ["pdf", "docx"],
|
||||
"allowed_group_ids": []
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 201 — 返回 FileAppResponse(同列表中的单项)
|
||||
|
||||
**错误码**: 409 — app_key 已存在
|
||||
|
||||
### 3. 获取应用详情
|
||||
|
||||
```
|
||||
GET /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 4. 更新应用
|
||||
|
||||
```
|
||||
PATCH /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
只传需要更新的字段:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "新名称",
|
||||
"is_enabled": false
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 5. 删除应用
|
||||
|
||||
```
|
||||
DELETE /api/v1/admin/file-app/{id}
|
||||
```
|
||||
|
||||
**响应** 204 No Content(级联删除扩展名关联、用户偏好、用户组关联)
|
||||
|
||||
### 6. 全量替换扩展名列表
|
||||
|
||||
```
|
||||
PUT /api/v1/admin/file-app/{id}/extensions
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"extensions": ["doc", "docx", "odt"]
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
### 7. 全量替换允许的用户组
|
||||
|
||||
```
|
||||
PUT /api/v1/admin/file-app/{id}/groups
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"group_ids": ["uuid-1", "uuid-2"]
|
||||
}
|
||||
```
|
||||
|
||||
**响应** 200 — FileAppResponse
|
||||
|
||||
> `is_restricted` 为 `true` 时,只有 `allowed_group_ids` 中的用户组成员能看到此应用。`is_restricted` 为 `false` 时所有用户可见,`allowed_group_ids` 不生效。
|
||||
|
||||
---
|
||||
|
||||
## TypeScript 类型参考
|
||||
|
||||
```typescript
|
||||
type FileAppType = 'builtin' | 'iframe' | 'wopi'
|
||||
|
||||
interface FileAppSummary {
|
||||
id: string
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon: string | null
|
||||
description: string | null
|
||||
iframe_url_template: string | null
|
||||
wopi_editor_url_template: string | null
|
||||
}
|
||||
|
||||
interface FileViewersResponse {
|
||||
viewers: FileAppSummary[]
|
||||
default_viewer_id: string | null
|
||||
}
|
||||
|
||||
interface SetDefaultViewerRequest {
|
||||
extension: string
|
||||
app_id: string
|
||||
}
|
||||
|
||||
interface UserFileAppDefaultResponse {
|
||||
id: string
|
||||
extension: string
|
||||
app: FileAppSummary
|
||||
}
|
||||
|
||||
interface WopiSessionResponse {
|
||||
wopi_src: string
|
||||
access_token: string
|
||||
access_token_ttl: number
|
||||
editor_url: string
|
||||
}
|
||||
|
||||
// ========== 管理员类型 ==========
|
||||
|
||||
interface FileAppResponse {
|
||||
id: string
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon: string | null
|
||||
description: string | null
|
||||
is_enabled: boolean
|
||||
is_restricted: boolean
|
||||
iframe_url_template: string | null
|
||||
wopi_discovery_url: string | null
|
||||
wopi_editor_url_template: string | null
|
||||
extensions: string[]
|
||||
allowed_group_ids: string[]
|
||||
}
|
||||
|
||||
interface FileAppListResponse {
|
||||
apps: FileAppResponse[]
|
||||
total: number
|
||||
}
|
||||
|
||||
interface FileAppCreateRequest {
|
||||
name: string
|
||||
app_key: string
|
||||
type: FileAppType
|
||||
icon?: string
|
||||
description?: string
|
||||
is_enabled?: boolean // default: true
|
||||
is_restricted?: boolean // default: false
|
||||
iframe_url_template?: string
|
||||
wopi_discovery_url?: string
|
||||
wopi_editor_url_template?: string
|
||||
extensions?: string[] // default: []
|
||||
allowed_group_ids?: string[] // default: []
|
||||
}
|
||||
|
||||
interface FileAppUpdateRequest {
|
||||
name?: string
|
||||
app_key?: string
|
||||
type?: FileAppType
|
||||
icon?: string
|
||||
description?: string
|
||||
is_enabled?: boolean
|
||||
is_restricted?: boolean
|
||||
iframe_url_template?: string
|
||||
wopi_discovery_url?: string
|
||||
wopi_editor_url_template?: string
|
||||
}
|
||||
```
|
||||
242
docs/text-editor-api.md
Normal file
242
docs/text-editor-api.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# 文本文件在线编辑 — 前端适配文档
|
||||
|
||||
## 概述
|
||||
|
||||
Monaco Editor 打开文本文件时,通过 GET 获取内容和哈希作为编辑基线;保存时用 jsdiff 计算 unified diff,仅发送差异部分,后端验证无并发冲突后应用 patch。
|
||||
|
||||
```
|
||||
打开文件: GET /api/v1/file/content/{file_id} → { content, hash, size }
|
||||
保存文件: PATCH /api/v1/file/content/{file_id} ← { patch, base_hash }
|
||||
→ { new_hash, new_size }
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 约定
|
||||
|
||||
| 项目 | 约定 |
|
||||
|------|------|
|
||||
| 编码 | 全程 UTF-8 |
|
||||
| 换行符 | 后端 GET 时统一规范化为 `\n`,前端无需处理 `\r\n` |
|
||||
| hash 算法 | SHA-256,hex 编码(64 字符),基于 UTF-8 bytes 计算 |
|
||||
| diff 格式 | jsdiff `createPatch()` 输出的标准 unified diff |
|
||||
| 空 diff | 前端自行判断,内容未变时不发请求 |
|
||||
|
||||
---
|
||||
|
||||
## GET /api/v1/file/content/{file_id}
|
||||
|
||||
获取文本文件内容。
|
||||
|
||||
### 请求
|
||||
|
||||
```
|
||||
GET /api/v1/file/content/{file_id}
|
||||
Authorization: Bearer <token>
|
||||
```
|
||||
|
||||
### 响应 200
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "line1\nline2\nline3\n",
|
||||
"hash": "a1b2c3d4...(64字符 SHA-256 hex)",
|
||||
"size": 18
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `content` | string | 文件文本内容,换行符已规范化为 `\n` |
|
||||
| `hash` | string | 基于规范化内容 UTF-8 bytes 的 SHA-256 hex |
|
||||
| `size` | number | 规范化后的字节大小 |
|
||||
|
||||
### 错误
|
||||
|
||||
| 状态码 | 说明 |
|
||||
|--------|------|
|
||||
| 400 | 文件不是有效的 UTF-8 文本(二进制文件) |
|
||||
| 401 | 未认证 |
|
||||
| 404 | 文件不存在 |
|
||||
|
||||
---
|
||||
|
||||
## PATCH /api/v1/file/content/{file_id}
|
||||
|
||||
增量保存文本文件。
|
||||
|
||||
### 请求
|
||||
|
||||
```
|
||||
PATCH /api/v1/file/content/{file_id}
|
||||
Authorization: Bearer <token>
|
||||
Content-Type: application/json
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"patch": "--- a\n+++ b\n@@ -1,3 +1,3 @@\n line1\n-line2\n+LINE2\n line3\n",
|
||||
"base_hash": "a1b2c3d4...(GET 返回的 hash)"
|
||||
}
|
||||
```
|
||||
|
||||
| 字段 | 类型 | 说明 |
|
||||
|------|------|------|
|
||||
| `patch` | string | jsdiff `createPatch()` 生成的 unified diff |
|
||||
| `base_hash` | string | 编辑前 GET 返回的 `hash` 值 |
|
||||
|
||||
### 响应 200
|
||||
|
||||
```json
|
||||
{
|
||||
"new_hash": "e5f6a7b8...(64字符)",
|
||||
"new_size": 18
|
||||
}
|
||||
```
|
||||
|
||||
保存成功后,前端应将 `new_hash` 作为新的 `base_hash`,用于下次保存。
|
||||
|
||||
### 错误
|
||||
|
||||
| 状态码 | 说明 | 前端处理 |
|
||||
|--------|------|----------|
|
||||
| 401 | 未认证 | — |
|
||||
| 404 | 文件不存在 | — |
|
||||
| 409 | `base_hash` 不匹配(并发冲突) | 提示用户刷新,重新加载内容 |
|
||||
| 422 | patch 格式无效或应用失败 | 回退到全量保存或提示用户 |
|
||||
|
||||
---
|
||||
|
||||
## 前端实现参考
|
||||
|
||||
### 依赖
|
||||
|
||||
```bash
|
||||
npm install jsdiff
|
||||
```
|
||||
|
||||
### 计算 hash
|
||||
|
||||
```typescript
|
||||
async function sha256(text: string): Promise<string> {
|
||||
const bytes = new TextEncoder().encode(text);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", bytes);
|
||||
const hashArray = Array.from(new Uint8Array(hashBuffer));
|
||||
return hashArray.map(b => b.toString(16).padStart(2, "0")).join("");
|
||||
}
|
||||
```
|
||||
|
||||
### 打开文件
|
||||
|
||||
```typescript
|
||||
interface TextContent {
|
||||
content: string;
|
||||
hash: string;
|
||||
size: number;
|
||||
}
|
||||
|
||||
async function openFile(fileId: string): Promise<TextContent> {
|
||||
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||
headers: { Authorization: `Bearer ${token}` },
|
||||
});
|
||||
|
||||
if (!resp.ok) {
|
||||
if (resp.status === 400) throw new Error("该文件不是文本文件");
|
||||
throw new Error("获取文件内容失败");
|
||||
}
|
||||
|
||||
return resp.json();
|
||||
}
|
||||
```
|
||||
|
||||
### 保存文件
|
||||
|
||||
```typescript
|
||||
import { createPatch } from "diff";
|
||||
|
||||
interface PatchResult {
|
||||
new_hash: string;
|
||||
new_size: number;
|
||||
}
|
||||
|
||||
async function saveFile(
|
||||
fileId: string,
|
||||
originalContent: string,
|
||||
currentContent: string,
|
||||
baseHash: string,
|
||||
): Promise<PatchResult | null> {
|
||||
// 内容未变,不发请求
|
||||
if (originalContent === currentContent) return null;
|
||||
|
||||
const patch = createPatch("file", originalContent, currentContent);
|
||||
|
||||
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
Authorization: `Bearer ${token}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ patch, base_hash: baseHash }),
|
||||
});
|
||||
|
||||
if (resp.status === 409) {
|
||||
// 并发冲突,需要用户决策
|
||||
throw new Error("CONFLICT");
|
||||
}
|
||||
|
||||
if (!resp.ok) throw new Error("保存失败");
|
||||
|
||||
return resp.json();
|
||||
}
|
||||
```
|
||||
|
||||
### 完整编辑流程
|
||||
|
||||
```typescript
|
||||
// 1. 打开
|
||||
const file = await openFile(fileId);
|
||||
let baseContent = file.content;
|
||||
let baseHash = file.hash;
|
||||
|
||||
// 2. 用户在 Monaco 中编辑...
|
||||
editor.setValue(baseContent);
|
||||
|
||||
// 3. 保存(Ctrl+S)
|
||||
const currentContent = editor.getValue();
|
||||
const result = await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||
|
||||
if (result) {
|
||||
// 更新基线
|
||||
baseContent = currentContent;
|
||||
baseHash = result.new_hash;
|
||||
}
|
||||
```
|
||||
|
||||
### 冲突处理建议
|
||||
|
||||
当 PATCH 返回 409 时,说明文件已被其他会话修改:
|
||||
|
||||
```typescript
|
||||
try {
|
||||
await saveFile(fileId, baseContent, currentContent, baseHash);
|
||||
} catch (e) {
|
||||
if (e.message === "CONFLICT") {
|
||||
// 方案 A:提示用户,提供"覆盖"和"放弃"选项
|
||||
// 方案 B:重新 GET 最新内容,展示 diff 让用户合并
|
||||
const latest = await openFile(fileId);
|
||||
// 展示合并 UI...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## hash 一致性验证
|
||||
|
||||
前端可以在 GET 后本地验证 hash,确保传输无误:
|
||||
|
||||
```typescript
|
||||
const file = await openFile(fileId);
|
||||
const localHash = await sha256(file.content);
|
||||
console.assert(localHash === file.hash, "hash 不一致,内容可能损坏");
|
||||
```
|
||||
1
ee
Submodule
1
ee
Submodule
Submodule ee added at cc32d8db91
14
license_public.pem
Normal file
14
license_public.pem
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
|
||||
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
|
||||
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
|
||||
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
|
||||
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
|
||||
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
|
||||
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
|
||||
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
|
||||
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
|
||||
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
|
||||
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
|
||||
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
84
main.py
84
main.py
@@ -1,25 +1,63 @@
|
||||
from pathlib import Path
|
||||
from typing import NoReturn
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger as l
|
||||
|
||||
from routers import router
|
||||
from routers.dav import dav_app
|
||||
from routers.dav.provider import EventLoopRef
|
||||
from service.redis import RedisManager
|
||||
from service.storage import S3StorageService
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
from utils import JWT
|
||||
from utils.conf import appmeta
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
from models.database import init_db
|
||||
from models.migration import migration
|
||||
from utils import JWT
|
||||
from routers import router
|
||||
from service.redis import RedisManager
|
||||
from loguru import logger as l
|
||||
|
||||
# 尝试加载企业版功能
|
||||
_has_ee: bool = False
|
||||
try:
|
||||
from ee import init_ee
|
||||
from ee.license import LicenseError
|
||||
from ee.routers import ee_router
|
||||
|
||||
_has_ee = True
|
||||
|
||||
async def _init_ee() -> None:
|
||||
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
|
||||
try:
|
||||
await init_ee()
|
||||
except LicenseError as exc:
|
||||
l.critical(f"许可证验证失败: {exc}")
|
||||
raise SystemExit(1) from exc
|
||||
|
||||
lifespan.add_startup(_init_ee)
|
||||
except ImportError as exc:
|
||||
ee_router = None
|
||||
l.info(f"以 Community 版本运行 (原因: {exc})")
|
||||
|
||||
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
|
||||
"""前端静态文件目录(由 Docker 构建时复制)"""
|
||||
|
||||
async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||
|
||||
# 捕获事件循环引用(供 WSGI 线程桥接使用)
|
||||
lifespan.add_startup(EventLoopRef.capture)
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(init_db)
|
||||
lifespan.add_startup(_init_db)
|
||||
lifespan.add_startup(migration)
|
||||
lifespan.add_startup(JWT.load_secret_key)
|
||||
lifespan.add_startup(RedisManager.connect)
|
||||
lifespan.add_startup(S3StorageService.initialize_session)
|
||||
|
||||
# 添加关闭项
|
||||
lifespan.add_shutdown(S3StorageService.close_session)
|
||||
lifespan.add_shutdown(DatabaseManager.close)
|
||||
lifespan.add_shutdown(RedisManager.disconnect)
|
||||
|
||||
# 创建应用实例并设置元数据
|
||||
@@ -59,6 +97,36 @@ async def handle_unexpected_exceptions(
|
||||
|
||||
# 挂载路由
|
||||
app.include_router(router)
|
||||
if _has_ee:
|
||||
app.include_router(ee_router, prefix="/api/v1")
|
||||
|
||||
# 挂载 WebDAV 协议端点(优先于 SPA catch-all)
|
||||
app.mount("/dav", dav_app)
|
||||
|
||||
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
|
||||
if STATICS_DIR.is_dir():
|
||||
from starlette.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
_assets_dir: Path = STATICS_DIR / "assets"
|
||||
if _assets_dir.is_dir():
|
||||
app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
|
||||
|
||||
@app.get("/{path:path}")
|
||||
async def spa_fallback(path: str) -> FileResponse:
|
||||
"""
|
||||
SPA fallback 路由
|
||||
|
||||
优先级:API 路由 > /assets 静态挂载 > 此 catch-all 路由。
|
||||
若请求路径对应 statics/ 下的真实文件则直接返回,否则返回 index.html。
|
||||
"""
|
||||
file_path: Path = (STATICS_DIR / path).resolve()
|
||||
# 防止路径穿越
|
||||
if file_path.is_relative_to(STATICS_DIR) and path and file_path.is_file():
|
||||
return FileResponse(file_path)
|
||||
return FileResponse(STATICS_DIR / "index.html")
|
||||
|
||||
l.info(f"前端静态文件已挂载: {STATICS_DIR}")
|
||||
|
||||
# 防止直接运行 main.py
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -4,50 +4,79 @@ from uuid import UUID
|
||||
from fastapi import Depends
|
||||
import jwt
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import JWTPayload, User, UserStatus
|
||||
from utils import JWT
|
||||
from .dependencies import SessionDep
|
||||
from utils import http_exceptions
|
||||
from service.redis import RedisManager
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
|
||||
async def auth_required(
|
||||
|
||||
async def jwt_required(
|
||||
session: SessionDep,
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
) -> User:
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
AuthRequired 需要登录
|
||||
验证 JWT 并返回 claims。
|
||||
|
||||
封禁检查策略:
|
||||
1. JWT 内嵌 status 检查(签发时快照)
|
||||
2. Redis 黑名单检查(即时封禁,如果 Redis 可用)
|
||||
3. Redis 不可用时查库检查 status(降级方案)
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if user_id is None:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
user_id = UUID(user_id)
|
||||
|
||||
# 从数据库获取用户信息
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
return user
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
claims = JWTPayload(
|
||||
sub=payload["sub"],
|
||||
jti=payload["jti"],
|
||||
status=payload["status"],
|
||||
group=payload["group"],
|
||||
)
|
||||
except (jwt.InvalidTokenError, KeyError, ValueError):
|
||||
http_exceptions.raise_unauthorized("凭据过期或无效")
|
||||
|
||||
# 1. JWT 内嵌 status 检查
|
||||
if claims.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 2. 即时封禁检查
|
||||
user_id_str = str(claims.sub)
|
||||
if RedisManager.is_available():
|
||||
# Redis 可用:查黑名单
|
||||
if await UserBanStore.is_banned(user_id_str):
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
else:
|
||||
# Redis 不可用:查库(仅 status 字段,不加载关系)
|
||||
user = await User.get(session, User.id == claims.sub)
|
||||
if not user or user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
async def admin_required(
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> User:
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
验证是否为管理员。
|
||||
验证管理员权限(仅读取 JWT claims,不查库)。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(admin_required)])
|
||||
"""
|
||||
group = await user.awaitable_attrs.group
|
||||
if group.admin:
|
||||
return user
|
||||
raise http_exceptions.raise_forbidden("Admin Required")
|
||||
if not claims.group.admin:
|
||||
http_exceptions.raise_forbidden("Admin Required")
|
||||
return claims
|
||||
|
||||
|
||||
async def auth_required(
|
||||
session: SessionDep,
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> User:
|
||||
"""验证 JWT + 从数据库加载完整 User(含 group 关系)"""
|
||||
user = await User.get(session, User.id == claims.sub, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
return user
|
||||
|
||||
|
||||
def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
||||
|
||||
@@ -6,22 +6,24 @@ FastAPI 依赖注入
|
||||
- TimeFilterRequestDep: 时间筛选查询依赖(用于 count 等统计接口)
|
||||
- TableViewRequestDep: 分页排序查询依赖(包含时间筛选 + 分页排序)
|
||||
- UserFilterParamsDep: 用户筛选参数依赖(用于管理员用户列表)
|
||||
- require_captcha: 验证码校验依赖注入工厂
|
||||
"""
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from fastapi import Depends, Form, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.database import get_session
|
||||
from models.mixin import TimeFilterRequest, TableViewRequest
|
||||
from models.user import UserFilterParams, UserStatus
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodel_ext import TimeFilterRequest, TableViewRequest
|
||||
from sqlmodels.user import UserFilterParams, UserStatus
|
||||
|
||||
|
||||
# --- 数据库会话依赖 ---
|
||||
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(get_session)]
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(DatabaseManager.get_session)]
|
||||
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
||||
|
||||
|
||||
@@ -79,14 +81,14 @@ TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends(_get_table_
|
||||
|
||||
async def _get_user_filter_params(
|
||||
group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None,
|
||||
username: Annotated[str | None, Query(max_length=50, description="按用户名模糊搜索")] = None,
|
||||
email: Annotated[str | None, Query(max_length=50, description="按邮箱模糊搜索")] = None,
|
||||
nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None,
|
||||
status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None,
|
||||
) -> UserFilterParams:
|
||||
"""解析用户过滤查询参数"""
|
||||
return UserFilterParams(
|
||||
group_id=group_id,
|
||||
username_contains=username,
|
||||
email_contains=email,
|
||||
nickname_contains=nickname,
|
||||
status=status,
|
||||
)
|
||||
@@ -94,3 +96,30 @@ async def _get_user_filter_params(
|
||||
|
||||
UserFilterParamsDep: TypeAlias = Annotated[UserFilterParams, Depends(_get_user_filter_params)]
|
||||
"""获取用户筛选参数的依赖(用于管理员用户列表)"""
|
||||
|
||||
|
||||
# --- 验证码校验依赖 ---
|
||||
|
||||
def require_captcha(scene: 'CaptchaScene') -> Callable[..., Awaitable[None]]:
|
||||
"""
|
||||
验证码校验依赖注入工厂。
|
||||
|
||||
根据场景查询数据库设置,判断是否需要验证码。
|
||||
需要则校验前端提交的 captcha_code,失败则抛出异常。
|
||||
|
||||
使用方式::
|
||||
|
||||
@router.post('/session', dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))])
|
||||
async def login(...): ...
|
||||
|
||||
:param scene: 验证码使用场景(LOGIN / REGISTER / FORGET)
|
||||
"""
|
||||
from service.captcha import CaptchaScene, verify_captcha_if_needed
|
||||
|
||||
async def _verify_captcha(
|
||||
session: SessionDep,
|
||||
captcha_code: Annotated[str | None, Form()] = None,
|
||||
) -> None:
|
||||
await verify_captcha_if_needed(session, scene, captcha_code)
|
||||
|
||||
return _verify_captcha
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
from .user import (
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
UserPublic,
|
||||
UserResponse,
|
||||
UserSettingResponse,
|
||||
WebAuthnInfo,
|
||||
# 管理员DTO
|
||||
UserAdminUpdateRequest,
|
||||
UserCalibrateResponse,
|
||||
UserAdminDetailResponse,
|
||||
)
|
||||
from .user_authn import AuthnResponse, UserAuthn
|
||||
from .color import ThemeResponse
|
||||
|
||||
from .download import (
|
||||
Download,
|
||||
DownloadAria2File,
|
||||
DownloadAria2Info,
|
||||
DownloadAria2InfoBase,
|
||||
DownloadStatus,
|
||||
DownloadType,
|
||||
)
|
||||
from .node import (
|
||||
Aria2Configuration,
|
||||
Aria2ConfigurationBase,
|
||||
Node,
|
||||
NodeStatus,
|
||||
NodeType,
|
||||
)
|
||||
from .group import (
|
||||
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
# 管理员DTO
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
|
||||
)
|
||||
from .object import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
FileMetadata,
|
||||
FileMetadataBase,
|
||||
Object,
|
||||
ObjectBase,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
ObjectMoveRequest,
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectResponse,
|
||||
ObjectType,
|
||||
PolicyResponse,
|
||||
UploadChunkResponse,
|
||||
UploadSession,
|
||||
UploadSessionBase,
|
||||
UploadSessionResponse,
|
||||
# 管理员DTO
|
||||
AdminFileResponse,
|
||||
AdminFileListResponse,
|
||||
FileBanRequest,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .order import Order, OrderStatus, OrderType
|
||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||
from .redeem import Redeem, RedeemType
|
||||
from .report import Report, ReportReason
|
||||
from .setting import (
|
||||
Setting, SettingsType, SiteConfigResponse,
|
||||
# 管理员DTO
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from .share import Share, ShareBase, ShareCreateRequest, ShareResponse, AdminShareListItem
|
||||
from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
|
||||
from .database import engine, get_session
|
||||
|
||||
from .model_base import (
|
||||
MCPBase,
|
||||
MCPMethod,
|
||||
MCPRequestBase,
|
||||
MCPResponseBase,
|
||||
ResponseBase,
|
||||
# Admin Summary DTO
|
||||
MetricsSummary,
|
||||
LicenseInfo,
|
||||
VersionInfo,
|
||||
AdminSummaryResponse,
|
||||
)
|
||||
|
||||
# mixin 中的通用分页模型
|
||||
from .mixin import ListResponse
|
||||
@@ -1,657 +0,0 @@
|
||||
# SQLModels Base Module
|
||||
|
||||
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
|
||||
|
||||
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Migration Notice](#migration-notice)
|
||||
- [Python 3.14 Compatibility](#python-314-compatibility)
|
||||
- [Core Component](#core-component)
|
||||
- [SQLModelBase](#sqlmodelbase)
|
||||
- [Metaclass Features](#metaclass-features)
|
||||
- [Automatic sa_type Injection](#automatic-sa_type-injection)
|
||||
- [Table Configuration](#table-configuration)
|
||||
- [Polymorphic Support](#polymorphic-support)
|
||||
- [Custom Types Integration](#custom-types-integration)
|
||||
- [Best Practices](#best-practices)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
|
||||
|
||||
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
|
||||
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
|
||||
- **Flexible configuration** through class parameters and automatic docstring support
|
||||
- **Type-safe annotations** with automatic validation
|
||||
|
||||
All models in this project should directly or indirectly inherit from `SQLModelBase`.
|
||||
|
||||
---
|
||||
|
||||
## Migration Notice
|
||||
|
||||
As of the recent refactoring, the following components have been moved:
|
||||
|
||||
| Component | Old Location | New Location |
|
||||
|-----------|-------------|--------------|
|
||||
| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
|
||||
**Update your imports**:
|
||||
|
||||
```python
|
||||
# ❌ Old (deprecated)
|
||||
from sqlmodels.base import TableBase, UUIDTableBase
|
||||
|
||||
# ✅ New (correct)
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
```
|
||||
|
||||
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Python 3.14 Compatibility
|
||||
|
||||
### Overview
|
||||
|
||||
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
|
||||
|
||||
**Key Changes in Python 3.14**:
|
||||
- Annotations are no longer evaluated at class definition time
|
||||
- Type hints are stored as deferred code objects
|
||||
- `__annotate__` function generates annotations on demand
|
||||
- Forward references become `ForwardRef` objects
|
||||
|
||||
### Implementation Strategy
|
||||
|
||||
We use **`typing.get_type_hints()`** as the universal annotations resolver:
|
||||
|
||||
```python
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
|
||||
# Create temporary proxy class
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
|
||||
# Use get_type_hints with include_extras=True
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
include_extras=True # Preserve Annotated metadata
|
||||
)
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
```
|
||||
|
||||
**Why `get_type_hints()`?**
|
||||
- ✅ Works across Python 3.10-3.14+
|
||||
- ✅ Handles PEP 649 automatically
|
||||
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
|
||||
- ✅ Resolves forward references
|
||||
- ✅ Recommended by Python documentation
|
||||
|
||||
### SQLModel Compatibility Patch
|
||||
|
||||
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
|
||||
|
||||
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
|
||||
|
||||
```python
|
||||
if sys.version_info >= (3, 14):
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# Priority 1: Check __sqlmodel_sa_type__ attribute
|
||||
# Handles NumpyVector[dims, dtype] and similar custom types
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# Priority 2: Check Annotated metadata
|
||||
if get_origin(annotation) is Annotated:
|
||||
for metadata in get_args(annotation)[1:]:
|
||||
if hasattr(metadata, '__sqlmodel_sa_type__'):
|
||||
return metadata.__sqlmodel_sa_type__
|
||||
|
||||
# ... handle ForwardRef, ClassVar, etc.
|
||||
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
```
|
||||
|
||||
### Supported Patterns
|
||||
|
||||
#### Pattern 1: Direct Custom Type Usage
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding - sa_type automatically extracted"""
|
||||
```
|
||||
|
||||
#### Pattern 2: Annotated Wrapper
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
```
|
||||
|
||||
#### Pattern 3: Array Type
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
|
||||
class ServerConfig(TableBaseMixin, table=True):
|
||||
protocols: Array[ProtocolEnum]
|
||||
"""Allowed protocols - sa_type from Array handler"""
|
||||
```
|
||||
|
||||
### Migration from Python 3.13
|
||||
|
||||
**No code changes required!** The implementation is transparent:
|
||||
|
||||
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
|
||||
- Custom types already use `__sqlmodel_sa_type__` attribute
|
||||
- Monkey-patch only activates for Python 3.14+
|
||||
|
||||
---
|
||||
|
||||
## Core Component
|
||||
|
||||
### SQLModelBase
|
||||
|
||||
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
|
||||
|
||||
**Key Features**:
|
||||
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
|
||||
- Automatic `validate_by_name` configuration
|
||||
- Custom metaclass for sa_type injection and polymorphic setup
|
||||
- Integration with Pydantic v2
|
||||
- Python 3.14 PEP 649 compatibility
|
||||
|
||||
**Usage**:
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
email: str
|
||||
"""User's email address"""
|
||||
```
|
||||
|
||||
**Important Notes**:
|
||||
- Use **docstrings** for field descriptions, not `Field(description=...)`
|
||||
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
|
||||
- This class should be used for non-table models (DTOs, request/response models)
|
||||
|
||||
**For table models**, use mixins from `sqlmodels.mixin`:
|
||||
- `TableBaseMixin` - Integer primary key with timestamps
|
||||
- `UUIDTableBaseMixin` - UUID primary key with timestamps
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
|
||||
|
||||
---
|
||||
|
||||
## Metaclass Features
|
||||
|
||||
### Automatic sa_type Injection
|
||||
|
||||
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
|
||||
|
||||
**Before** (verbose):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
**After** (clean):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding"""
|
||||
```
|
||||
|
||||
**How It Works**:
|
||||
|
||||
The metaclass uses a three-tier detection strategy:
|
||||
|
||||
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
|
||||
```python
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
2. **Annotated metadata** (Priority 2)
|
||||
```python
|
||||
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
for item in metadata_items:
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
3. **Pydantic Core Schema metadata** (Priority 3)
|
||||
```python
|
||||
schema = annotation.__get_pydantic_core_schema__(...)
|
||||
if schema['metadata'].get('sa_type'):
|
||||
return schema['metadata']['sa_type']
|
||||
```
|
||||
|
||||
After extracting `sa_type`, the metaclass:
|
||||
- Creates `Field(sa_type=sa_type)` if no Field is defined
|
||||
- Injects `sa_type` into existing Field if not already set
|
||||
- Respects explicit `Field(sa_type=...)` (no override)
|
||||
|
||||
**Supported Patterns**:
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# Pattern 1: Direct usage (recommended)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
|
||||
# Pattern 2: With Field constraints
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
|
||||
|
||||
# Pattern 3: Annotated wrapper
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
|
||||
# Pattern 4: Explicit sa_type (override)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
|
||||
)
|
||||
```
|
||||
|
||||
### Table Configuration
|
||||
|
||||
The metaclass provides smart defaults and flexible configuration:
|
||||
|
||||
**Automatic `table=True`**:
|
||||
```python
|
||||
# Classes inheriting from TableBaseMixin automatically get table=True
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin): # table=True is automatic
|
||||
pass
|
||||
```
|
||||
|
||||
**Convenient mapper arguments**:
|
||||
```python
|
||||
# Instead of verbose __mapper_args__
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
polymorphic_on='_polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
pass
|
||||
|
||||
# Equivalent to:
|
||||
class MyModel(UUIDTableBaseMixin):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': '_polymorphic_name',
|
||||
'polymorphic_abstract': True
|
||||
}
|
||||
```
|
||||
|
||||
**Smart merging**:
|
||||
```python
|
||||
# Dictionary and keyword arguments are merged
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
mapper_args={'version_id_col': 'version'},
|
||||
polymorphic_on='type' # Merged into __mapper_args__
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### Polymorphic Support
|
||||
|
||||
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
|
||||
|
||||
**Supported parameters**:
|
||||
- `polymorphic_on`: Discriminator column name
|
||||
- `polymorphic_identity`: Identity value for this class
|
||||
- `polymorphic_abstract`: Whether this is an abstract base
|
||||
- `table_args`: SQLAlchemy table arguments
|
||||
- `table_name`: Override table name (becomes `__tablename__`)
|
||||
|
||||
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Custom Types Integration
|
||||
|
||||
### Using NumpyVector
|
||||
|
||||
The `NumpyVector` type demonstrates automatic sa_type injection:
|
||||
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
import numpy as np
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding - sa_type automatically injected"""
|
||||
```
|
||||
|
||||
**How NumpyVector works**:
|
||||
|
||||
```python
|
||||
# NumpyVector[dims, dtype] returns a class with:
|
||||
class _NumpyVectorType:
|
||||
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return handler.generate_schema(np.ndarray)
|
||||
```
|
||||
|
||||
This dual approach ensures:
|
||||
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
|
||||
2. Pydantic can validate as `np.ndarray`
|
||||
|
||||
### Creating Custom SQLAlchemy Types
|
||||
|
||||
To create types that work with automatic injection, provide one of:
|
||||
|
||||
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
|
||||
|
||||
```python
|
||||
from sqlalchemy import TypeDecorator, String
|
||||
|
||||
class UpperCaseString(TypeDecorator):
|
||||
impl = String
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value.upper() if value else value
|
||||
|
||||
class UpperCaseType:
|
||||
__sqlmodel_sa_type__ = UpperCaseString()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return core_schema.str_schema()
|
||||
|
||||
# Usage
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCaseType # Automatically uses UpperCaseString()
|
||||
```
|
||||
|
||||
**Option 2: Pydantic metadata with sa_type**:
|
||||
|
||||
```python
|
||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.str_schema(),
|
||||
python_schema=core_schema.str_schema(),
|
||||
metadata={'sa_type': UpperCaseString()}
|
||||
)
|
||||
```
|
||||
|
||||
**Option 3: Using Annotated**:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
UpperCase = Annotated[str, UpperCaseType()]
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCase
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Inherit from correct base classes
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
# ✅ For non-table models (DTOs, requests, responses)
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
|
||||
# ✅ For table models with UUID primary key
|
||||
class User(UserBase, UUIDTableBaseMixin, table=True):
|
||||
email: str
|
||||
|
||||
# ✅ For table models with custom primary key
|
||||
class LegacyUser(TableBaseMixin, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
username: str
|
||||
```
|
||||
|
||||
### 2. Use docstrings for field descriptions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Recommended
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
# ❌ Avoid
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str = Field(description="User's display name")
|
||||
```
|
||||
|
||||
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
|
||||
|
||||
### 3. Leverage automatic sa_type injection
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Clean and recommended
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding"""
|
||||
|
||||
# ❌ Verbose and unnecessary
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Follow polymorphic naming conventions
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
|
||||
|
||||
### 5. Separate Base, Parent, and Implementation classes
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
|
||||
|
||||
# ✅ Recommended structure
|
||||
class ASRBase(SQLModelBase):
|
||||
"""Pure data fields, no table"""
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract parent with table"""
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
pass
|
||||
|
||||
class WhisperASR(ASR, table=True):
|
||||
"""Concrete implementation"""
|
||||
model_size: str
|
||||
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
# Implementation
|
||||
pass
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Base class can be reused for DTOs
|
||||
- Parent class defines the polymorphic hierarchy
|
||||
- Implementation classes are clean and focused
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: ValueError: X has no matching SQLAlchemy type
|
||||
|
||||
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
|
||||
|
||||
```python
|
||||
# ✅ Provide __sqlmodel_sa_type__
|
||||
class MyType:
|
||||
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
|
||||
```
|
||||
|
||||
### Issue: Can't generate DDL for NullType()
|
||||
|
||||
**Symptoms**: Error during table creation saying a column has `NullType`.
|
||||
|
||||
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
|
||||
|
||||
**Solution**:
|
||||
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
|
||||
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
|
||||
3. Verify type annotation is correct (not a string forward reference)
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Correct
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
|
||||
|
||||
# ❌ Wrong (string annotation)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: 'NumpyVector[256, np.float32]' # sa_type lost
|
||||
```
|
||||
|
||||
### Issue: Polymorphic identity conflicts
|
||||
|
||||
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
|
||||
|
||||
**Solution**:
|
||||
1. Check that each concrete class has a unique identity
|
||||
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
|
||||
3. Manually specify identity if needed:
|
||||
```python
|
||||
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### Issue: Python 3.14 annotation errors
|
||||
|
||||
**Symptoms**: Errors related to `__annotations__` or type resolution.
|
||||
|
||||
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
|
||||
1. Check for manual `__annotations__` manipulation (avoid it)
|
||||
2. Ensure all types are properly imported
|
||||
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
|
||||
|
||||
### Issue: Polymorphic and CRUD-related errors
|
||||
|
||||
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Implementation Details
|
||||
|
||||
For developers modifying this module:
|
||||
|
||||
**Core files**:
|
||||
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
|
||||
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
|
||||
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
|
||||
|
||||
**Key functions in this module**:
|
||||
|
||||
1. **`_resolve_annotations(attrs: dict[str, Any])`**
|
||||
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
|
||||
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
|
||||
- Preserves `Annotated` metadata with `include_extras=True`
|
||||
|
||||
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
|
||||
- Extracts SQLAlchemy type from type annotations
|
||||
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
|
||||
- Called by metaclass during class creation
|
||||
|
||||
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
|
||||
- Global monkey-patch for SQLModel
|
||||
- Checks `__sqlmodel_sa_type__` before falling back to original logic
|
||||
- Handles custom types like `NumpyVector` and `Array`
|
||||
|
||||
4. **`__DeclarativeMeta.__new__()`**
|
||||
- Processes class definition parameters
|
||||
- Injects `sa_type` into field definitions
|
||||
- Sets up `__mapper_args__`, `__table_args__`, etc.
|
||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
||||
|
||||
**Metaclass processing order**:
|
||||
1. Check if class should be a table (`_is_table_mixin`)
|
||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
||||
4. Resolve annotations using `get_type_hints()`
|
||||
5. For each field, try to extract `sa_type` and inject into Field
|
||||
6. Call parent metaclass with cleaned kwargs
|
||||
|
||||
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
**Project Documentation**:
|
||||
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
|
||||
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
|
||||
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
|
||||
|
||||
**External References**:
|
||||
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
|
||||
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
|
||||
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
|
||||
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
|
||||
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)
|
||||
@@ -1,12 +0,0 @@
|
||||
"""
|
||||
SQLModel 基础模块
|
||||
|
||||
包含:
|
||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
||||
|
||||
注意:
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin
|
||||
为了避免循环导入,此处不再重新导出它们
|
||||
请直接从 models.mixin 导入这些类
|
||||
"""
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
@@ -1,846 +0,0 @@
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Mapping, get_args, get_origin, get_type_hints
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined as Undefined
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel.main import SQLModelMetaclass
|
||||
|
||||
# Python 3.14+ PEP 649支持
|
||||
if sys.version_info >= (3, 14):
|
||||
import annotationlib
|
||||
|
||||
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
|
||||
import sqlmodel.main
|
||||
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
|
||||
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
"""
|
||||
修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。
|
||||
|
||||
问题:
|
||||
1. ForwardRef对象(来自Relationship字段)会导致issubclass错误
|
||||
2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题
|
||||
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
|
||||
4. Mapped类型在Python 3.14下可能出现在annotation中
|
||||
5. Annotated类型可能包含sa_type metadata(如Array[T])
|
||||
6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性
|
||||
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
|
||||
|
||||
解决:
|
||||
- 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况)
|
||||
- 检测__sqlmodel_sa_type__属性(NumpyVector等)
|
||||
- 检测Relationship/ClassVar等返回None
|
||||
- 对于Annotated类型,尝试提取sa_type metadata
|
||||
- 其他情况调用原始函数
|
||||
"""
|
||||
# 优先检查 field.metadata(Pydantic已处理Annotated类型的情况)
|
||||
# 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里
|
||||
metadata = getattr(field, 'metadata', None)
|
||||
if metadata:
|
||||
# metadata是一个列表,包含所有Annotated的元数据项
|
||||
for metadata_item in metadata:
|
||||
# 检查metadata_item是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata_item.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Pydantic schema获取可能失败(类型不匹配、缺少属性等)
|
||||
# 这是正常情况,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# 优先检查 __sqlmodel_sa_type__ 属性
|
||||
# 这处理 NumpyVector[dims, dtype] 等自定义类型
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法
|
||||
# 这些类型在schema的metadata中定义sa_type
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema(传None作为handler,因为我们只需要metadata)
|
||||
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Schema获取失败,继续其他检查
|
||||
pass
|
||||
|
||||
anno_type_name = type(annotation).__name__
|
||||
|
||||
# ForwardRef: Relationship字段的annotation
|
||||
if anno_type_name == 'ForwardRef':
|
||||
return None
|
||||
|
||||
# AnnotatedAlias: 检查是否有sa_type metadata(如Array[T])
|
||||
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
|
||||
from typing import get_origin, get_args
|
||||
import typing
|
||||
|
||||
# 尝试提取Annotated的metadata
|
||||
if hasattr(typing, 'get_args'):
|
||||
args = get_args(annotation)
|
||||
# args[0]是实际类型,args[1:]是metadata
|
||||
for metadata in args[1:]:
|
||||
# 检查metadata是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Annotated metadata的schema获取可能失败
|
||||
# 这是正常的类型检查过程,继续检查下一个metadata
|
||||
pass
|
||||
|
||||
# _GenericAlias或GenericAlias: typing泛型类型
|
||||
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
|
||||
from typing import get_origin
|
||||
import typing
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# ClassVar必须跳过
|
||||
if origin is typing.ClassVar:
|
||||
return None
|
||||
|
||||
# list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过
|
||||
# 这通常意味着它是Relationship字段或类变量
|
||||
if origin in (list, dict, tuple, set):
|
||||
# 检查field_info是否存在且有意义
|
||||
# Relationship字段会有特殊的field_info
|
||||
field_info = getattr(field, 'field_info', None)
|
||||
if field_info is None:
|
||||
return None
|
||||
|
||||
# Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理
|
||||
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
|
||||
# 检查类型名称和annotation的字符串表示
|
||||
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
|
||||
return None
|
||||
|
||||
# 检查annotation是否是Mapped类或其实例
|
||||
try:
|
||||
from sqlalchemy.orm import Mapped as SAMapped
|
||||
# 检查origin(对于Mapped[T]这种泛型)
|
||||
from typing import get_origin
|
||||
if get_origin(annotation) is SAMapped:
|
||||
return None
|
||||
# 检查类型本身
|
||||
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
|
||||
return None
|
||||
except (ImportError, TypeError):
|
||||
# 如果SQLAlchemy没有Mapped或检查失败,继续
|
||||
pass
|
||||
|
||||
# 其他情况正常处理
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
|
||||
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
|
||||
|
||||
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
|
||||
# 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时,
|
||||
# 父类的关系字段(如server_config)会在子类的model_fields中出现,
|
||||
# 但其default值错误地设置为InstrumentedAttribute对象,而不是None
|
||||
# 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误
|
||||
import sqlmodel._compat as _compat
|
||||
from sqlalchemy.orm import attributes as _sa_attributes
|
||||
|
||||
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
|
||||
|
||||
def _patched_sqlmodel_table_construct(self_instance, values):
|
||||
"""
|
||||
修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值
|
||||
|
||||
问题:
|
||||
- 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider)
|
||||
- 其model_fields中的继承字段default值为InstrumentedAttribute
|
||||
- 原函数尝试将InstrumentedAttribute设置为字段值
|
||||
- SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误
|
||||
|
||||
解决:
|
||||
- 只设置用户提供的值和非InstrumentedAttribute默认值
|
||||
- InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理)
|
||||
"""
|
||||
cls = type(self_instance)
|
||||
|
||||
# 收集要设置的字段值
|
||||
fields_to_set = {}
|
||||
|
||||
for name, field in cls.model_fields.items():
|
||||
# 如果用户提供了值,直接使用
|
||||
if name in values:
|
||||
fields_to_set[name] = values[name]
|
||||
continue
|
||||
|
||||
# 否则检查默认值
|
||||
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
|
||||
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
|
||||
continue
|
||||
|
||||
# 使用正常的默认值
|
||||
if field.default is not Undefined:
|
||||
fields_to_set[name] = field.default
|
||||
elif field.default_factory is not None:
|
||||
fields_to_set[name] = field.get_default(call_default_factory=True)
|
||||
|
||||
# 设置属性 - 只设置非InstrumentedAttribute值
|
||||
for key, value in fields_to_set.items():
|
||||
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
# 设置Pydantic内部属性
|
||||
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
|
||||
if not cls.__pydantic_root_model__:
|
||||
_extra = None
|
||||
if cls.model_config.get('extra') == 'allow':
|
||||
_extra = {}
|
||||
for k, v in values.items():
|
||||
if k not in cls.model_fields:
|
||||
_extra[k] = v
|
||||
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
|
||||
|
||||
if cls.__pydantic_post_init__:
|
||||
self_instance.model_post_init(None)
|
||||
elif not cls.__pydantic_root_model__:
|
||||
object.__setattr__(self_instance, '__pydantic_private__', None)
|
||||
|
||||
# 设置关系
|
||||
for key in self_instance.__sqlmodel_relationships__:
|
||||
value = values.get(key, Undefined)
|
||||
if value is not Undefined:
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
return self_instance
|
||||
|
||||
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
|
||||
else:
|
||||
annotationlib = None
|
||||
|
||||
|
||||
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
|
||||
"""
|
||||
从类型注解中提取SQLAlchemy类型。
|
||||
|
||||
支持以下形式:
|
||||
1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性)
|
||||
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
|
||||
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
|
||||
|
||||
Args:
|
||||
annotation: 字段的类型注解
|
||||
|
||||
Returns:
|
||||
提取到的SQLAlchemy类型,如果没有则返回None
|
||||
"""
|
||||
# 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性
|
||||
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 方法2:检查是否为Annotated类型
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
# 获取元数据项(跳过第一个实际类型参数)
|
||||
args = get_args(annotation)
|
||||
if len(args) >= 2:
|
||||
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
|
||||
|
||||
# 遍历元数据,查找包含sa_type的项
|
||||
for item in metadata_items:
|
||||
# 检查元数据项是否有__sqlmodel_sa_type__属性
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
|
||||
# 检查是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用该方法获取core schema
|
||||
schema = item.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# Pydantic core schema获取可能失败:
|
||||
# - TypeError: 参数不匹配
|
||||
# - AttributeError: metadata不存在
|
||||
# - KeyError: schema结构不符合预期
|
||||
# - ValueError: 无效的类型定义
|
||||
# 这是正常的类型探测过程,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
# 方法3:检查类型本身是否有__get_pydantic_core_schema__
|
||||
# (虽然NumpyVector已经在方法1处理,但这是通用的fallback)
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
schema = annotation.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# 类型本身的schema获取失败
|
||||
# 这是正常的fallback机制,annotation可能不支持此协议
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
|
||||
dict[str, Any],
|
||||
dict[str, str],
|
||||
Mapping[str, Any],
|
||||
Mapping[str, Any],
|
||||
]:
|
||||
"""
|
||||
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
|
||||
|
||||
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
|
||||
metadata and custom types remain accessible. Forward references that cannot be
|
||||
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
|
||||
whole resolution process.
|
||||
"""
|
||||
raw_annotations = attrs.get('__annotations__') or {}
|
||||
try:
|
||||
base_annotations = dict(raw_annotations)
|
||||
except TypeError:
|
||||
base_annotations = {}
|
||||
|
||||
module_name = attrs.get('__module__')
|
||||
module_globals: dict[str, Any]
|
||||
if module_name and module_name in sys.modules:
|
||||
module_globals = dict(sys.modules[module_name].__dict__)
|
||||
else:
|
||||
module_globals = {}
|
||||
|
||||
module_globals.setdefault('__builtins__', __builtins__)
|
||||
localns: dict[str, Any] = dict(attrs)
|
||||
|
||||
try:
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
temp_cls.__module__ = module_name
|
||||
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
**extras_kw,
|
||||
)
|
||||
except (NameError, AttributeError, TypeError, RecursionError):
|
||||
# get_type_hints可能失败的原因:
|
||||
# - NameError: 前向引用无法解析(类型尚未定义)
|
||||
# - AttributeError: 模块或类型不存在
|
||||
# - TypeError: 无效的类型注解
|
||||
# - RecursionError: 循环依赖的类型定义
|
||||
# 这是正常情况,回退到原始注解字符串
|
||||
evaluated = base_annotations
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
|
||||
|
||||
def _evaluate_annotation_from_string(
|
||||
field_name: str,
|
||||
annotation_strings: dict[str, str],
|
||||
current_type: Any,
|
||||
globalns: Mapping[str, Any],
|
||||
localns: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Attempt to re-evaluate the original annotation string for a field.
|
||||
|
||||
This is used as a fallback when the resolved annotation lost its metadata
|
||||
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
|
||||
"""
|
||||
if not annotation_strings:
|
||||
return current_type
|
||||
|
||||
expr = annotation_strings.get(field_name)
|
||||
if not expr or not isinstance(expr, str):
|
||||
return current_type
|
||||
|
||||
try:
|
||||
return eval(expr, globalns, localns)
|
||||
except (NameError, SyntaxError, AttributeError, TypeError):
|
||||
# eval可能失败的原因:
|
||||
# - NameError: 类型名称在namespace中不存在
|
||||
# - SyntaxError: 注解字符串有语法错误
|
||||
# - AttributeError: 访问不存在的模块属性
|
||||
# - TypeError: 无效的类型表达式
|
||||
# 这是正常的fallback机制,返回当前已解析的类型
|
||||
return current_type
|
||||
|
||||
|
||||
class __DeclarativeMeta(SQLModelMetaclass):
|
||||
"""
|
||||
一个智能的混合模式元类,它提供了灵活性和清晰度:
|
||||
|
||||
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
|
||||
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
|
||||
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
|
||||
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
|
||||
"""
|
||||
|
||||
_KNOWN_MAPPER_KEYS = {
|
||||
"polymorphic_on",
|
||||
"polymorphic_identity",
|
||||
"polymorphic_abstract",
|
||||
"version_id_col",
|
||||
"concrete",
|
||||
}
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
# 1. 约定优于配置:自动设置 table=True
|
||||
is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases)
|
||||
if is_intended_as_table and 'table' not in kwargs:
|
||||
kwargs['table'] = True
|
||||
|
||||
# 2. 智能合并 __mapper_args__
|
||||
collected_mapper_args = {}
|
||||
|
||||
# 首先,处理明确的 mapper_args 字典 (优先级较低)
|
||||
if 'mapper_args' in kwargs:
|
||||
collected_mapper_args.update(kwargs.pop('mapper_args'))
|
||||
|
||||
# 其次,处理便捷的关键字参数 (优先级更高)
|
||||
for key in cls._KNOWN_MAPPER_KEYS:
|
||||
if key in kwargs:
|
||||
# .pop() 获取值并移除,避免传递给父类
|
||||
collected_mapper_args[key] = kwargs.pop(key)
|
||||
|
||||
# 如果收集到了任何 mapper 参数,则更新到类的属性中
|
||||
if collected_mapper_args:
|
||||
existing = attrs.get('__mapper_args__', {}).copy()
|
||||
existing.update(collected_mapper_args)
|
||||
attrs['__mapper_args__'] = existing
|
||||
|
||||
# 3. 处理其他明确的参数
|
||||
if 'table_args' in kwargs:
|
||||
attrs['__table_args__'] = kwargs.pop('table_args')
|
||||
if 'table_name' in kwargs:
|
||||
attrs['__tablename__'] = kwargs.pop('table_name')
|
||||
if 'abstract' in kwargs:
|
||||
attrs['__abstract__'] = kwargs.pop('abstract')
|
||||
|
||||
# 4. 从Annotated元数据中提取sa_type并注入到Field
|
||||
# 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations
|
||||
#
|
||||
# Python 3.14兼容性问题:
|
||||
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误)
|
||||
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
|
||||
# - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施
|
||||
#
|
||||
# 获取annotations的策略:
|
||||
# - Python 3.14+: 优先从__annotate__获取(如果存在)
|
||||
# - fallback: 从__annotations__读取(如果存在)
|
||||
# - 最终fallback: 空字典
|
||||
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
|
||||
|
||||
if annotations:
|
||||
attrs['__annotations__'] = annotations
|
||||
if annotationlib is not None:
|
||||
# 在Python 3.14中禁用descriptor,转为普通dict
|
||||
attrs['__annotate__'] = None
|
||||
|
||||
for field_name, field_type in annotations.items():
|
||||
field_type = _evaluate_annotation_from_string(
|
||||
field_name,
|
||||
annotation_strings,
|
||||
field_type,
|
||||
eval_globals,
|
||||
eval_locals,
|
||||
)
|
||||
|
||||
# 跳过字符串或ForwardRef类型注解,让SQLModel自己处理
|
||||
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
|
||||
continue
|
||||
|
||||
# 跳过特殊类型的字段
|
||||
origin = get_origin(field_type)
|
||||
|
||||
# 跳过 ClassVar 字段 - 它们不是数据库字段
|
||||
if origin is typing.ClassVar:
|
||||
continue
|
||||
|
||||
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
|
||||
if origin is Mapped:
|
||||
continue
|
||||
|
||||
# 尝试从注解中提取sa_type
|
||||
sa_type = _extract_sa_type_from_annotation(field_type)
|
||||
|
||||
if sa_type is not None:
|
||||
# 检查字段是否已有Field定义
|
||||
field_value = attrs.get(field_name, Undefined)
|
||||
|
||||
if field_value is Undefined:
|
||||
# 没有Field定义,创建一个新的Field并注入sa_type
|
||||
attrs[field_name] = Field(sa_type=sa_type)
|
||||
elif isinstance(field_value, FieldInfo):
|
||||
# 已有Field定义,检查是否已设置sa_type
|
||||
# 注意:只有在未设置时才注入,尊重显式配置
|
||||
# SQLModel使用Undefined作为"未设置"的标记
|
||||
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
|
||||
field_value.sa_type = sa_type
|
||||
# 如果field_value是其他类型(如默认值),不处理
|
||||
# SQLModel会在后续处理中将其转换为Field
|
||||
|
||||
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
|
||||
result = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
|
||||
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
|
||||
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
|
||||
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
|
||||
if kwargs.get('table', False):
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__'):
|
||||
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
|
||||
# 只继承子类没有重新定义的关系
|
||||
if rel_name not in result.__sqlmodel_relationships__:
|
||||
result.__sqlmodel_relationships__[rel_name] = rel_info
|
||||
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
|
||||
if hasattr(base, rel_name):
|
||||
base_attr = getattr(base, rel_name)
|
||||
setattr(result, rel_name, base_attr)
|
||||
|
||||
# 7. 检测:禁止子类重定义父类的 Relationship 字段
|
||||
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
|
||||
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
|
||||
for base in bases:
|
||||
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
|
||||
for rel_name in parent_relationships:
|
||||
# 检查当前类是否在 attrs 中重新定义了这个关系字段
|
||||
if rel_name in attrs:
|
||||
raise TypeError(
|
||||
f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。"
|
||||
f"如需修改关系配置,请在父类中修改。"
|
||||
)
|
||||
|
||||
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
|
||||
# SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields
|
||||
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是
|
||||
# Mapped[list['Character']] 这种前向引用,Pydantic 无法解析,
|
||||
# 导致 __pydantic_complete__ = False
|
||||
#
|
||||
# 修复策略:
|
||||
# - 检查类的 __sqlmodel_relationships__ 属性
|
||||
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
|
||||
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
|
||||
relationships = getattr(result, '__sqlmodel_relationships__', {})
|
||||
if relationships:
|
||||
model_fields = getattr(result, 'model_fields', {})
|
||||
pydantic_fields = getattr(result, '__pydantic_fields__', {})
|
||||
|
||||
fields_removed = False
|
||||
for rel_name in relationships:
|
||||
if rel_name in model_fields:
|
||||
del model_fields[rel_name]
|
||||
fields_removed = True
|
||||
if rel_name in pydantic_fields:
|
||||
del pydantic_fields[rel_name]
|
||||
fields_removed = True
|
||||
|
||||
# 如果移除了字段,重新构建 Pydantic 模式
|
||||
# 注意:只在有字段被移除时才 rebuild,避免不必要的开销
|
||||
if fields_removed and hasattr(result, 'model_rebuild'):
|
||||
result.model_rebuild(force=True)
|
||||
|
||||
return result
|
||||
|
||||
def __init__(
|
||||
cls,
|
||||
classname: str,
|
||||
bases: tuple[type, ...],
|
||||
dict_: dict[str, typing.Any],
|
||||
**kw: typing.Any,
|
||||
) -> None:
|
||||
"""
|
||||
重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance)
|
||||
|
||||
SQLModel 原始行为:
|
||||
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
|
||||
- 这阻止了子类创建自己的表
|
||||
|
||||
修复逻辑:
|
||||
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
|
||||
- 强制调用 DeclarativeMeta.__init__ 来创建子表
|
||||
"""
|
||||
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
|
||||
|
||||
# 检查是否是表模型
|
||||
if not is_table_model_class(cls):
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 检查是否有基类是表模型
|
||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
||||
|
||||
if not base_is_table:
|
||||
# 没有基类是表模型,走正常的 SQLModel 流程
|
||||
# 处理关系字段
|
||||
cls._setup_relationships()
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 关键:检测联表继承场景
|
||||
# 条件:
|
||||
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
|
||||
# 2. 当前类有字段带有 foreign_key 指向父表
|
||||
current_tablename = getattr(cls, '__tablename__', None)
|
||||
|
||||
# 查找父表信息
|
||||
parent_table = None
|
||||
parent_tablename = None
|
||||
for base in bases:
|
||||
if is_table_model_class(base) and hasattr(base, '__tablename__'):
|
||||
parent_tablename = base.__tablename__
|
||||
break
|
||||
|
||||
# 检查是否有不同的 tablename
|
||||
has_different_tablename = (
|
||||
current_tablename is not None
|
||||
and parent_tablename is not None
|
||||
and current_tablename != parent_tablename
|
||||
)
|
||||
|
||||
# 检查是否有外键字段指向父表的主键
|
||||
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
|
||||
# 而不是当前类的合并后的 model_fields
|
||||
has_fk_to_parent = False
|
||||
|
||||
def _normalize_tablename(name: str) -> str:
|
||||
"""标准化表名以进行比较(移除下划线,转小写)"""
|
||||
return name.replace('_', '').lower()
|
||||
|
||||
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
|
||||
"""检查 FK 字符串是否指向父表"""
|
||||
if not fk_str or not parent_table:
|
||||
return False
|
||||
# FK 格式: "tablename.column" 或 "schema.tablename.column"
|
||||
parts = fk_str.split('.')
|
||||
if len(parts) >= 2:
|
||||
fk_table = parts[-2] # 取倒数第二个作为表名
|
||||
# 标准化比较(处理下划线差异)
|
||||
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
|
||||
return False
|
||||
|
||||
if has_different_tablename and parent_tablename:
|
||||
# 首先检查当前类的 model_fields
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
|
||||
# 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题)
|
||||
if not has_fk_to_parent:
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
if has_fk_to_parent:
|
||||
break
|
||||
|
||||
is_joined_inheritance = has_different_tablename and has_fk_to_parent
|
||||
|
||||
if is_joined_inheritance:
|
||||
# 联表继承:需要创建子表
|
||||
|
||||
# 修复外键字段:由于字段合并,外键信息可能丢失
|
||||
# 需要从基类的 mixin 中找回外键信息,并重建列
|
||||
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
|
||||
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
# 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段
|
||||
# 所有继承自祖先表的列都不应该在子表中重复创建
|
||||
|
||||
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
|
||||
# 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction)
|
||||
ancestor_column_names: set[str] = set()
|
||||
for ancestor in cls.__mro__:
|
||||
if ancestor is cls:
|
||||
continue # 跳过当前类
|
||||
if is_table_model_class(ancestor):
|
||||
try:
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
|
||||
mapper = sa_inspect(ancestor)
|
||||
for col in mapper.local_table.columns:
|
||||
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
|
||||
if col.name.startswith('_polymorphic'):
|
||||
continue
|
||||
ancestor_column_names.add(col.name)
|
||||
except NoInspectionAvailable:
|
||||
continue
|
||||
|
||||
# 找到子类自己定义的字段(不在父类中的)
|
||||
child_own_fields: set[str] = set()
|
||||
for field_name in cls.model_fields:
|
||||
# 检查这个字段是否是在当前类直接定义的(不是继承的)
|
||||
# 通过检查父类是否有这个字段来判断
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_fields.add(field_name)
|
||||
|
||||
# 从子类类属性中移除父表已有的列定义
|
||||
# 这样 SQLAlchemy 就不会在子表中创建这些列
|
||||
fk_field_name = None
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
pk = getattr(field_info, 'primary_key', False)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
fk_field_name = field_name
|
||||
# 找到了外键字段,重建它
|
||||
# 创建一个新的 Column 对象包含外键约束
|
||||
new_col = Column(
|
||||
field_name,
|
||||
SA_UUID(as_uuid=True),
|
||||
ForeignKey(fk),
|
||||
primary_key=pk if pk else False
|
||||
)
|
||||
setattr(cls, field_name, new_col)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
|
||||
# 这防止 SQLAlchemy 在子表中创建重复列
|
||||
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
|
||||
for col_name in ancestor_column_names:
|
||||
if col_name == fk_field_name:
|
||||
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
|
||||
if col_name == 'id':
|
||||
continue # id 会被 FK 字段覆盖
|
||||
if col_name in child_own_fields:
|
||||
continue # 保留子类自己定义的字段
|
||||
|
||||
# 检查类属性是否是 Column 或 InstrumentedAttribute
|
||||
if col_name in cls.__dict__:
|
||||
attr = cls.__dict__[col_name]
|
||||
# Column 对象或 InstrumentedAttribute 都需要删除
|
||||
if isinstance(attr, (Column, InstrumentedAttribute)):
|
||||
try:
|
||||
delattr(cls, col_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# 找到子类自己定义的关系(不在父类中的)
|
||||
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
|
||||
child_own_relationships: set[str] = set()
|
||||
for rel_name in cls.__sqlmodel_relationships__:
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_relationships.add(rel_name)
|
||||
|
||||
# 只为子类自己定义的新关系调用关系设置
|
||||
if child_own_relationships:
|
||||
cls._setup_relationships(only_these=child_own_relationships)
|
||||
|
||||
# 强制调用 DeclarativeMeta.__init__
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
else:
|
||||
# 非联表继承:单表继承或正常 Pydantic 模型
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
|
||||
"""
|
||||
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
|
||||
|
||||
Args:
|
||||
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
|
||||
如果为 None,设置所有关系(默认行为)
|
||||
"""
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import inspect
|
||||
from sqlmodel.main import get_relationship_to
|
||||
from typing import get_origin
|
||||
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
# 如果指定了 only_these,只设置这些关系
|
||||
if only_these is not None and rel_name not in only_these:
|
||||
continue
|
||||
if rel_info.sa_relationship:
|
||||
setattr(cls, rel_name, rel_info.sa_relationship)
|
||||
continue
|
||||
|
||||
raw_ann = cls.__annotations__[rel_name]
|
||||
origin: typing.Any = get_origin(raw_ann)
|
||||
if origin is Mapped:
|
||||
ann = raw_ann.__args__[0]
|
||||
else:
|
||||
ann = raw_ann
|
||||
cls.__annotations__[rel_name] = Mapped[ann]
|
||||
|
||||
relationship_to = get_relationship_to(
|
||||
name=rel_name, rel_info=rel_info, annotation=ann
|
||||
)
|
||||
rel_kwargs: dict[str, typing.Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
if rel_info.cascade_delete:
|
||||
rel_kwargs["cascade"] = "all, delete-orphan"
|
||||
if rel_info.passive_deletes:
|
||||
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
|
||||
if rel_info.link_model:
|
||||
ins = inspect(rel_info.link_model)
|
||||
local_table = getattr(ins, "local_table")
|
||||
if local_table is None:
|
||||
raise RuntimeError(
|
||||
f"Couldn't find secondary table for {rel_info.link_model}"
|
||||
)
|
||||
rel_kwargs["secondary"] = local_table
|
||||
|
||||
rel_args: list[typing.Any] = []
|
||||
if rel_info.sa_relationship_args:
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
|
||||
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
|
||||
setattr(cls, rel_name, rel_value)
|
||||
|
||||
|
||||
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
|
||||
"""此类必须和TableBase系列类搭配使用"""
|
||||
|
||||
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
|
||||
@@ -1,7 +0,0 @@
|
||||
from .base import SQLModelBase
|
||||
|
||||
class ThemeResponse(SQLModelBase):
|
||||
"""主题响应 DTO"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from utils.conf import appmeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import AsyncGenerator
|
||||
|
||||
ASYNC_DATABASE_URL = appmeta.database_url
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=appmeta.debug,
|
||||
connect_args={
|
||||
"check_same_thread": False
|
||||
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
|
||||
future=True,
|
||||
# pool_size=POOL_SIZE,
|
||||
# max_overflow=64,
|
||||
)
|
||||
|
||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with _async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def init_db(
|
||||
url: str = ASYNC_DATABASE_URL
|
||||
):
|
||||
"""创建数据库结构"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
@@ -1,311 +0,0 @@
|
||||
|
||||
from .setting import Setting, SettingsType
|
||||
from .color import ThemeResponse
|
||||
from utils.conf.appmeta import BackendVersion
|
||||
from utils.password.pwd import Password
|
||||
from loguru import logger as log
|
||||
|
||||
async def migration() -> None:
|
||||
"""
|
||||
数据库迁移函数,初始化默认设置和用户组。
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
log.info('开始进行数据库初始化...')
|
||||
|
||||
await init_default_settings()
|
||||
await init_default_policy()
|
||||
await init_default_group()
|
||||
await init_default_user()
|
||||
|
||||
log.info('数据库初始化结束')
|
||||
|
||||
default_settings: list[Setting] = [
|
||||
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
|
||||
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
|
||||
Setting(name="default_group", value="", type=SettingsType.REGISTER),
|
||||
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
|
||||
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
|
||||
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
|
||||
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
|
||||
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpHost", value="smtp.yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpPort", value="25", type=SettingsType.MAIL),
|
||||
Setting(name="replyTo", value="feedback@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpUser", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpPass", value="", type=SettingsType.MAIL),
|
||||
Setting(name="maxEditSize", value="4194304", type=SettingsType.FILE_EDIT),
|
||||
Setting(name="archive_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="download_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="preview_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="doc_preview_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="upload_credential_timeout", value="1800", type=SettingsType.TIMEOUT),
|
||||
Setting(name="upload_session_timeout", value="86400", type=SettingsType.TIMEOUT),
|
||||
Setting(name="slave_api_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_monitor_timeout", value="600", type=SettingsType.TIMEOUT),
|
||||
Setting(name="share_download_session_timeout", value="2073600", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_callback_check", value="20", type=SettingsType.TIMEOUT),
|
||||
Setting(name="aria2_call_timeout", value="5", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_chunk_retries", value="1", type=SettingsType.RETRY),
|
||||
Setting(name="onedrive_source_timeout", value="1800", type=SettingsType.TIMEOUT),
|
||||
Setting(name="reset_after_upload_failed", value="0", type=SettingsType.UPLOAD),
|
||||
Setting(name="login_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="reg_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="email_active", value="0", type=SettingsType.REGISTER),
|
||||
Setting(name="mail_activation_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
|
||||
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
|
||||
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
|
||||
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
|
||||
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
|
||||
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #009688; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">激活{siteTitle}账户</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
|
||||
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong>:</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您注册{siteTitle},请点击下方按钮完成账户激活。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{activationUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #009688; margin: 0; border-color: #009688; border-style: solid; border-width: 10px 20px;">激活账户</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
|
||||
Setting(name="forget_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="mail_reset_pwd_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
|
||||
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>重设密码</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
|
||||
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
|
||||
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
|
||||
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
|
||||
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
|
||||
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong>:</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
|
||||
Setting(name=f"db_version_{BackendVersion}", value="installed", type=SettingsType.VERSION),
|
||||
Setting(name="hot_share_num", value="10", type=SettingsType.SHARE),
|
||||
Setting(name="gravatar_server", value="https://www.gravatar.com/", type=SettingsType.AVATAR),
|
||||
Setting(name="defaultTheme", value="#3f51b5", type=SettingsType.BASIC),
|
||||
Setting(name="themes", value=ThemeResponse().model_dump_json(), type=SettingsType.BASIC),
|
||||
Setting(name="aria2_token", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_rpcurl", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_temp_path", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_options", value="{}", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_interval", value="60", type=SettingsType.ARIA2),
|
||||
Setting(name="max_worker_num", value="10", type=SettingsType.TASK),
|
||||
Setting(name="max_parallel_transfer", value="4", type=SettingsType.TASK),
|
||||
Setting(name="secret_key", value=Password.generate(256), type=SettingsType.AUTH),
|
||||
Setting(name="temp_path", value="temp", type=SettingsType.PATH),
|
||||
Setting(name="avatar_path", value="avatar", type=SettingsType.PATH),
|
||||
Setting(name="avatar_size", value="2097152", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_l", value="200", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_m", value="130", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_s", value="50", type=SettingsType.AVATAR),
|
||||
Setting(name="home_view_method", value="icon", type=SettingsType.VIEW),
|
||||
Setting(name="share_view_method", value="list", type=SettingsType.VIEW),
|
||||
Setting(name="cron_garbage_collect", value="@hourly", type=SettingsType.CRON),
|
||||
Setting(name="authn_enabled", value="0", type=SettingsType.AUTHN),
|
||||
Setting(name="captcha_height", value="60", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_width", value="240", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_mode", value="3", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ComplexOfNoiseText", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ComplexOfNoiseDot", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowHollowLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowNoiseDot", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowNoiseText", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsUseReCaptcha", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="defaultKey", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="defaultSecret", type=SettingsType.CAPTCHA),
|
||||
Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
|
||||
Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
|
||||
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
|
||||
Setting(name="pwa_medium_icon", value="/static/img/logo192.png", type=SettingsType.PWA),
|
||||
Setting(name="pwa_large_icon", value="/static/img/logo512.png", type=SettingsType.PWA),
|
||||
Setting(name="pwa_display", value="standalone", type=SettingsType.PWA),
|
||||
Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA),
|
||||
Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA),
|
||||
]
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
|
||||
)
|
||||
if ver and ver.value == "installed":
|
||||
return
|
||||
|
||||
# 批量添加默认设置
|
||||
await Setting.add(session, default_settings)
|
||||
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group, GroupOptions
|
||||
from .policy import Policy, GroupPolicyLink
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
async for session in get_session():
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
default_policy_id = default_policy.id if default_policy else None
|
||||
|
||||
# 未找到初始管理组时,则创建
|
||||
if not await Group.get(session, Group.name == "管理员"):
|
||||
admin_group = Group(
|
||||
name="管理员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
admin=True,
|
||||
)
|
||||
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
||||
await admin_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=admin_group_id,
|
||||
archive_download=True,
|
||||
archive_task=True,
|
||||
share_download=True,
|
||||
share_free=True,
|
||||
aria2=True,
|
||||
select_node=True,
|
||||
advance_delete=True,
|
||||
).save(session)
|
||||
|
||||
# 关联默认存储策略
|
||||
if default_policy_id:
|
||||
session.add(GroupPolicyLink(
|
||||
group_id=admin_group_id,
|
||||
policy_id=default_policy_id,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not await Group.get(session, Group.name == "注册会员"):
|
||||
member_group = Group(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
)
|
||||
member_group_id = member_group.id # 在 save 前保存 UUID
|
||||
await member_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=member_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
# 关联默认存储策略
|
||||
if default_policy_id:
|
||||
session.add(GroupPolicyLink(
|
||||
group_id=member_group_id,
|
||||
policy_id=default_policy_id,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
# 更新 default_group 设置为注册会员组的 UUID
|
||||
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
||||
if default_group_setting:
|
||||
default_group_setting.value = str(member_group_id)
|
||||
await default_group_setting.save(session)
|
||||
|
||||
# 未找到初始游客组时,则创建
|
||||
if not await Group.get(session, Group.name == "游客"):
|
||||
guest_group = Group(
|
||||
name="游客",
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
)
|
||||
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
||||
await guest_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=guest_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
# 游客组不关联存储策略(无法上传)
|
||||
|
||||
async def init_default_user() -> None:
|
||||
from .user import User
|
||||
from .group import Group
|
||||
from .object import Object, ObjectType
|
||||
from .policy import Policy
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.username == "admin")
|
||||
|
||||
if not admin_user:
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(session, Group.name == "管理员")
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
raise RuntimeError("默认存储策略不存在,无法创建管理员用户")
|
||||
default_policy_id = default_policy.id # 在后续 save 前保存 UUID
|
||||
|
||||
# 生成管理员密码
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
nickname="admin",
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
admin_username = admin_user.username
|
||||
await admin_user.save(session)
|
||||
|
||||
# 为管理员创建根目录(使用用户名作为目录名)
|
||||
await Object(
|
||||
name=admin_username,
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy_id,
|
||||
).save(session)
|
||||
|
||||
log.warning('请注意,账号密码仅显示一次,请妥善保管')
|
||||
log.info(f'初始管理员账号: admin')
|
||||
log.info(f'初始管理员密码: {admin_password}')
|
||||
|
||||
|
||||
async def init_default_policy() -> None:
|
||||
from .policy import Policy, PolicyType
|
||||
from .database import get_session
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
log.info('初始化默认存储策略...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查默认存储策略是否存在
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
|
||||
if not default_policy:
|
||||
local_policy = Policy(
|
||||
name="本地存储",
|
||||
type=PolicyType.LOCAL,
|
||||
server="./data",
|
||||
is_private=True,
|
||||
max_size=0,
|
||||
auto_rename=True,
|
||||
dir_name_rule="{date}/{randomkey16}",
|
||||
file_name_rule="{randomkey16}_{originname}",
|
||||
)
|
||||
|
||||
local_policy = await local_policy.save(session)
|
||||
|
||||
# 创建物理存储目录
|
||||
storage_service = LocalStorageService(local_policy)
|
||||
await storage_service.ensure_base_directory()
|
||||
|
||||
log.info('已创建默认本地存储策略,存储目录:./data')
|
||||
@@ -1,543 +0,0 @@
|
||||
# SQLModel Mixin Module
|
||||
|
||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
||||
|
||||
## Module Overview
|
||||
|
||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
||||
|
||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
||||
- **JWT Authentication**: Token generation and validation
|
||||
- **Pagination & Sorting**: Standardized table view parameters
|
||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
sqlmodels/mixin/
|
||||
├── __init__.py # Module exports
|
||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
||||
└── jwt/ # JWT authentication
|
||||
├── __init__.py
|
||||
├── key.py # JWTKey database model
|
||||
├── payload.py # JWTPayloadBase
|
||||
├── manager.py # JWTManager singleton
|
||||
├── auth.py # JWTAuthMixin
|
||||
├── exceptions.py # JWT-related exceptions
|
||||
└── responses.py # TokenResponse DTO
|
||||
```
|
||||
|
||||
## Dependency Hierarchy
|
||||
|
||||
The module has a strict import order to avoid circular dependencies:
|
||||
|
||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
||||
2. **table.py** - Depends on `polymorphic.py`
|
||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. TableBaseMixin
|
||||
|
||||
Base mixin for database table models with integer primary keys.
|
||||
|
||||
**Features:**
|
||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
||||
- Async relationship loading support (via `AsyncAttrs`)
|
||||
- Pagination and sorting via `TableViewRequest`
|
||||
- Polymorphic subclass loading support
|
||||
|
||||
**Fields:**
|
||||
- `id: int | None` - Integer primary key (auto-increment)
|
||||
- `created_at: datetime` - Record creation timestamp
|
||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
||||
name: str
|
||||
email: str
|
||||
"""User email"""
|
||||
|
||||
# CRUD operations
|
||||
async def example(session: AsyncSession):
|
||||
# Add
|
||||
user = User(name="Alice", email="alice@example.com")
|
||||
user = await user.save(session)
|
||||
|
||||
# Get
|
||||
user = await User.get(session, User.id == 1)
|
||||
|
||||
# Update
|
||||
update_data = UserUpdateRequest(name="Alice Smith")
|
||||
user = await user.update(session, update_data)
|
||||
|
||||
# Delete
|
||||
await User.delete(session, user)
|
||||
|
||||
# Count
|
||||
count = await User.count(session, User.is_active == True)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
||||
```python
|
||||
# ✅ Correct
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device
|
||||
```
|
||||
|
||||
### 2. UUIDTableBaseMixin
|
||||
|
||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
||||
|
||||
**Differences from TableBaseMixin:**
|
||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
description: str | None = None
|
||||
"""Character description"""
|
||||
```
|
||||
|
||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
||||
|
||||
### 3. TableViewRequest
|
||||
|
||||
Standardized pagination and sorting parameters for LIST endpoints.
|
||||
|
||||
**Fields:**
|
||||
- `offset: int | None` - Skip first N records (default: 0)
|
||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
||||
- `desc: bool | None` - Sort descending (default: True)
|
||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
||||
|
||||
**Usage with TableBaseMixin.get():**
|
||||
```python
|
||||
from dependencies import TableViewRequestDep
|
||||
|
||||
@router.get("/list")
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> list[Character]:
|
||||
"""List characters with pagination and sorting"""
|
||||
return await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view # Automatically handles pagination and sorting
|
||||
)
|
||||
```
|
||||
|
||||
**Manual usage:**
|
||||
```python
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
||||
|
||||
### 4. PolymorphicBaseMixin
|
||||
|
||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
||||
|
||||
**Automatic Configuration:**
|
||||
- Defines `_polymorphic_name: str` field (indexed)
|
||||
- Sets `polymorphic_on='_polymorphic_name'`
|
||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
||||
|
||||
**Methods:**
|
||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
||||
"""Abstract base class for all tools"""
|
||||
name: str
|
||||
description: str
|
||||
"""Tool description"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, params: dict) -> dict:
|
||||
"""Execute the tool"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Why Single Underscore Prefix?**
|
||||
- SQLAlchemy maps single-underscore fields to database columns
|
||||
- Pydantic treats them as private (excluded from serialization)
|
||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
||||
|
||||
### 5. create_subclass_id_mixin()
|
||||
|
||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
||||
|
||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
||||
"""
|
||||
Args:
|
||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
A mixin class containing id field (foreign key + primary key)
|
||||
"""
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import create_subclass_id_mixin
|
||||
|
||||
# Create mixin for ASR subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
||||
"""FunASR implementation"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
||||
|
||||
### 6. AutoPolymorphicIdentityMixin
|
||||
|
||||
Automatically generates `polymorphic_identity` based on class name.
|
||||
|
||||
**Naming Convention:**
|
||||
- Format: `{parent_identity}.{classname_lowercase}`
|
||||
- If no parent identity exists, uses `{classname_lowercase}`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
||||
|
||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
"""Base class for function-type tools"""
|
||||
pass
|
||||
# polymorphic_identity = 'function'
|
||||
|
||||
class GetWeatherFunction(Function, table=True):
|
||||
"""Weather query function"""
|
||||
pass
|
||||
# polymorphic_identity = 'function.getweatherfunction'
|
||||
```
|
||||
|
||||
**Manual Override:**
|
||||
```python
|
||||
class CustomTool(
|
||||
Tool,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_identity='custom_name', # Override auto-generated name
|
||||
table=True
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. JWTAuthMixin
|
||||
|
||||
Provides JWT token generation and validation for entity classes (User, Client).
|
||||
|
||||
**Methods:**
|
||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
||||
|
||||
**Requirements:**
|
||||
Subclasses must define:
|
||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
||||
|
||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
||||
JWTPayload = UserJWTPayload # Define payload model
|
||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
||||
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
is_active: bool = True
|
||||
"""User active status"""
|
||||
|
||||
# Generate token
|
||||
async def login(session: AsyncSession, user: User) -> str:
|
||||
token = await user.issue_jwt(session)
|
||||
return token
|
||||
|
||||
# Validate token
|
||||
async def verify(session: AsyncSession, token: str) -> User:
|
||||
user = await User.from_jwt(session, token)
|
||||
return user
|
||||
```
|
||||
|
||||
### 8. Response DTO Mixins
|
||||
|
||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
||||
|
||||
**Available Mixins:**
|
||||
- `IntIdInfoMixin` - Integer ID field
|
||||
- `UUIDIdInfoMixin` - UUID ID field
|
||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
||||
|
||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
||||
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
"""Character response DTO with id and timestamps"""
|
||||
pass # Inherits id, created_at, updated_at from mixin
|
||||
```
|
||||
|
||||
## Complete Joined Table Inheritance Example
|
||||
|
||||
Here's a complete example demonstrating polymorphic inheritance:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import (
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. Define Base class (fields only, no table)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
"""Configuration name"""
|
||||
|
||||
base_url: str
|
||||
"""Service URL"""
|
||||
|
||||
# 2. Define abstract parent class (with table)
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract base class for ASR configurations"""
|
||||
# PolymorphicBaseMixin automatically provides:
|
||||
# - _polymorphic_name field
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
"""Transcribe audio to text"""
|
||||
pass
|
||||
|
||||
# 3. Create ID Mixin for second-level subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. Create second-level abstract class (if needed)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
"""FunASR abstract base (may have multiple implementations)"""
|
||||
pass
|
||||
# polymorphic_identity = 'funasr'
|
||||
|
||||
# 5. Create concrete implementation classes
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
"""FunASR local deployment"""
|
||||
# polymorphic_identity = 'funasr.funasrlocal'
|
||||
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
# Implementation...
|
||||
return "transcribed text"
|
||||
|
||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# Returns: [FunASRLocal, ...]
|
||||
```
|
||||
|
||||
## Import Guidelines
|
||||
|
||||
**Standard Import:**
|
||||
```python
|
||||
from sqlmodels.mixin import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
TableViewRequest,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
JWTAuthMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
||||
```python
|
||||
# Legacy import path (still works)
|
||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
||||
|
||||
# Recommended new import path
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Mixin Order Matters
|
||||
|
||||
**Correct Order:**
|
||||
```python
|
||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
**Wrong Order:**
|
||||
```python
|
||||
# ❌ ID Mixin not first - won't override parent's id field
|
||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Always Use Return Values from save() and update()
|
||||
|
||||
```python
|
||||
# ✅ Correct - use returned instance
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device # AttributeError when accessing fields
|
||||
```
|
||||
|
||||
### 3. Prefer table_view Over Manual Pagination
|
||||
|
||||
```python
|
||||
# ✅ Recommended - consistent across all endpoints
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
|
||||
# ⚠️ Works but not recommended - manual parameter management
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
offset=0,
|
||||
limit=20,
|
||||
order_by=[desc(Character.created_at)]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Polymorphic Loading for Many Subclasses
|
||||
|
||||
```python
|
||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
||||
)
|
||||
|
||||
# For fewer subclasses, specify the list explicitly
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Response DTOs Should Inherit Base Classes
|
||||
|
||||
```python
|
||||
# ✅ Correct - inherits from CharacterBase
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
pass
|
||||
|
||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
||||
name: str # Duplicated field definition
|
||||
description: str | None = None
|
||||
```
|
||||
|
||||
**Reason:** Inheriting from Base classes ensures:
|
||||
- Type checking via `isinstance(obj, XxxBase)`
|
||||
- Consistency across related DTOs
|
||||
- Future field additions automatically propagate
|
||||
|
||||
### 6. Use Specific Types, Not Containers
|
||||
|
||||
```python
|
||||
# ✅ Correct - specific DTO for config updates
|
||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
||||
weather_api_key: str | None = None
|
||||
default_location: str | None = None
|
||||
"""Default location"""
|
||||
|
||||
# ❌ Wrong - lose type safety
|
||||
class ToolUpdateRequest(SQLModelBase):
|
||||
config: dict[str, Any] # No field validation
|
||||
```
|
||||
|
||||
## Type Variables
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import T, M
|
||||
|
||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import now, now_date
|
||||
|
||||
# Lambda functions for default factories
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
```
|
||||
|
||||
## Related Modules
|
||||
|
||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
||||
- **sqlmodels.user** - User model with JWT authentication
|
||||
- **sqlmodels.client** - Client model with JWT authentication
|
||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
SQLModel Mixin模块
|
||||
|
||||
提供各种Mixin类供SQLModel实体使用。
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. table(依赖 polymorphic)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
)
|
||||
# table 依赖 polymorphic
|
||||
from .table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
TableViewRequest,
|
||||
ListResponse,
|
||||
T,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
IntIdInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
)
|
||||
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
InfoResponse DTO Mixin模块
|
||||
|
||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
||||
|
||||
设计说明:
|
||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from models.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: int
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class UUIDIdInfoMixin(SQLModelBase):
|
||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: UUID
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class DatetimeInfoMixin(SQLModelBase):
|
||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
||||
"""整数ID + 时间戳响应mixin"""
|
||||
pass
|
||||
|
||||
|
||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
||||
"""UUID ID + 时间戳响应mixin"""
|
||||
pass
|
||||
@@ -1,456 +0,0 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import String, inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的default_factory丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
|
||||
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory,
|
||||
遵循单一真相原则(不硬编码 default_factory)。
|
||||
|
||||
需要修复的字段:
|
||||
- id: 主键(从父类获取 default_factory)
|
||||
- created_at: 创建时间戳(从父类获取 default_factory)
|
||||
- updated_at: 更新时间戳(从父类获取 default_factory)
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 污染的
|
||||
if not isinstance(field_info.default, InstrumentedAttribute):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
# 动态检测所有需要修复的字段
|
||||
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
|
||||
# 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染)
|
||||
# 2. 原始定义有 default_factory 或明确的 default 值
|
||||
#
|
||||
# 覆盖场景:
|
||||
# - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复
|
||||
# - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增)
|
||||
# - created_at/updated_at:有 default_factory=now,需要修复
|
||||
# - 外键字段(created_by_id等):有 default=None,需要修复
|
||||
# - 普通字段(name, temperature等):无 default_factory,不需要修复
|
||||
#
|
||||
# MRO 查找保证:
|
||||
# - 在多重继承场景下,MRO 顺序是确定性的
|
||||
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute)
|
||||
if not isinstance(current_field.default, InstrumentedAttribute):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制SQLModel特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
Example:
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,66 +0,0 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class OrderStatus(StrEnum):
|
||||
"""订单状态枚举"""
|
||||
PENDING = "pending"
|
||||
"""待支付"""
|
||||
COMPLETED = "completed"
|
||||
"""已完成"""
|
||||
CANCELLED = "cancelled"
|
||||
"""已取消"""
|
||||
|
||||
|
||||
class OrderType(StrEnum):
|
||||
"""订单类型枚举"""
|
||||
# [TODO] 补充具体订单类型
|
||||
pass
|
||||
|
||||
|
||||
class Order(SQLModelBase, TableBaseMixin):
|
||||
"""订单模型"""
|
||||
|
||||
order_no: str = Field(max_length=255, unique=True, index=True)
|
||||
"""订单号,唯一"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单类型 [TODO] 待定义枚举"""
|
||||
|
||||
method: str | None = Field(default=None, max_length=255)
|
||||
"""支付方式"""
|
||||
|
||||
product_id: int | None = Field(default=None)
|
||||
"""商品ID"""
|
||||
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||
"""购买数量"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""商品名称"""
|
||||
|
||||
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单价格(分)"""
|
||||
|
||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||
"""订单状态"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="orders")
|
||||
@@ -1,23 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlmodel import Field, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
|
||||
class RedeemType(StrEnum):
|
||||
"""兑换码类型枚举"""
|
||||
# [TODO] 补充具体兑换码类型
|
||||
pass
|
||||
|
||||
|
||||
class Redeem(SQLModelBase, TableBaseMixin):
|
||||
"""兑换码模型"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""兑换码类型 [TODO] 待定义枚举"""
|
||||
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
|
||||
code: str = Field(unique=True, index=True, description="兑换码,唯一")
|
||||
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
|
||||
@@ -1,31 +0,0 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class StoragePack(SQLModelBase, TableBaseMixin):
|
||||
"""容量包模型"""
|
||||
|
||||
name: str = Field(max_length=255, description="容量包名称")
|
||||
active_time: datetime | None = Field(default=None, description="激活时间")
|
||||
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
|
||||
size: int = Field(description="容量包大小(字节)")
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="storage_packs")
|
||||
@@ -1,61 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class AuthnResponse(SQLModelBase):
|
||||
"""WebAuthn 响应 DTO"""
|
||||
|
||||
id: str
|
||||
"""凭证ID"""
|
||||
|
||||
fingerprint: str
|
||||
"""凭证指纹"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class UserAuthn(SQLModelBase, TableBaseMixin):
|
||||
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
||||
|
||||
credential_id: str = Field(max_length=255, unique=True, index=True)
|
||||
"""凭证 ID,Base64 编码"""
|
||||
|
||||
credential_public_key: str = Field(sa_column=Column(Text))
|
||||
"""凭证公钥,Base64 编码"""
|
||||
|
||||
sign_count: int = Field(default=0, ge=0)
|
||||
"""签名计数器,用于防重放攻击"""
|
||||
|
||||
credential_device_type: str = Field(max_length=32)
|
||||
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
||||
|
||||
credential_backed_up: bool = Field(default=False)
|
||||
"""凭证是否已备份"""
|
||||
|
||||
transports: str | None = Field(default=None, max_length=255)
|
||||
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
"""用户自定义的凭证名称,便于识别"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="authns")
|
||||
@@ -1,33 +0,0 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class WebDAV(SQLModelBase, TableBaseMixin):
|
||||
"""WebDAV账户模型"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255, description="WebDAV账户名")
|
||||
password: str = Field(max_length=255, description="WebDAV密码")
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
|
||||
readonly: bool = Field(default=False, description="是否只读")
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
@@ -11,10 +11,13 @@ dependencies = [
|
||||
"argon2-cffi>=25.1.0",
|
||||
"asyncpg>=0.31.0",
|
||||
"cachetools>=6.2.4",
|
||||
"captcha>=0.7.1",
|
||||
"cryptography>=46.0.3",
|
||||
"fastapi[standard]>=0.122.0",
|
||||
"httpx>=0.27.0",
|
||||
"itsdangerous>=2.2.0",
|
||||
"loguru>=0.7.3",
|
||||
"orjson>=3.11.7",
|
||||
"pyjwt>=2.10.1",
|
||||
"pyotp>=2.9.0",
|
||||
"pytest>=9.0.2",
|
||||
@@ -26,8 +29,18 @@ dependencies = [
|
||||
"redis[hiredis]>=7.1.0",
|
||||
"sqlalchemy>=2.0.44",
|
||||
"sqlmodel>=0.0.27",
|
||||
"sqlmodel-ext[pgvector]>=0.1.1",
|
||||
"uvicorn>=0.38.0",
|
||||
"webauthn>=2.7.0",
|
||||
"whatthepatch>=1.0.6",
|
||||
"wsgidav>=4.3.0",
|
||||
"a2wsgi>=1.10.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
build = [
|
||||
"cython>=3.0.11",
|
||||
"setuptools>=75.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .api import router as api_router
|
||||
from .wopi import wopi_router
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(api_router)
|
||||
router.include_router(wopi_router)
|
||||
@@ -5,15 +5,16 @@ from utils.conf import appmeta
|
||||
from .admin import admin_router
|
||||
|
||||
from .callback import callback_router
|
||||
from .category import category_router
|
||||
from .directory import directory_router
|
||||
from .download import download_router
|
||||
from .file import router as file_router
|
||||
from .object import object_router
|
||||
from .share import share_router
|
||||
from .trash import trash_router
|
||||
from .site import site_router
|
||||
from .slave import slave_router
|
||||
from .user import user_router
|
||||
from .vas import vas_router
|
||||
from .webdav import webdav_router
|
||||
|
||||
router = APIRouter(prefix="/v1")
|
||||
@@ -23,14 +24,15 @@ router = APIRouter(prefix="/v1")
|
||||
if appmeta.mode == "master":
|
||||
router.include_router(admin_router)
|
||||
router.include_router(callback_router)
|
||||
router.include_router(category_router)
|
||||
router.include_router(directory_router)
|
||||
router.include_router(download_router)
|
||||
router.include_router(file_router)
|
||||
router.include_router(object_router)
|
||||
router.include_router(share_router)
|
||||
router.include_router(site_router)
|
||||
router.include_router(trash_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(vas_router)
|
||||
router.include_router(webdav_router)
|
||||
elif appmeta.mode == "slave":
|
||||
router.include_router(slave_router)
|
||||
|
||||
@@ -5,24 +5,31 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
)
|
||||
from models.base import SQLModelBase
|
||||
from models.setting import (
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from sqlmodels.setting import (
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from models.setting import SettingsType
|
||||
from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
|
||||
try:
|
||||
from ee.service import get_cached_license
|
||||
except ImportError:
|
||||
get_cached_license = None
|
||||
|
||||
from .file import admin_file_router
|
||||
from .file_app import admin_file_app_router
|
||||
from .group import admin_group_router
|
||||
from .policy import admin_policy_router
|
||||
from .share import admin_share_router
|
||||
from .task import admin_task_router
|
||||
from .user import admin_user_router
|
||||
from .vas import admin_vas_router
|
||||
from .theme import admin_theme_router
|
||||
|
||||
|
||||
class Aria2TestRequest(SQLModelBase):
|
||||
@@ -43,10 +50,11 @@ admin_router = APIRouter(
|
||||
admin_router.include_router(admin_group_router)
|
||||
admin_router.include_router(admin_user_router)
|
||||
admin_router.include_router(admin_file_router)
|
||||
admin_router.include_router(admin_file_app_router)
|
||||
admin_router.include_router(admin_policy_router)
|
||||
admin_router.include_router(admin_share_router)
|
||||
admin_router.include_router(admin_task_router)
|
||||
admin_router.include_router(admin_vas_router)
|
||||
admin_router.include_router(admin_theme_router)
|
||||
|
||||
# 离线下载 /api/admin/aria2
|
||||
admin_aria2_router = APIRouter(
|
||||
@@ -155,14 +163,24 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
||||
if site_url_setting and site_url_setting.value:
|
||||
site_urls.append(site_url_setting.value)
|
||||
|
||||
# 许可证信息(从设置读取或使用默认值)
|
||||
license_info = LicenseInfo(
|
||||
expired_at=now + timedelta(days=365),
|
||||
signed_at=now,
|
||||
root_domains=[],
|
||||
domains=[],
|
||||
vol_domains=[],
|
||||
)
|
||||
# 许可证信息(Pro 版本从缓存读取,CE 版本永不过期)
|
||||
if appmeta.IsPro and get_cached_license:
|
||||
payload = get_cached_license()
|
||||
license_info = LicenseInfo(
|
||||
expired_at=payload.expires_at,
|
||||
signed_at=payload.issued_at,
|
||||
root_domains=[],
|
||||
domains=[payload.domain],
|
||||
vol_domains=[],
|
||||
)
|
||||
else:
|
||||
license_info = LicenseInfo(
|
||||
expired_at=datetime.max,
|
||||
signed_at=now,
|
||||
root_domains=[],
|
||||
domains=[],
|
||||
vol_domains=[],
|
||||
)
|
||||
|
||||
# 版本信息
|
||||
version_info = VersionInfo(
|
||||
@@ -221,11 +239,11 @@ async def router_admin_update_settings(
|
||||
|
||||
if existing:
|
||||
existing.value = item.value
|
||||
await existing.save(session)
|
||||
existing = await existing.save(session)
|
||||
updated_count += 1
|
||||
else:
|
||||
new_setting = Setting(type=item.type, name=item.name, value=item.value)
|
||||
await new_setting.save(session)
|
||||
new_setting = await new_setting.save(session)
|
||||
created_count += 1
|
||||
|
||||
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
|
||||
@@ -279,16 +297,17 @@ async def router_admin_get_settings(
|
||||
path='/test',
|
||||
summary='测试 Aria2 连接',
|
||||
description='Test Aria2 RPC connection',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_aira2_test(
|
||||
request: Aria2TestRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
测试 Aria2 RPC 连接。
|
||||
|
||||
:param request: 测试请求
|
||||
:return: 测试结果
|
||||
:raises HTTPException: 连接失败时抛出 400
|
||||
"""
|
||||
import aiohttp
|
||||
|
||||
@@ -303,22 +322,18 @@ async def router_admin_aira2_test(
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
|
||||
if resp.status != 200:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg=f"连接失败,HTTP {resp.status}"
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"连接失败,HTTP {resp.status}",
|
||||
)
|
||||
|
||||
result = await resp.json()
|
||||
if "error" in result:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg=f"Aria2 错误: {result['error']['message']}"
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Aria2 错误: {result['error']['message']}",
|
||||
)
|
||||
|
||||
version = result.get("result", {}).get("version", "unknown")
|
||||
return ResponseBase(data={
|
||||
"connected": True,
|
||||
"version": version,
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||
@@ -5,14 +5,60 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
Policy, PolicyType, User, ResponseBase, ListResponse,
|
||||
from sqlmodels import (
|
||||
JWTPayload, Policy, PolicyType, User, ListResponse,
|
||||
Object, ObjectType, AdminFileResponse, FileBanRequest, )
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
async def _set_ban_recursive(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
ban: bool,
|
||||
admin_id: UUID,
|
||||
reason: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
递归设置封禁状态,返回受影响对象数量。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要封禁/解禁的对象
|
||||
:param ban: True=封禁, False=解禁
|
||||
:param admin_id: 管理员UUID
|
||||
:param reason: 封禁原因
|
||||
:return: 受影响的对象数量
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# 如果是文件夹,先递归处理子对象
|
||||
if obj.is_folder:
|
||||
children = await Object.get(
|
||||
session,
|
||||
Object.parent_id == obj.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for child in children:
|
||||
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
|
||||
|
||||
# 设置当前对象
|
||||
obj.is_banned = ban
|
||||
if ban:
|
||||
obj.banned_at = datetime.now()
|
||||
obj.banned_by = admin_id
|
||||
obj.ban_reason = reason
|
||||
else:
|
||||
obj.banned_at = None
|
||||
obj.banned_by = None
|
||||
obj.ban_reason = None
|
||||
|
||||
obj = await obj.save(session)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
admin_file_router = APIRouter(
|
||||
prefix="/file",
|
||||
tags=["admin", "admin_file"],
|
||||
@@ -85,9 +131,7 @@ async def router_admin_preview_file(
|
||||
:param file_id: 文件UUID
|
||||
:return: 文件内容
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
@@ -118,45 +162,30 @@ async def router_admin_preview_file(
|
||||
path='/ban/{file_id}',
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
) -> ResponseBase:
|
||||
claims: Annotated[JWTPayload, Depends(admin_required)],
|
||||
) -> None:
|
||||
"""
|
||||
封禁或解禁文件。封禁后用户无法访问该文件。
|
||||
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
|
||||
封禁文件夹时会级联封禁所有子对象。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param file_id: 文件UUID
|
||||
:param request: 封禁请求
|
||||
:param admin: 当前管理员
|
||||
:param claims: 当前管理员 JWT claims
|
||||
:return: 封禁结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
file_obj.is_banned = request.is_banned
|
||||
if request.is_banned:
|
||||
file_obj.banned_at = datetime.now()
|
||||
file_obj.banned_by = admin.id
|
||||
file_obj.ban_reason = request.reason
|
||||
else:
|
||||
file_obj.banned_at = None
|
||||
file_obj.banned_by = None
|
||||
file_obj.ban_reason = None
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||
|
||||
file_obj = await file_obj.save(session)
|
||||
|
||||
action = "封禁" if request.is_banned else "解禁"
|
||||
l.info(f"管理员{action}了文件: {file_obj.name}")
|
||||
return ResponseBase(data={
|
||||
"id": str(file_obj.id),
|
||||
"is_banned": file_obj.is_banned,
|
||||
})
|
||||
action = "封禁" if request.ban else "解禁"
|
||||
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
|
||||
|
||||
|
||||
@admin_file_router.delete(
|
||||
@@ -164,12 +193,13 @@ async def router_admin_ban_file(
|
||||
summary='删除文件',
|
||||
description='Delete file by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
delete_physical: bool = True,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除文件。
|
||||
|
||||
@@ -178,9 +208,7 @@ async def router_admin_delete_file(
|
||||
:param delete_physical: 是否同时删除物理文件
|
||||
:return: 删除结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
file_obj = await Object.get_exist_one(session, file_id)
|
||||
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
@@ -212,4 +240,3 @@ async def router_admin_delete_file(
|
||||
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
450
routers/api/v1/admin/file_app/__init__.py
Normal file
450
routers/api/v1/admin/file_app/__init__.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
管理员文件应用管理端点
|
||||
|
||||
提供文件查看器应用的 CRUD、扩展名管理、用户组权限管理和 WOPI Discovery。
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import select
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from service.wopi import parse_wopi_discovery_xml
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppCreateRequest,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppListResponse,
|
||||
FileAppResponse,
|
||||
FileAppUpdateRequest,
|
||||
ExtensionUpdateRequest,
|
||||
GroupAccessUpdateRequest,
|
||||
WopiDiscoveredExtension,
|
||||
WopiDiscoveryResponse,
|
||||
)
|
||||
from sqlmodels.file_app import FileAppType
|
||||
from utils import http_exceptions
|
||||
|
||||
admin_file_app_router = APIRouter(
|
||||
prefix="/file-app",
|
||||
tags=["admin", "file_app"],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
|
||||
|
||||
@admin_file_app_router.get(
|
||||
path='/list',
|
||||
summary='列出所有文件应用',
|
||||
)
|
||||
async def list_file_apps(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep,
|
||||
) -> FileAppListResponse:
|
||||
"""
|
||||
列出所有文件应用端点(分页)
|
||||
|
||||
认证:管理员权限
|
||||
"""
|
||||
result = await FileApp.get_with_count(
|
||||
session,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
apps: list[FileAppResponse] = []
|
||||
for app in result.items:
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
apps.append(FileAppResponse.from_app(app, extensions, group_links))
|
||||
|
||||
return FileAppListResponse(apps=apps, total=result.count)
|
||||
|
||||
|
||||
@admin_file_app_router.post(
|
||||
path='/',
|
||||
summary='创建文件应用',
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_file_app(
|
||||
session: SessionDep,
|
||||
request: FileAppCreateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
创建文件应用端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 409: app_key 已存在
|
||||
"""
|
||||
# 检查 app_key 唯一
|
||||
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||
|
||||
# 创建应用
|
||||
app = FileApp(
|
||||
name=request.name,
|
||||
app_key=request.app_key,
|
||||
type=request.type,
|
||||
icon=request.icon,
|
||||
description=request.description,
|
||||
is_enabled=request.is_enabled,
|
||||
is_restricted=request.is_restricted,
|
||||
iframe_url_template=request.iframe_url_template,
|
||||
wopi_discovery_url=request.wopi_discovery_url,
|
||||
wopi_editor_url_template=request.wopi_editor_url_template,
|
||||
)
|
||||
app = await app.save(session)
|
||||
app_id = app.id
|
||||
|
||||
# 创建扩展名关联
|
||||
extensions: list[FileAppExtension] = []
|
||||
for i, ext in enumerate(request.extensions):
|
||||
normalized = ext.lower().strip().lstrip('.')
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=normalized,
|
||||
priority=i,
|
||||
)
|
||||
ext_record = await ext_record.save(session)
|
||||
extensions.append(ext_record)
|
||||
|
||||
# 创建用户组关联
|
||||
group_links: list[FileAppGroupLink] = []
|
||||
for group_id in request.allowed_group_ids:
|
||||
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||
session.add(link)
|
||||
group_links.append(link)
|
||||
if group_links:
|
||||
await session.commit()
|
||||
await session.refresh(app)
|
||||
|
||||
l.info(f"创建文件应用: {app.name} ({app.app_key})")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.get(
|
||||
path='/{app_id}',
|
||||
summary='获取文件应用详情',
|
||||
)
|
||||
async def get_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
获取文件应用详情端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.patch(
|
||||
path='/{app_id}',
|
||||
summary='更新文件应用',
|
||||
)
|
||||
async def update_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: FileAppUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
更新文件应用端点
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 409: 新 app_key 已被其他应用使用
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 检查 app_key 唯一性
|
||||
if request.app_key is not None and request.app_key != app.app_key:
|
||||
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
|
||||
|
||||
# 更新非 None 字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(app, key, value)
|
||||
|
||||
app = await app.save(session)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
l.info(f"更新文件应用: {app.name} ({app.app_key})")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.delete(
|
||||
path='/{app_id}',
|
||||
summary='删除文件应用',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_file_app(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除文件应用端点(级联删除扩展名、用户偏好和用户组关联)
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
app_name = app.app_key
|
||||
await FileApp.delete(session, app)
|
||||
l.info(f"删除文件应用: {app_name}")
|
||||
|
||||
|
||||
@admin_file_app_router.put(
|
||||
path='/{app_id}/extensions',
|
||||
summary='全量替换扩展名列表',
|
||||
)
|
||||
async def update_extensions(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: ExtensionUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
全量替换扩展名列表端点
|
||||
|
||||
先删除旧的扩展名关联,再创建新的。
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 保留旧扩展名的 wopi_action_url(Discovery 填充的值)
|
||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
old_url_map: dict[str, str] = {
|
||||
ext.extension: ext.wopi_action_url
|
||||
for ext in old_extensions
|
||||
if ext.wopi_action_url
|
||||
}
|
||||
for old_ext in old_extensions:
|
||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||
await session.flush()
|
||||
|
||||
# 创建新的扩展名(保留已有的 wopi_action_url)
|
||||
new_extensions: list[FileAppExtension] = []
|
||||
for i, ext in enumerate(request.extensions):
|
||||
normalized = ext.lower().strip().lstrip('.')
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=normalized,
|
||||
priority=i,
|
||||
wopi_action_url=old_url_map.get(normalized),
|
||||
)
|
||||
session.add(ext_record)
|
||||
new_extensions.append(ext_record)
|
||||
|
||||
await session.commit()
|
||||
# refresh commit 后过期的对象
|
||||
await session.refresh(app)
|
||||
for ext_record in new_extensions:
|
||||
await session.refresh(ext_record)
|
||||
|
||||
group_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||
)
|
||||
group_links: list[FileAppGroupLink] = list(group_links_result.all())
|
||||
|
||||
l.info(f"更新文件应用 {app.app_key} 的扩展名: {request.extensions}")
|
||||
|
||||
return FileAppResponse.from_app(app, new_extensions, group_links)
|
||||
|
||||
|
||||
@admin_file_app_router.put(
|
||||
path='/{app_id}/groups',
|
||||
summary='全量替换允许的用户组',
|
||||
)
|
||||
async def update_group_access(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
request: GroupAccessUpdateRequest,
|
||||
) -> FileAppResponse:
|
||||
"""
|
||||
全量替换允许的用户组端点
|
||||
|
||||
先删除旧的关联,再创建新的。
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
# 删除旧的用户组关联
|
||||
old_links_result = await session.exec(
|
||||
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
|
||||
)
|
||||
old_links: list[FileAppGroupLink] = list(old_links_result.all())
|
||||
for old_link in old_links:
|
||||
await session.delete(old_link)
|
||||
|
||||
# 创建新的用户组关联
|
||||
new_links: list[FileAppGroupLink] = []
|
||||
for group_id in request.group_ids:
|
||||
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
|
||||
session.add(link)
|
||||
new_links.append(link)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(app)
|
||||
|
||||
extensions = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
|
||||
|
||||
return FileAppResponse.from_app(app, extensions, new_links)
|
||||
|
||||
|
||||
@admin_file_app_router.post(
|
||||
path='/{app_id}/discover',
|
||||
summary='执行 WOPI Discovery',
|
||||
)
|
||||
async def discover_wopi(
|
||||
session: SessionDep,
|
||||
app_id: UUID,
|
||||
) -> WopiDiscoveryResponse:
|
||||
"""
|
||||
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
|
||||
|
||||
流程:
|
||||
1. 验证 FileApp 存在且为 WOPI 类型
|
||||
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
|
||||
3. 解析 XML,提取扩展名和动作 URL
|
||||
4. 全量替换 FileAppExtension 记录(带 wopi_action_url)
|
||||
|
||||
认证:管理员权限
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
|
||||
- 502: WOPI 服务端不可达或返回无效响应
|
||||
"""
|
||||
app = await FileApp.get_exist_one(session, app_id)
|
||||
|
||||
if app.type != FileAppType.WOPI:
|
||||
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
|
||||
|
||||
if not app.wopi_discovery_url:
|
||||
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
|
||||
|
||||
# commit 后对象会过期,先保存需要的值
|
||||
discovery_url = app.wopi_discovery_url
|
||||
app_key = app.app_key
|
||||
|
||||
# 获取 Discovery XML
|
||||
try:
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.get(
|
||||
discovery_url,
|
||||
timeout=aiohttp.ClientTimeout(total=15),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
http_exceptions.raise_bad_gateway(
|
||||
f"WOPI 服务端返回 HTTP {resp.status}"
|
||||
)
|
||||
xml_content = await resp.text()
|
||||
except aiohttp.ClientError as e:
|
||||
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
|
||||
|
||||
# 解析 XML
|
||||
try:
|
||||
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
|
||||
except ValueError as e:
|
||||
http_exceptions.raise_bad_request(str(e))
|
||||
|
||||
if not action_urls:
|
||||
return WopiDiscoveryResponse(app_names=app_names)
|
||||
|
||||
# 全量替换扩展名
|
||||
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
FileAppExtension.app_id == app_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for old_ext in old_extensions:
|
||||
await FileAppExtension.delete(session, old_ext, commit=False)
|
||||
await session.flush()
|
||||
|
||||
new_extensions: list[FileAppExtension] = []
|
||||
discovered: list[WopiDiscoveredExtension] = []
|
||||
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
|
||||
ext_record = FileAppExtension(
|
||||
app_id=app_id,
|
||||
extension=ext,
|
||||
priority=i,
|
||||
wopi_action_url=action_url,
|
||||
)
|
||||
session.add(ext_record)
|
||||
new_extensions.append(ext_record)
|
||||
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
|
||||
|
||||
await session.commit()
|
||||
|
||||
l.info(
|
||||
f"WOPI Discovery 完成: 应用 {app_key}, "
|
||||
f"发现 {len(discovered)} 个扩展名"
|
||||
)
|
||||
|
||||
return WopiDiscoveryResponse(
|
||||
discovered_extensions=discovered,
|
||||
app_names=app_names,
|
||||
applied_count=len(discovered),
|
||||
)
|
||||
@@ -5,12 +5,12 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, GroupOptions, )
|
||||
from models.group import (
|
||||
from sqlmodels.group import (
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
|
||||
from models.policy import GroupPolicyLink
|
||||
from sqlmodels.policy import GroupPolicyLink
|
||||
|
||||
admin_group_router = APIRouter(
|
||||
prefix="/group",
|
||||
@@ -55,7 +55,7 @@ async def router_admin_get_groups(
|
||||
async def router_admin_get_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> GroupDetailResponse:
|
||||
"""
|
||||
根据用户组ID获取用户组详细信息。
|
||||
|
||||
@@ -63,17 +63,12 @@ async def router_admin_get_group(
|
||||
:param group_id: 用户组UUID
|
||||
:return: 用户组详情
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
|
||||
|
||||
# 直接访问已加载的关系,无需额外查询
|
||||
policies = group.policies
|
||||
user_count = await User.count(session, User.group_id == group_id)
|
||||
response = GroupDetailResponse.from_group(group, user_count, policies)
|
||||
|
||||
return ResponseBase(data=response.model_dump())
|
||||
return GroupDetailResponse.from_group(group, user_count, policies)
|
||||
|
||||
|
||||
@admin_group_router.get(
|
||||
@@ -96,9 +91,7 @@ async def router_admin_get_group_members(
|
||||
:return: 分页成员列表
|
||||
"""
|
||||
# 验证组存在
|
||||
group = await Group.get(session, Group.id == group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
await Group.get_exist_one(session, group_id)
|
||||
|
||||
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
|
||||
|
||||
@@ -113,11 +106,12 @@ async def router_admin_get_group_members(
|
||||
summary='创建用户组',
|
||||
description='Create a new user group',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_create_group(
|
||||
session: SessionDep,
|
||||
request: GroupCreateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建新的用户组。
|
||||
|
||||
@@ -139,10 +133,11 @@ async def router_admin_create_group(
|
||||
speed_limit=request.speed_limit,
|
||||
)
|
||||
group = await group.save(session)
|
||||
group_id_val: UUID = group.id
|
||||
|
||||
# 创建选项
|
||||
options = GroupOptions(
|
||||
group_id=group.id,
|
||||
group_id=group_id_val,
|
||||
share_download=request.share_download,
|
||||
share_free=request.share_free,
|
||||
relocate=request.relocate,
|
||||
@@ -155,16 +150,15 @@ async def router_admin_create_group(
|
||||
aria2=request.aria2,
|
||||
redirected_source=request.redirected_source,
|
||||
)
|
||||
await options.save(session)
|
||||
options = await options.save(session)
|
||||
|
||||
# 关联存储策略
|
||||
for policy_id in request.policy_ids:
|
||||
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id)
|
||||
link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员创建了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id), "name": group.name})
|
||||
|
||||
|
||||
@admin_group_router.patch(
|
||||
@@ -172,12 +166,13 @@ async def router_admin_create_group(
|
||||
summary='更新用户组信息',
|
||||
description='Update user group information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_update_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
request: GroupUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID更新用户组信息。
|
||||
|
||||
@@ -186,9 +181,7 @@ async def router_admin_update_group(
|
||||
:param request: 更新请求
|
||||
:return: 更新结果
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id, load=Group.options)
|
||||
|
||||
# 检查名称唯一性(如果要更新名称)
|
||||
if request.name and request.name != group.name:
|
||||
@@ -218,7 +211,7 @@ async def router_admin_update_group(
|
||||
if options_data:
|
||||
for key, value in options_data.items():
|
||||
setattr(group.options, key, value)
|
||||
await group.options.save(session)
|
||||
group.options = await group.options.save(session)
|
||||
|
||||
# 更新策略关联
|
||||
if request.policy_ids is not None:
|
||||
@@ -233,8 +226,7 @@ async def router_admin_update_group(
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员更新了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id)})
|
||||
l.info(f"管理员更新了用户组: {group_id}")
|
||||
|
||||
|
||||
@admin_group_router.delete(
|
||||
@@ -242,11 +234,12 @@ async def router_admin_update_group(
|
||||
summary='删除用户组',
|
||||
description='Delete user group by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID删除用户组。
|
||||
|
||||
@@ -256,9 +249,7 @@ async def router_admin_delete_group(
|
||||
:param group_id: 用户组UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
group = await Group.get_exist_one(session, group_id)
|
||||
|
||||
# 检查是否有用户属于该组
|
||||
user_count = await User.count(session, User.group_id == group_id)
|
||||
@@ -271,5 +262,4 @@ async def router_admin_delete_group(
|
||||
group_name = group.name
|
||||
await Group.delete(session, group)
|
||||
|
||||
l.info(f"管理员删除了用户组: {group_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了用户组: {group_id}")
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
@@ -6,17 +7,96 @@ from sqlmodel import Field
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object, )
|
||||
from models.base import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
from sqlmodels import (
|
||||
Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
|
||||
PolicyUpdateRequest, ResponseBase, ListResponse, Object,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
prefix='/policy',
|
||||
tags=['admin', 'admin_policy']
|
||||
)
|
||||
|
||||
|
||||
class PathTestResponse(SQLModelBase):
|
||||
"""路径测试响应"""
|
||||
|
||||
path: str
|
||||
"""解析后的路径"""
|
||||
|
||||
is_exists: bool
|
||||
"""路径是否存在"""
|
||||
|
||||
is_writable: bool
|
||||
"""路径是否可写"""
|
||||
|
||||
|
||||
class PolicyGroupInfo(SQLModelBase):
|
||||
"""策略关联的用户组信息"""
|
||||
|
||||
id: str
|
||||
"""用户组UUID"""
|
||||
|
||||
name: str
|
||||
"""用户组名称"""
|
||||
|
||||
|
||||
class PolicyDetailResponse(SQLModelBase):
|
||||
"""存储策略详情响应"""
|
||||
|
||||
id: str
|
||||
"""策略UUID"""
|
||||
|
||||
name: str
|
||||
"""策略名称"""
|
||||
|
||||
type: str
|
||||
"""策略类型"""
|
||||
|
||||
server: str | None
|
||||
"""服务器地址"""
|
||||
|
||||
bucket_name: str | None
|
||||
"""存储桶名称"""
|
||||
|
||||
is_private: bool
|
||||
"""是否私有"""
|
||||
|
||||
base_url: str | None
|
||||
"""基础URL"""
|
||||
|
||||
access_key: str | None
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str | None
|
||||
"""Secret Key"""
|
||||
|
||||
max_size: int
|
||||
"""最大文件尺寸"""
|
||||
|
||||
auto_rename: bool
|
||||
"""是否自动重命名"""
|
||||
|
||||
dir_name_rule: str | None
|
||||
"""目录命名规则"""
|
||||
|
||||
file_name_rule: str | None
|
||||
"""文件命名规则"""
|
||||
|
||||
is_origin_link_enable: bool
|
||||
"""是否启用外链"""
|
||||
|
||||
options: dict[str, Any] | None
|
||||
"""策略选项"""
|
||||
|
||||
groups: list[PolicyGroupInfo]
|
||||
"""关联的用户组"""
|
||||
|
||||
object_count: int
|
||||
"""使用此策略的对象数量"""
|
||||
|
||||
class PolicyTestPathRequest(SQLModelBase):
|
||||
"""测试本地路径请求 DTO"""
|
||||
|
||||
@@ -33,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
|
||||
secret: str
|
||||
"""从机通信密钥"""
|
||||
|
||||
class PolicyCreateRequest(PolicyBase):
|
||||
"""创建存储策略请求 DTO,继承 PolicyBase 中的所有字段"""
|
||||
pass
|
||||
class PolicyTestS3Request(SQLModelBase):
|
||||
"""测试 S3 连接请求 DTO"""
|
||||
|
||||
server: str = Field(max_length=255)
|
||||
"""S3 端点地址"""
|
||||
|
||||
bucket_name: str = Field(max_length=255)
|
||||
"""存储桶名称"""
|
||||
|
||||
access_key: str
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str
|
||||
"""Secret Key"""
|
||||
|
||||
s3_region: str = Field(default='us-east-1', max_length=64)
|
||||
"""S3 区域"""
|
||||
|
||||
s3_path_style: bool = False
|
||||
"""是否使用路径风格"""
|
||||
|
||||
|
||||
class PolicyTestS3Response(SQLModelBase):
|
||||
"""S3 连接测试响应"""
|
||||
|
||||
is_connected: bool
|
||||
"""连接是否成功"""
|
||||
|
||||
message: str
|
||||
"""测试结果消息"""
|
||||
|
||||
|
||||
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
|
||||
|
||||
_OPTIONS_FIELDS: set[str] = {
|
||||
'token', 'file_type', 'mimetype', 'od_redirect',
|
||||
'chunk_size', 's3_path_style', 's3_region',
|
||||
}
|
||||
|
||||
|
||||
@admin_policy_router.get(
|
||||
path='/list',
|
||||
@@ -70,7 +186,7 @@ async def router_policy_list(
|
||||
)
|
||||
async def router_policy_test_path(
|
||||
request: PolicyTestPathRequest,
|
||||
) -> ResponseBase:
|
||||
) -> PathTestResponse:
|
||||
"""
|
||||
测试本地存储路径是否可用。
|
||||
|
||||
@@ -97,22 +213,23 @@ async def router_policy_test_path(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ResponseBase(data={
|
||||
"path": str(path),
|
||||
"exists": is_exists,
|
||||
"writable": is_writable,
|
||||
})
|
||||
return PathTestResponse(
|
||||
path=str(path),
|
||||
is_exists=is_exists,
|
||||
is_writable=is_writable,
|
||||
)
|
||||
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/test/slave',
|
||||
summary='测试从机通信',
|
||||
description='Test slave node communication',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_test_slave(
|
||||
request: PolicyTestSlaveRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
测试从机RPC通信。
|
||||
|
||||
@@ -129,25 +246,28 @@ async def router_policy_test_slave(
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
return ResponseBase(data={"connected": True})
|
||||
return
|
||||
else:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg=f"从机响应错误,HTTP {resp.status}"
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"从机响应错误,HTTP {resp.status}",
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return ResponseBase(code=400, msg=f"连接失败: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/',
|
||||
summary='创建存储策略',
|
||||
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_add_policy(
|
||||
session: SessionDep,
|
||||
request: PolicyCreateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建存储策略端点
|
||||
|
||||
@@ -201,12 +321,18 @@ async def router_policy_add_policy(
|
||||
# 保存到数据库
|
||||
policy = await policy.save(session)
|
||||
|
||||
return ResponseBase(data={
|
||||
"id": str(policy.id),
|
||||
"name": policy.name,
|
||||
"type": policy.type.value,
|
||||
"server": policy.server,
|
||||
})
|
||||
# 创建策略选项
|
||||
options = PolicyOptions(
|
||||
policy_id=policy.id,
|
||||
token=request.token,
|
||||
file_type=request.file_type,
|
||||
mimetype=request.mimetype,
|
||||
od_redirect=request.od_redirect,
|
||||
chunk_size=request.chunk_size,
|
||||
s3_path_style=request.s3_path_style,
|
||||
s3_region=request.s3_region,
|
||||
)
|
||||
options = await options.save(session)
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/cors',
|
||||
@@ -257,9 +383,7 @@ async def router_policy_onddrive_oauth(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: OAuth URL
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# TODO: 实现OneDrive OAuth
|
||||
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
|
||||
@@ -274,7 +398,7 @@ async def router_policy_onddrive_oauth(
|
||||
async def router_policy_get_policy(
|
||||
session: SessionDep,
|
||||
policy_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> PolicyDetailResponse:
|
||||
"""
|
||||
获取存储策略详情。
|
||||
|
||||
@@ -282,9 +406,7 @@ async def router_policy_get_policy(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: 策略详情
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||
|
||||
# 获取使用此策略的用户组
|
||||
groups = await policy.awaitable_attrs.groups
|
||||
@@ -292,35 +414,38 @@ async def router_policy_get_policy(
|
||||
# 统计使用此策略的对象数量
|
||||
object_count = await Object.count(session, Object.policy_id == policy_id)
|
||||
|
||||
return ResponseBase(data={
|
||||
"id": str(policy.id),
|
||||
"name": policy.name,
|
||||
"type": policy.type.value,
|
||||
"server": policy.server,
|
||||
"bucket_name": policy.bucket_name,
|
||||
"is_private": policy.is_private,
|
||||
"base_url": policy.base_url,
|
||||
"max_size": policy.max_size,
|
||||
"auto_rename": policy.auto_rename,
|
||||
"dir_name_rule": policy.dir_name_rule,
|
||||
"file_name_rule": policy.file_name_rule,
|
||||
"is_origin_link_enable": policy.is_origin_link_enable,
|
||||
"options": policy.options.model_dump() if policy.options else None,
|
||||
"groups": [{"id": str(g.id), "name": g.name} for g in groups],
|
||||
"object_count": object_count,
|
||||
})
|
||||
return PolicyDetailResponse(
|
||||
id=str(policy.id),
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
server=policy.server,
|
||||
bucket_name=policy.bucket_name,
|
||||
is_private=policy.is_private,
|
||||
base_url=policy.base_url,
|
||||
access_key=policy.access_key,
|
||||
secret_key=policy.secret_key,
|
||||
max_size=policy.max_size,
|
||||
auto_rename=policy.auto_rename,
|
||||
dir_name_rule=policy.dir_name_rule,
|
||||
file_name_rule=policy.file_name_rule,
|
||||
is_origin_link_enable=policy.is_origin_link_enable,
|
||||
options=policy.options.model_dump() if policy.options else None,
|
||||
groups=[PolicyGroupInfo(id=str(g.id), name=g.name) for g in groups],
|
||||
object_count=object_count,
|
||||
)
|
||||
|
||||
|
||||
@admin_policy_router.delete(
|
||||
path='/{policy_id}',
|
||||
summary='删除存储策略',
|
||||
description='Delete storage policy by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_delete_policy(
|
||||
session: SessionDep,
|
||||
policy_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除存储策略。
|
||||
|
||||
@@ -330,9 +455,7 @@ async def router_policy_delete_policy(
|
||||
:param policy_id: 存储策略UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
# 检查是否有文件使用此策略
|
||||
file_count = await Object.count(session, Object.policy_id == policy_id)
|
||||
@@ -346,4 +469,105 @@ async def router_policy_delete_policy(
|
||||
await Policy.delete(session, policy)
|
||||
|
||||
l.info(f"管理员删除了存储策略: {policy_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
|
||||
|
||||
@admin_policy_router.patch(
|
||||
path='/{policy_id}',
|
||||
summary='更新存储策略',
|
||||
description='更新存储策略配置。策略类型创建后不可更改。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_policy_update_policy(
|
||||
session: SessionDep,
|
||||
policy_id: UUID,
|
||||
request: PolicyUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新存储策略端点
|
||||
|
||||
功能:
|
||||
- 更新策略基础字段和扩展选项
|
||||
- 策略类型(type)不可更改
|
||||
|
||||
认证:
|
||||
- 需要管理员权限
|
||||
|
||||
:param session: 数据库会话
|
||||
:param policy_id: 存储策略UUID
|
||||
:param request: 更新请求
|
||||
"""
|
||||
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
|
||||
|
||||
# 检查名称唯一性(如果要更新名称)
|
||||
if request.name and request.name != policy.name:
|
||||
existing = await Policy.get(session, Policy.name == request.name)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="策略名称已存在")
|
||||
|
||||
# 分离 Policy 字段和 Options 字段
|
||||
all_data = request.model_dump(exclude_unset=True)
|
||||
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
|
||||
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
|
||||
|
||||
# 更新 Policy 基础字段
|
||||
if policy_data:
|
||||
for key, value in policy_data.items():
|
||||
setattr(policy, key, value)
|
||||
policy = await policy.save(session)
|
||||
|
||||
# 更新或创建 PolicyOptions
|
||||
if options_data:
|
||||
if policy.options:
|
||||
for key, value in options_data.items():
|
||||
setattr(policy.options, key, value)
|
||||
policy.options = await policy.options.save(session)
|
||||
else:
|
||||
options = PolicyOptions(policy_id=policy.id, **options_data)
|
||||
options = await options.save(session)
|
||||
|
||||
l.info(f"管理员更新了存储策略: {policy_id}")
|
||||
|
||||
|
||||
@admin_policy_router.post(
|
||||
path='/test/s3',
|
||||
summary='测试 S3 连接',
|
||||
description='测试 S3 存储端点的连通性和凭据有效性。',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_policy_test_s3(
|
||||
request: PolicyTestS3Request,
|
||||
) -> PolicyTestS3Response:
|
||||
"""
|
||||
测试 S3 连接端点
|
||||
|
||||
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
|
||||
|
||||
:param request: 测试请求
|
||||
:return: 测试结果
|
||||
"""
|
||||
from service.storage import S3APIError
|
||||
|
||||
# 构造临时 Policy 对象用于创建 S3StorageService
|
||||
temp_policy = Policy(
|
||||
name="__test__",
|
||||
type=PolicyType.S3,
|
||||
server=request.server,
|
||||
bucket_name=request.bucket_name,
|
||||
access_key=request.access_key,
|
||||
secret_key=request.secret_key,
|
||||
)
|
||||
s3_service = S3StorageService(
|
||||
temp_policy,
|
||||
region=request.s3_region,
|
||||
is_path_style=request.s3_path_style,
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用 file_exists 发送 HEAD 请求来验证连通性
|
||||
await s3_service.file_exists("__connection_test__")
|
||||
return PolicyTestS3Response(is_connected=True, message="连接成功")
|
||||
except S3APIError as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
|
||||
except Exception as e:
|
||||
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
@@ -5,9 +6,54 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
ResponseBase, ListResponse,
|
||||
Share, AdminShareListItem, )
|
||||
from sqlmodels import (
|
||||
ListResponse,
|
||||
Share, AdminShareListItem,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class ShareDetailResponse(SQLModelBase):
|
||||
"""分享详情响应"""
|
||||
|
||||
id: UUID
|
||||
"""分享UUID"""
|
||||
|
||||
code: str
|
||||
"""分享码"""
|
||||
|
||||
views: int
|
||||
"""浏览次数"""
|
||||
|
||||
downloads: int
|
||||
"""下载次数"""
|
||||
|
||||
remain_downloads: int | None
|
||||
"""剩余下载次数"""
|
||||
|
||||
expires: datetime | None
|
||||
"""过期时间"""
|
||||
|
||||
preview_enabled: bool
|
||||
"""是否启用预览"""
|
||||
|
||||
score: int
|
||||
"""评分"""
|
||||
|
||||
has_password: bool
|
||||
"""是否有密码"""
|
||||
|
||||
user_id: str
|
||||
"""用户UUID"""
|
||||
|
||||
username: str | None
|
||||
"""用户名"""
|
||||
|
||||
object: dict | None
|
||||
"""关联对象信息"""
|
||||
|
||||
created_at: str
|
||||
"""创建时间"""
|
||||
|
||||
admin_share_router = APIRouter(
|
||||
prefix='/share',
|
||||
@@ -53,8 +99,8 @@ async def router_admin_get_share_list(
|
||||
)
|
||||
async def router_admin_get_share(
|
||||
session: SessionDep,
|
||||
share_id: int,
|
||||
) -> ResponseBase:
|
||||
share_id: UUID,
|
||||
) -> ShareDetailResponse:
|
||||
"""
|
||||
获取分享详情。
|
||||
|
||||
@@ -69,38 +115,39 @@ async def router_admin_get_share(
|
||||
obj = await share.awaitable_attrs.object
|
||||
user = await share.awaitable_attrs.user
|
||||
|
||||
return ResponseBase(data={
|
||||
"id": share.id,
|
||||
"code": share.code,
|
||||
"views": share.views,
|
||||
"downloads": share.downloads,
|
||||
"remain_downloads": share.remain_downloads,
|
||||
"expires": share.expires.isoformat() if share.expires else None,
|
||||
"preview_enabled": share.preview_enabled,
|
||||
"score": share.score,
|
||||
"has_password": bool(share.password),
|
||||
"user_id": str(share.user_id),
|
||||
"username": user.username if user else None,
|
||||
"object": {
|
||||
return ShareDetailResponse(
|
||||
id=share.id,
|
||||
code=share.code,
|
||||
views=share.views,
|
||||
downloads=share.downloads,
|
||||
remain_downloads=share.remain_downloads,
|
||||
expires=share.expires,
|
||||
preview_enabled=share.preview_enabled,
|
||||
score=share.score,
|
||||
has_password=bool(share.password),
|
||||
user_id=str(share.user_id),
|
||||
username=user.email if user else None,
|
||||
object={
|
||||
"id": str(obj.id),
|
||||
"name": obj.name,
|
||||
"type": obj.type.value,
|
||||
"size": obj.size,
|
||||
} if obj else None,
|
||||
"created_at": share.created_at.isoformat(),
|
||||
})
|
||||
created_at=share.created_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@admin_share_router.delete(
|
||||
path='/{share_id}',
|
||||
summary='删除分享',
|
||||
description='Delete share by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_share(
|
||||
session: SessionDep,
|
||||
share_id: int,
|
||||
) -> ResponseBase:
|
||||
share_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除分享。
|
||||
|
||||
@@ -108,11 +155,8 @@ async def router_admin_delete_share(
|
||||
:param share_id: 分享ID
|
||||
:return: 删除结果
|
||||
"""
|
||||
share = await Share.get(session, Share.id == share_id)
|
||||
if not share:
|
||||
raise HTTPException(status_code=404, detail="分享不存在")
|
||||
share = await Share.get_exist_one(session, share_id)
|
||||
|
||||
await Share.delete(session, share)
|
||||
|
||||
l.info(f"管理员删除了分享: {share.code}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
@@ -5,10 +6,45 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
ResponseBase, ListResponse,
|
||||
Task, TaskSummary,
|
||||
from sqlmodels import (
|
||||
ListResponse,
|
||||
Task, TaskSummary, TaskStatus, TaskType,
|
||||
)
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class TaskDetailResponse(SQLModelBase):
|
||||
"""任务详情响应"""
|
||||
|
||||
id: int
|
||||
"""任务ID"""
|
||||
|
||||
status: TaskStatus
|
||||
"""任务状态"""
|
||||
|
||||
type: TaskType
|
||||
"""任务类型"""
|
||||
|
||||
progress: int
|
||||
"""任务进度"""
|
||||
|
||||
error: str | None
|
||||
"""错误信息"""
|
||||
|
||||
user_id: str
|
||||
"""用户UUID"""
|
||||
|
||||
username: str | None
|
||||
"""用户名"""
|
||||
|
||||
props: dict[str, Any] | None
|
||||
"""任务属性"""
|
||||
|
||||
created_at: str
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: str
|
||||
"""更新时间"""
|
||||
|
||||
admin_task_router = APIRouter(
|
||||
prefix='/task',
|
||||
@@ -67,7 +103,7 @@ async def router_admin_get_task_list(
|
||||
async def router_admin_get_task(
|
||||
session: SessionDep,
|
||||
task_id: int,
|
||||
) -> ResponseBase:
|
||||
) -> TaskDetailResponse:
|
||||
"""
|
||||
获取任务详情。
|
||||
|
||||
@@ -82,30 +118,31 @@ async def router_admin_get_task(
|
||||
user = await task.awaitable_attrs.user
|
||||
props = await task.awaitable_attrs.props
|
||||
|
||||
return ResponseBase(data={
|
||||
"id": task.id,
|
||||
"status": task.status,
|
||||
"type": task.type,
|
||||
"progress": task.progress,
|
||||
"error": task.error,
|
||||
"user_id": str(task.user_id),
|
||||
"username": user.username if user else None,
|
||||
"props": props.model_dump() if props else None,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
"updated_at": task.updated_at.isoformat(),
|
||||
})
|
||||
return TaskDetailResponse(
|
||||
id=task.id,
|
||||
status=task.status,
|
||||
type=task.type,
|
||||
progress=task.progress,
|
||||
error=task.error,
|
||||
user_id=str(task.user_id),
|
||||
username=user.email if user else None,
|
||||
props=props.model_dump() if props else None,
|
||||
created_at=task.created_at.isoformat(),
|
||||
updated_at=task.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
@admin_task_router.delete(
|
||||
path='/{task_id}',
|
||||
summary='删除任务',
|
||||
description='Delete task by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_task(
|
||||
session: SessionDep,
|
||||
task_id: int,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除任务。
|
||||
|
||||
@@ -113,11 +150,8 @@ async def router_admin_delete_task(
|
||||
:param task_id: 任务ID
|
||||
:return: 删除结果
|
||||
"""
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail="任务不存在")
|
||||
task = await Task.get_exist_one(session, task_id)
|
||||
|
||||
await Task.delete(session, task)
|
||||
|
||||
l.info(f"管理员删除了任务: {task_id}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
187
routers/api/v1/admin/theme/__init__.py
Normal file
187
routers/api/v1/admin/theme/__init__.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import update as sql_update
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
ThemePreset,
|
||||
ThemePresetCreateRequest,
|
||||
ThemePresetUpdateRequest,
|
||||
ThemePresetResponse,
|
||||
ThemePresetListResponse,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
admin_theme_router = APIRouter(
|
||||
prefix="/theme",
|
||||
tags=["admin", "admin_theme"],
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
|
||||
|
||||
@admin_theme_router.get(
|
||||
path='/',
|
||||
summary='获取主题预设列表',
|
||||
)
|
||||
async def router_admin_theme_list(session: SessionDep) -> ThemePresetListResponse:
|
||||
"""
|
||||
获取所有主题预设列表
|
||||
|
||||
认证:需要管理员权限
|
||||
|
||||
响应:
|
||||
- ThemePresetListResponse: 包含所有主题预设的列表
|
||||
"""
|
||||
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
|
||||
return ThemePresetListResponse(
|
||||
themes=[ThemePresetResponse.from_preset(p) for p in presets]
|
||||
)
|
||||
|
||||
|
||||
@admin_theme_router.post(
|
||||
path='/',
|
||||
summary='创建主题预设',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_admin_theme_create(
|
||||
session: SessionDep,
|
||||
request: ThemePresetCreateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
创建新的主题预设
|
||||
|
||||
认证:需要管理员权限
|
||||
|
||||
请求体:
|
||||
- name: 预设名称(唯一)
|
||||
- colors: 颜色配置对象
|
||||
|
||||
错误处理:
|
||||
- 409: 名称已存在
|
||||
"""
|
||||
# 检查名称唯一性
|
||||
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
|
||||
|
||||
preset = ThemePreset(
|
||||
name=request.name,
|
||||
**request.colors.model_dump(),
|
||||
)
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员创建了主题预设: {request.name}")
|
||||
|
||||
|
||||
@admin_theme_router.patch(
|
||||
path='/{preset_id}',
|
||||
summary='更新主题预设',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_admin_theme_update(
|
||||
session: SessionDep,
|
||||
preset_id: UUID,
|
||||
request: ThemePresetUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
部分更新主题预设
|
||||
|
||||
认证:需要管理员权限
|
||||
|
||||
路径参数:
|
||||
- preset_id: 预设UUID
|
||||
|
||||
请求体(均可选):
|
||||
- name: 预设名称
|
||||
- colors: 颜色配置对象
|
||||
|
||||
错误处理:
|
||||
- 404: 预设不存在
|
||||
- 409: 名称已被其他预设使用
|
||||
"""
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
# 检查名称唯一性(排除自身)
|
||||
if request.name is not None and request.name != preset.name:
|
||||
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
|
||||
preset.name = request.name
|
||||
|
||||
# 更新颜色字段
|
||||
if request.colors is not None:
|
||||
color_data = request.colors.model_dump()
|
||||
for key, value in color_data.items():
|
||||
setattr(preset, key, value)
|
||||
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员更新了主题预设: {preset.name}")
|
||||
|
||||
|
||||
@admin_theme_router.delete(
|
||||
path='/{preset_id}',
|
||||
summary='删除主题预设',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_admin_theme_delete(
|
||||
session: SessionDep,
|
||||
preset_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除主题预设
|
||||
|
||||
认证:需要管理员权限
|
||||
|
||||
路径参数:
|
||||
- preset_id: 预设UUID
|
||||
|
||||
错误处理:
|
||||
- 404: 预设不存在
|
||||
|
||||
副作用:
|
||||
- 关联用户的 theme_preset_id 会被数据库 SET NULL
|
||||
"""
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
await preset.delete(session)
|
||||
l.info(f"管理员删除了主题预设: {preset.name}")
|
||||
|
||||
|
||||
@admin_theme_router.patch(
|
||||
path='/{preset_id}/default',
|
||||
summary='设为默认主题预设',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_admin_theme_set_default(
|
||||
session: SessionDep,
|
||||
preset_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
将指定预设设为默认主题
|
||||
|
||||
认证:需要管理员权限
|
||||
|
||||
路径参数:
|
||||
- preset_id: 预设UUID
|
||||
|
||||
错误处理:
|
||||
- 404: 预设不存在
|
||||
|
||||
逻辑:
|
||||
- 事务中先清除所有旧默认,再设新默认
|
||||
"""
|
||||
preset = await ThemePreset.get_exist_one(session, preset_id)
|
||||
|
||||
# 清除所有旧默认
|
||||
await session.execute(
|
||||
sql_update(ThemePreset)
|
||||
.where(ThemePreset.is_default == True) # noqa: E712
|
||||
.values(is_default=False)
|
||||
)
|
||||
|
||||
# 设新默认
|
||||
preset.is_default = True
|
||||
preset = await preset.save(session)
|
||||
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")
|
||||
@@ -6,11 +6,15 @@ from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
|
||||
from models import (
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, Object, ObjectType, )
|
||||
from models.user import (
|
||||
UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||
)
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
@@ -26,19 +30,19 @@ admin_user_router = APIRouter(
|
||||
description='Get user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
|
||||
async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
|
||||
"""
|
||||
根据用户ID获取用户信息,包括用户名、邮箱、注册时间等。
|
||||
|
||||
Args:
|
||||
session(SessionDep): 数据库会话依赖项。
|
||||
user_id (int): 用户ID。
|
||||
user_id (UUID): 用户ID。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含用户信息的响应模型。
|
||||
"""
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.get(
|
||||
@@ -60,7 +64,7 @@ async def router_admin_get_users(
|
||||
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
|
||||
:return: 分页用户列表
|
||||
"""
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view)
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
|
||||
return ListResponse(
|
||||
items=[user.to_public() for user in result.items],
|
||||
count=result.count,
|
||||
@@ -75,22 +79,59 @@ async def router_admin_get_users(
|
||||
)
|
||||
async def router_admin_create_user(
|
||||
session: SessionDep,
|
||||
user: User,
|
||||
) -> ResponseBase:
|
||||
request: UserAdminCreateRequest,
|
||||
) -> UserPublic:
|
||||
"""
|
||||
创建一个新的用户,设置用户名、密码等信息。
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含创建结果的响应模型。
|
||||
管理员创建用户时,若提供了 email + password,
|
||||
会同时创建 AuthIdentity(provider=email_password)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.username == user.username)
|
||||
if existing_user:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg="User with this username already exists."
|
||||
# 如果提供了邮箱,检查唯一性(User 表和 AuthIdentity 表)
|
||||
if request.email:
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
existing_identity = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.email),
|
||||
)
|
||||
if existing_identity:
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
|
||||
# 如果提供了邮箱和密码,创建邮箱密码认证身份
|
||||
if request.email and request.password:
|
||||
identity = AuthIdentity(
|
||||
provider=AuthProviderType.EMAIL_PASSWORD,
|
||||
identifier=request.email,
|
||||
credential=Password.hash(request.password),
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=user.id,
|
||||
)
|
||||
identity = await identity.save(session)
|
||||
|
||||
user = await User.get(session, User.id == user.id, load=User.group)
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.patch(
|
||||
@@ -98,12 +139,13 @@ async def router_admin_create_user(
|
||||
summary='更新用户信息',
|
||||
description='Update user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204
|
||||
)
|
||||
async def router_admin_update_user(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
request: UserAdminUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID更新用户信息。
|
||||
|
||||
@@ -112,12 +154,17 @@ async def router_admin_update_user(
|
||||
:param request: 更新请求
|
||||
:return: 更新结果
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
|
||||
# 默认管理员(用户名为 admin)不允许更改用户组
|
||||
if request.group_id and user.username == "admin" and request.group_id != user.group_id:
|
||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
if (request.group_id
|
||||
and default_admin_setting
|
||||
and default_admin_setting.value == str(user_id)
|
||||
and request.group_id != user.group_id):
|
||||
http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
|
||||
|
||||
# 如果更新用户组,验证新组存在
|
||||
@@ -126,55 +173,60 @@ async def router_admin_update_user(
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
# 如果更新密码,需要加密
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if 'password' in update_data and update_data['password']:
|
||||
update_data['password'] = Password.hash(update_data['password'])
|
||||
elif 'password' in update_data:
|
||||
del update_data['password'] # 空密码不更新
|
||||
|
||||
# 验证两步验证密钥格式(如果提供了值且不为 None,长度必须为 32)
|
||||
if 'two_factor' in update_data and update_data['two_factor'] is not None:
|
||||
if len(update_data['two_factor']) != 32:
|
||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
||||
# 记录旧 status 以便检测变更
|
||||
old_status = user.status
|
||||
|
||||
# 更新字段
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
user = await user.save(session)
|
||||
|
||||
l.info(f"管理员更新了用户: {user.username}")
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
# 封禁状态变更 → 更新 BanStore
|
||||
new_status = user.status
|
||||
if old_status == UserStatus.ACTIVE and new_status != UserStatus.ACTIVE:
|
||||
await UserBanStore.ban(str(user_id))
|
||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||
await UserBanStore.unban(str(user_id))
|
||||
|
||||
l.info(f"管理员更新了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
path='/{user_id}',
|
||||
summary='删除用户',
|
||||
description='Delete user by ID',
|
||||
path='/',
|
||||
summary='删除用户(支持批量)',
|
||||
description='Delete users by ID list',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_user(
|
||||
async def router_admin_delete_users(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
request: BatchDeleteRequest,
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID删除用户及其所有数据。
|
||||
批量删除用户及其所有数据。
|
||||
|
||||
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除结果
|
||||
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
||||
:return: 删除结果(已删除数 / 总请求数)
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
for uid in request.ids:
|
||||
user = await User.get(session, User.id == uid, load=User.group)
|
||||
|
||||
username = user.username
|
||||
await User.delete(session, user)
|
||||
# 安全检查:默认管理员不允许被删除(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
if user and default_admin_setting and default_admin_setting.value == str(uid):
|
||||
raise HTTPException(status_code=403, detail=f"默认管理员不允许被删除")
|
||||
|
||||
l.info(f"管理员删除了用户: {username}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
if user:
|
||||
await User.delete(session, user)
|
||||
l.info(f"管理员删除了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.post(
|
||||
@@ -186,7 +238,7 @@ async def router_admin_delete_user(
|
||||
async def router_admin_calibrate_storage(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> UserCalibrateResponse:
|
||||
"""
|
||||
重新计算用户的已用存储空间。
|
||||
|
||||
@@ -199,13 +251,12 @@ async def router_admin_calibrate_storage(
|
||||
:param user_id: 用户UUID
|
||||
:return: 校准结果
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
|
||||
previous_storage = user.storage
|
||||
|
||||
# 计算实际存储量 - 使用 SQL 聚合
|
||||
# [TODO] 不应这么计算,看看 SQLModel_Ext 库怎么解决
|
||||
from sqlmodel import select
|
||||
result = await session.execute(
|
||||
select(func.sum(Object.size), func.count(Object.id)).where(
|
||||
@@ -228,5 +279,5 @@ async def router_admin_calibrate_storage(
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}")
|
||||
return ResponseBase(data=response.model_dump())
|
||||
l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
|
||||
return response
|
||||
@@ -1,81 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
ResponseBase,
|
||||
)
|
||||
|
||||
admin_vas_router = APIRouter(
|
||||
prefix='/vas',
|
||||
tags=['admin', 'admin_vas']
|
||||
)
|
||||
|
||||
@admin_vas_router.get(
|
||||
path='/list',
|
||||
summary='获取增值服务列表',
|
||||
description='Get VAS list (orders and storage packs)',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas_list(
|
||||
session: SessionDep,
|
||||
user_id: UUID | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
获取增值服务列表(订单和存储包)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 按用户筛选
|
||||
:param page: 页码
|
||||
:param page_size: 每页数量
|
||||
:return: 增值服务列表
|
||||
"""
|
||||
# TODO: 实现增值服务列表
|
||||
# 需要查询 Order 和 StoragePack 模型
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
|
||||
|
||||
@admin_vas_router.get(
|
||||
path='/{vas_id}',
|
||||
summary='获取增值服务详情',
|
||||
description='Get VAS detail by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_get_vas(
|
||||
session: SessionDep,
|
||||
vas_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
获取增值服务详情。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param vas_id: 增值服务UUID
|
||||
:return: 增值服务详情
|
||||
"""
|
||||
# TODO: 实现增值服务详情
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
|
||||
|
||||
@admin_vas_router.delete(
|
||||
path='/{vas_id}',
|
||||
summary='删除增值服务',
|
||||
description='Delete VAS by ID',
|
||||
dependencies=[Depends(admin_required)]
|
||||
)
|
||||
async def router_admin_delete_vas(
|
||||
session: SessionDep,
|
||||
vas_id: UUID,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
删除增值服务。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param vas_id: 增值服务UUID
|
||||
:return: 删除结果
|
||||
"""
|
||||
# TODO: 实现增值服务删除
|
||||
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from loguru import logger as l
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
import service.oauth
|
||||
from utils import http_exceptions
|
||||
|
||||
@@ -15,18 +16,12 @@ oauth_router = APIRouter(
|
||||
tags=["callback", "oauth"],
|
||||
)
|
||||
|
||||
pay_router = APIRouter(
|
||||
prefix='/callback/pay',
|
||||
tags=["callback", "pay"],
|
||||
)
|
||||
|
||||
upload_router = APIRouter(
|
||||
prefix='/callback/upload',
|
||||
tags=["callback", "upload"],
|
||||
)
|
||||
|
||||
callback_router.include_router(oauth_router)
|
||||
callback_router.include_router(pay_router)
|
||||
callback_router.include_router(upload_router)
|
||||
|
||||
@oauth_router.post(
|
||||
@@ -64,91 +59,17 @@ async def router_callback_github(
|
||||
"""
|
||||
try:
|
||||
access_token = await service.oauth.github.get_access_token(code)
|
||||
# [TODO] 把access_token写数据库里
|
||||
if not access_token:
|
||||
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400)
|
||||
return PlainTextResponse("GitHub 认证失败", status_code=400)
|
||||
|
||||
user_data = await service.oauth.github.get_user_info(access_token.access_token)
|
||||
# [TODO] 把user_data写数据库里
|
||||
# [TODO] 把 access_token 和 user_data 写数据库,生成 JWT,重定向到前端
|
||||
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
|
||||
|
||||
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200)
|
||||
return PlainTextResponse("认证成功,功能开发中", status_code=200)
|
||||
except Exception as e:
|
||||
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500)
|
||||
|
||||
@pay_router.post(
|
||||
path='/alipay',
|
||||
summary='支付宝支付回调',
|
||||
description='Handle Alipay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_alipay() -> ResponseBase:
|
||||
"""
|
||||
Handle Alipay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Alipay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/wechat',
|
||||
summary='微信支付回调',
|
||||
description='Handle WeChat Pay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_wechat() -> ResponseBase:
|
||||
"""
|
||||
Handle WeChat Pay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.post(
|
||||
path='/stripe',
|
||||
summary='Stripe支付回调',
|
||||
description='Handle Stripe payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_stripe() -> ResponseBase:
|
||||
"""
|
||||
Handle Stripe payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the Stripe payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@pay_router.get(
|
||||
path='/easypay',
|
||||
summary='易支付回调',
|
||||
description='Handle EasyPay payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_easypay() -> PlainTextResponse:
|
||||
"""
|
||||
Handle EasyPay payment callback and return payment status.
|
||||
|
||||
Returns:
|
||||
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
# return PlainTextResponse("success", status_code=200)
|
||||
|
||||
@pay_router.get(
|
||||
path='/custom/{order_no}/{id}',
|
||||
summary='自定义支付回调',
|
||||
description='Handle custom payment callback and return payment status.',
|
||||
)
|
||||
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
|
||||
"""
|
||||
Handle custom payment callback and return payment status.
|
||||
|
||||
Args:
|
||||
order_no (str): The order number for the payment.
|
||||
id (str): The ID associated with the payment.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the custom payment callback.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
l.error(f"GitHub OAuth 回调异常: {e}")
|
||||
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
|
||||
|
||||
@upload_router.post(
|
||||
path='/remote/{session_id}/{key}',
|
||||
|
||||
100
routers/api/v1/category/__init__.py
Normal file
100
routers/api/v1/category/__init__.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
文件分类筛选端点
|
||||
|
||||
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
|
||||
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from sqlmodels import (
|
||||
FileCategory,
|
||||
ListResponse,
|
||||
Object,
|
||||
ObjectResponse,
|
||||
ObjectType,
|
||||
Setting,
|
||||
SettingsType,
|
||||
User,
|
||||
)
|
||||
|
||||
category_router = APIRouter(
|
||||
prefix="/category",
|
||||
tags=["category"],
|
||||
)
|
||||
|
||||
|
||||
@category_router.get(
|
||||
path="/{category}",
|
||||
summary="按分类获取文件列表",
|
||||
)
|
||||
async def router_category_list(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
category: FileCategory,
|
||||
table_view: TableViewRequestDep,
|
||||
) -> ListResponse[ObjectResponse]:
|
||||
"""
|
||||
按文件类型分类查询用户的所有文件
|
||||
|
||||
跨所有目录搜索,返回分页结果。
|
||||
扩展名配置从数据库 Setting 表读取(type=file_category)。
|
||||
|
||||
认证:
|
||||
- JWT token in Authorization header
|
||||
|
||||
路径参数:
|
||||
- category: 文件分类(image / video / audio / document)
|
||||
|
||||
查询参数:
|
||||
- offset: 分页偏移量(默认0)
|
||||
- limit: 每页数量(默认20,最大100)
|
||||
- desc: 是否降序(默认true)
|
||||
- order: 排序字段(created_at / updated_at)
|
||||
|
||||
响应:
|
||||
- ListResponse[ObjectResponse]: 分页文件列表
|
||||
|
||||
错误处理:
|
||||
- HTTPException 422: category 参数无效
|
||||
- HTTPException 404: 该分类未配置扩展名
|
||||
"""
|
||||
# 从数据库读取该分类的扩展名配置
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
|
||||
)
|
||||
if not setting or not setting.value:
|
||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
|
||||
|
||||
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
|
||||
if not extensions:
|
||||
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
|
||||
|
||||
result = await Object.get_by_category(
|
||||
session,
|
||||
user.id,
|
||||
extensions,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
items = [
|
||||
ObjectResponse(
|
||||
id=obj.id,
|
||||
name=obj.name,
|
||||
type=ObjectType.FILE,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
thumb=False,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for obj in result.items
|
||||
]
|
||||
|
||||
return ListResponse(count=result.count, items=items)
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
@@ -14,12 +16,84 @@ from models import (
|
||||
User,
|
||||
ResponseBase,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
directory_router = APIRouter(
|
||||
prefix="/directory",
|
||||
tags=["directory"]
|
||||
)
|
||||
|
||||
|
||||
async def _get_directory_response(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
folder: Object,
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
构建目录响应 DTO
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param folder: 目录对象
|
||||
:return: DirectoryResponse
|
||||
"""
|
||||
children = await Object.get_children(session, user_id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
|
||||
objects = [
|
||||
ObjectResponse(
|
||||
id=child.id,
|
||||
name=child.name,
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
created_at=child.created_at,
|
||||
updated_at=child.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
|
||||
policy_response = PolicyResponse(
|
||||
id=policy.id,
|
||||
name=policy.name,
|
||||
type=policy.type,
|
||||
max_size=policy.max_size,
|
||||
)
|
||||
|
||||
return DirectoryResponse(
|
||||
id=folder.id,
|
||||
parent=folder.parent_id,
|
||||
objects=objects,
|
||||
policy=policy_response,
|
||||
)
|
||||
|
||||
|
||||
@directory_router.get(
|
||||
path="/",
|
||||
summary="获取根目录内容",
|
||||
)
|
||||
async def router_directory_root(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取当前用户的根目录内容
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:return: 根目录内容
|
||||
"""
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
|
||||
if root.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
@@ -32,24 +106,23 @@ async def router_directory_get(
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
|
||||
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(必须以用户名开头)
|
||||
:param path: 目录路径(从根目录开始的相对路径)
|
||||
:return: 目录内容
|
||||
"""
|
||||
# 路径必须以用户名开头
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
|
||||
# 空路径交给根目录端点处理(理论上不会到达这里)
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
path_parts = path.split("/")
|
||||
if path_parts[0] != user.username:
|
||||
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
@@ -57,47 +130,22 @@ async def router_directory_get(
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
children = await Object.get_children(session, user.id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
if folder.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
objects = [
|
||||
ObjectResponse(
|
||||
id=child.id,
|
||||
name=child.name,
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
date=child.updated_at,
|
||||
create_date=child.created_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
|
||||
policy_response = PolicyResponse(
|
||||
id=policy.id,
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
max_size=policy.max_size,
|
||||
)
|
||||
|
||||
return DirectoryResponse(
|
||||
id=folder.id,
|
||||
parent=folder.parent_id,
|
||||
objects=objects,
|
||||
policy=policy_response,
|
||||
)
|
||||
return await _get_directory_response(session, user.id, folder)
|
||||
|
||||
|
||||
@directory_router.put(
|
||||
@directory_router.post(
|
||||
path="/",
|
||||
summary="创建目录",
|
||||
status_code=204,
|
||||
)
|
||||
async def router_directory_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: DirectoryCreateRequest
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建目录
|
||||
|
||||
@@ -115,25 +163,40 @@ async def router_directory_create(
|
||||
if "/" in name or "\\" in name:
|
||||
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
|
||||
|
||||
# 通过 UUID 获取父目录
|
||||
parent = await Object.get(session, Object.id == request.parent_id)
|
||||
# 通过 UUID 获取父目录(排除已删除的)
|
||||
parent = await Object.get(
|
||||
session,
|
||||
(Object.id == request.parent_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not parent or parent.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父路径不是目录")
|
||||
|
||||
# 检查是否已存在同名对象
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名对象(仅检查未删除的)
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user.id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == name)
|
||||
(Object.name == name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||
|
||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||
|
||||
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
|
||||
if request.policy_id:
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
if request.policy_id not in {p.id for p in group.policies}:
|
||||
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
|
||||
|
||||
parent_id = parent.id # 在 save 前保存
|
||||
|
||||
new_folder = Object(
|
||||
@@ -143,14 +206,4 @@ async def router_directory_create(
|
||||
parent_id=parent_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
new_folder_id = new_folder.id # 在 save 前保存 UUID
|
||||
new_folder_name = new_folder.name
|
||||
await new_folder.save(session)
|
||||
|
||||
return ResponseBase(
|
||||
data={
|
||||
"id": new_folder_id,
|
||||
"name": new_folder_name,
|
||||
"parent_id": parent_id,
|
||||
}
|
||||
)
|
||||
new_folder = await new_folder.save(session)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
download_router = APIRouter(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
106
routers/api/v1/file/viewers/__init__.py
Normal file
106
routers/api/v1/file/viewers/__init__.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
文件查看器查询端点
|
||||
|
||||
提供按文件扩展名查询可用查看器的功能,包含用户组访问控制过滤。
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import and_
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
FileAppGroupLink,
|
||||
FileAppSummary,
|
||||
FileViewersResponse,
|
||||
User,
|
||||
UserFileAppDefault,
|
||||
)
|
||||
|
||||
viewers_router = APIRouter(prefix="/viewers", tags=["file", "viewers"])
|
||||
|
||||
|
||||
@viewers_router.get(
|
||||
path='',
|
||||
summary='查询可用文件查看器',
|
||||
description='根据文件扩展名查询可用的查看器应用列表。',
|
||||
)
|
||||
async def get_viewers(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
ext: Annotated[str, Query(max_length=20, description="文件扩展名")],
|
||||
) -> FileViewersResponse:
|
||||
"""
|
||||
查询可用文件查看器端点
|
||||
|
||||
流程:
|
||||
1. 规范化扩展名(小写,去点号)
|
||||
2. 查询匹配的已启用应用
|
||||
3. 按用户组权限过滤
|
||||
4. 按 priority 排序
|
||||
5. 查询用户默认偏好
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 401: 未授权
|
||||
"""
|
||||
# 规范化扩展名
|
||||
normalized_ext = ext.lower().strip().lstrip('.')
|
||||
|
||||
# 查询匹配扩展名的应用(已启用的)
|
||||
ext_records: list[FileAppExtension] = await FileAppExtension.get(
|
||||
session,
|
||||
and_(
|
||||
FileAppExtension.extension == normalized_ext,
|
||||
),
|
||||
fetch_mode="all",
|
||||
load=FileAppExtension.app,
|
||||
)
|
||||
|
||||
# 过滤和收集可用应用
|
||||
user_group_id = user.group_id
|
||||
viewers: list[tuple[FileAppSummary, int]] = []
|
||||
|
||||
for ext_record in ext_records:
|
||||
app: FileApp = ext_record.app
|
||||
if not app.is_enabled:
|
||||
continue
|
||||
|
||||
if app.is_restricted:
|
||||
# 检查用户组权限(FileAppGroupLink 是纯关联表,使用 session 查询)
|
||||
stmt = select(FileAppGroupLink).where(
|
||||
and_(
|
||||
FileAppGroupLink.app_id == app.id,
|
||||
FileAppGroupLink.group_id == user_group_id,
|
||||
)
|
||||
)
|
||||
result = await session.exec(stmt)
|
||||
group_link = result.first()
|
||||
if not group_link:
|
||||
continue
|
||||
|
||||
viewers.append((app.to_summary(), ext_record.priority))
|
||||
|
||||
# 按 priority 排序
|
||||
viewers.sort(key=lambda x: x[1])
|
||||
|
||||
# 查询用户默认偏好
|
||||
user_default: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
UserFileAppDefault.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
|
||||
return FileViewersResponse(
|
||||
viewers=[v[0] for v in viewers],
|
||||
default_viewer_id=user_default.app_id if user_default else None,
|
||||
)
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import MCPRequestBase, MCPResponseBase, MCPMethod
|
||||
|
||||
# MCP 路由
|
||||
MCP_router = APIRouter(
|
||||
prefix='/mcp',
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
@MCP_router.get(
|
||||
"/",
|
||||
)
|
||||
async def mcp_root(
|
||||
param: MCPRequestBase
|
||||
):
|
||||
match param.method:
|
||||
case MCPMethod.PING:
|
||||
return MCPResponseBase(result="pong", **param.model_dump())
|
||||
@@ -8,13 +8,14 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Group,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -22,186 +23,169 @@ from models import (
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectSwitchPolicyRequest,
|
||||
ObjectType,
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
Task,
|
||||
TaskProps,
|
||||
TaskStatus,
|
||||
TaskSummaryBase,
|
||||
TaskType,
|
||||
User,
|
||||
# 元数据相关
|
||||
ObjectMetadata,
|
||||
MetadataResponse,
|
||||
MetadataPatchRequest,
|
||||
INTERNAL_NAMESPACES,
|
||||
USER_WRITABLE_NAMESPACES,
|
||||
)
|
||||
from models import ResponseBase
|
||||
from service.storage import LocalStorageService
|
||||
from service.storage import (
|
||||
LocalStorageService,
|
||||
adjust_user_storage,
|
||||
copy_object_recursive,
|
||||
migrate_file_with_task,
|
||||
migrate_directory_files,
|
||||
)
|
||||
from service.storage.object import soft_delete_objects
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from utils import http_exceptions
|
||||
|
||||
from .custom_property import router as custom_property_router
|
||||
|
||||
object_router = APIRouter(
|
||||
prefix="/object",
|
||||
tags=["object"]
|
||||
)
|
||||
object_router.include_router(custom_property_router)
|
||||
|
||||
|
||||
async def _delete_object_recursive(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
) -> int:
|
||||
@object_router.post(
|
||||
path='/',
|
||||
summary='创建空白文件',
|
||||
description='在指定目录下创建空白文件。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_object_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateFileRequest,
|
||||
) -> None:
|
||||
"""
|
||||
递归删除对象(软删除)
|
||||
|
||||
对于文件:
|
||||
- 减少 PhysicalFile 引用计数
|
||||
- 只有引用计数为0时才移动物理文件到回收站
|
||||
|
||||
对于目录:
|
||||
- 递归处理所有子对象
|
||||
创建空白文件端点
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要删除的对象
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除的对象数量
|
||||
:param user: 当前登录用户
|
||||
:param request: 创建文件请求(parent_id, name)
|
||||
:return: 创建结果
|
||||
"""
|
||||
deleted_count = 0
|
||||
user_id = user.id
|
||||
|
||||
if obj.is_folder:
|
||||
# 递归删除子对象
|
||||
children = await Object.get_children(session, user_id, obj.id)
|
||||
for child in children:
|
||||
deleted_count += await _delete_object_recursive(session, child, user_id)
|
||||
# 验证文件名
|
||||
if not request.name or '/' in request.name or '\\' in request.name:
|
||||
raise HTTPException(status_code=400, detail="无效的文件名")
|
||||
|
||||
# 如果是文件,处理物理文件引用
|
||||
if obj.is_file and obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if physical_file:
|
||||
# 减少引用计数
|
||||
new_count = physical_file.decrement_reference()
|
||||
|
||||
if physical_file.can_be_deleted:
|
||||
# 引用计数为0,移动物理文件到回收站
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy and policy.type == PolicyType.LOCAL:
|
||||
try:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.move_to_trash(
|
||||
source_path=physical_file.storage_path,
|
||||
user_id=user_id,
|
||||
object_id=obj.id,
|
||||
)
|
||||
l.debug(f"物理文件已移动到回收站: {obj.name}")
|
||||
except Exception as e:
|
||||
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
|
||||
|
||||
# 删除 PhysicalFile 记录
|
||||
await PhysicalFile.delete(session, physical_file)
|
||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||
else:
|
||||
# 还有其他引用,只更新引用计数
|
||||
await physical_file.save(session)
|
||||
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
|
||||
|
||||
# 删除数据库记录
|
||||
await Object.delete(session, obj)
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
|
||||
async def _copy_object_recursive(
|
||||
session: AsyncSession,
|
||||
src: Object,
|
||||
dst_parent_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> tuple[int, list[UUID]]:
|
||||
"""
|
||||
递归复制对象
|
||||
|
||||
对于文件:
|
||||
- 增加 PhysicalFile 引用计数
|
||||
- 创建新的 Object 记录指向同一 PhysicalFile
|
||||
|
||||
对于目录:
|
||||
- 创建新目录
|
||||
- 递归复制所有子对象
|
||||
|
||||
:param session: 数据库会话
|
||||
:param src: 源对象
|
||||
:param dst_parent_id: 目标父目录UUID
|
||||
:param user_id: 用户UUID
|
||||
:return: (复制数量, 新对象UUID列表)
|
||||
"""
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
|
||||
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
|
||||
src_is_folder = src.is_folder
|
||||
src_id = src.id
|
||||
|
||||
# 创建新的 Object 记录
|
||||
new_obj = Object(
|
||||
name=src.name,
|
||||
type=src.type,
|
||||
size=src.size,
|
||||
password=src.password,
|
||||
parent_id=dst_parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=src.policy_id,
|
||||
physical_file_id=src.physical_file_id,
|
||||
# 验证父目录(排除已删除的)
|
||||
parent = await Object.get(
|
||||
session,
|
||||
(Object.id == request.parent_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not parent or parent.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
# 如果是文件,增加物理文件引用计数
|
||||
if src.is_file and src.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src.physical_file_id)
|
||||
if physical_file:
|
||||
physical_file.increment_reference()
|
||||
await physical_file.save(session)
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
new_obj = await new_obj.save(session)
|
||||
copied_count += 1
|
||||
new_ids.append(new_obj.id)
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 如果是目录,递归复制子对象
|
||||
if src_is_folder:
|
||||
children = await Object.get_children(session, user_id, src_id)
|
||||
for child in children:
|
||||
child_count, child_ids = await _copy_object_recursive(
|
||||
session, child, new_obj.id, user_id
|
||||
)
|
||||
copied_count += child_count
|
||||
new_ids.extend(child_ids)
|
||||
# 检查是否已存在同名文件(仅检查未删除的)
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == request.name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件已存在")
|
||||
|
||||
return copied_count, new_ids
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get_exist_one(session, policy_id)
|
||||
|
||||
parent_id = parent.id
|
||||
|
||||
# 生成存储路径并创建空文件
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=0,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=request.name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
file_object = await file_object.save(session)
|
||||
|
||||
l.info(f"创建空白文件: {request.name}")
|
||||
|
||||
|
||||
@object_router.delete(
|
||||
path='/',
|
||||
summary='删除对象',
|
||||
description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_object_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectDeleteRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除对象端点(软删除)
|
||||
删除对象端点(软删除到回收站)
|
||||
|
||||
流程:
|
||||
1. 验证对象存在且属于当前用户
|
||||
2. 对于文件,减少物理文件引用计数
|
||||
3. 如果引用计数为0,移动物理文件到 .trash 目录
|
||||
4. 对于目录,递归处理子对象
|
||||
5. 从数据库中删除记录
|
||||
2. 设置 deleted_at 时间戳
|
||||
3. 保存原 parent_id 到 deleted_original_parent_id
|
||||
4. 将 parent_id 置 NULL 脱离文件树
|
||||
5. 子对象和物理文件不做任何变更
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 删除请求(包含待删除对象的UUID列表)
|
||||
:return: 删除结果
|
||||
"""
|
||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||
user_id = user.id
|
||||
deleted_count = 0
|
||||
objects_to_delete: list[Object] = []
|
||||
|
||||
# 处理单个 UUID 或 UUID 列表
|
||||
ids = request.ids if isinstance(request.ids, list) else [request.ids]
|
||||
|
||||
for obj_id in ids:
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == obj_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj or obj.owner_id != user_id:
|
||||
continue
|
||||
|
||||
@@ -210,30 +194,24 @@ async def router_object_delete(
|
||||
l.warning(f"尝试删除根目录被阻止: {obj.name}")
|
||||
continue
|
||||
|
||||
# 递归删除(包含引用计数逻辑)
|
||||
count = await _delete_object_recursive(session, obj, user_id)
|
||||
deleted_count += count
|
||||
objects_to_delete.append(obj)
|
||||
|
||||
l.info(f"用户 {user_id} 删除了 {deleted_count} 个对象")
|
||||
|
||||
return ResponseBase(
|
||||
data={
|
||||
"deleted": deleted_count,
|
||||
"total": len(ids),
|
||||
}
|
||||
)
|
||||
if objects_to_delete:
|
||||
deleted_count = await soft_delete_objects(session, objects_to_delete)
|
||||
l.info(f"用户 {user_id} 软删除了 {deleted_count} 个对象到回收站")
|
||||
|
||||
|
||||
@object_router.patch(
|
||||
path='/',
|
||||
summary='移动对象',
|
||||
description='移动一个或多个对象到目标目录',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_object_move(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectMoveRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
移动对象端点
|
||||
|
||||
@@ -245,14 +223,20 @@ async def router_object_move(
|
||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||
user_id = user.id
|
||||
|
||||
# 验证目标目录
|
||||
dst = await Object.get(session, Object.id == request.dst_id)
|
||||
# 验证目标目录(排除已删除的)
|
||||
dst = await Object.get(
|
||||
session,
|
||||
(Object.id == request.dst_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not dst or dst.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="目标目录不存在")
|
||||
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
|
||||
dst_id = dst.id
|
||||
dst_parent_id = dst.parent_id
|
||||
@@ -260,10 +244,16 @@ async def router_object_move(
|
||||
moved_count = 0
|
||||
|
||||
for src_id in request.src_ids:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
src = await Object.get(
|
||||
session,
|
||||
(Object.id == src_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能移动根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -285,12 +275,13 @@ async def router_object_move(
|
||||
if is_cycle:
|
||||
continue
|
||||
|
||||
# 检查目标目录下是否存在同名对象
|
||||
# 检查目标目录下是否存在同名对象(仅检查未删除的)
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == dst_id) &
|
||||
(Object.name == src.name)
|
||||
(Object.name == src.name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if existing:
|
||||
continue # 跳过重名对象
|
||||
@@ -302,24 +293,18 @@ async def router_object_move(
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
|
||||
return ResponseBase(
|
||||
data={
|
||||
"moved": moved_count,
|
||||
"total": len(request.src_ids),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@object_router.post(
|
||||
path='/copy',
|
||||
summary='复制对象',
|
||||
description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_object_copy(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectCopyRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
复制对象端点
|
||||
|
||||
@@ -340,27 +325,41 @@ async def router_object_copy(
|
||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||
user_id = user.id
|
||||
|
||||
# 验证目标目录
|
||||
dst = await Object.get(session, Object.id == request.dst_id)
|
||||
# 验证目标目录(排除已删除的)
|
||||
dst = await Object.get(
|
||||
session,
|
||||
(Object.id == request.dst_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not dst or dst.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="目标目录不存在")
|
||||
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
total_copied_size = 0
|
||||
|
||||
for src_id in request.src_ids:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
src = await Object.get(
|
||||
session,
|
||||
(Object.id == src_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
http_exceptions.raise_banned("源对象已被封禁,无法执行此操作")
|
||||
|
||||
# 不能复制根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
http_exceptions.raise_banned("无法复制根目录")
|
||||
|
||||
# 不能复制到自身
|
||||
# [TODO] 视为创建副本
|
||||
if src.id == dst.id:
|
||||
continue
|
||||
|
||||
@@ -376,42 +375,42 @@ async def router_object_copy(
|
||||
if is_cycle:
|
||||
continue
|
||||
|
||||
# 检查目标目录下是否存在同名对象
|
||||
# 检查目标目录下是否存在同名对象(仅检查未删除的)
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == dst.id) &
|
||||
(Object.name == src.name)
|
||||
(Object.name == src.name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if existing:
|
||||
continue # 跳过重名对象
|
||||
# [TODO] 应当询问用户是否覆盖、跳过或创建副本
|
||||
continue
|
||||
|
||||
# 递归复制
|
||||
count, ids = await _copy_object_recursive(session, src, dst.id, user_id)
|
||||
count, ids, copied_size = await copy_object_recursive(session, src, dst.id, user_id)
|
||||
copied_count += count
|
||||
new_ids.extend(ids)
|
||||
total_copied_size += copied_size
|
||||
|
||||
# 更新用户存储配额
|
||||
if total_copied_size > 0:
|
||||
await adjust_user_storage(session, user_id, total_copied_size)
|
||||
|
||||
l.info(f"用户 {user_id} 复制了 {copied_count} 个对象")
|
||||
|
||||
return ResponseBase(
|
||||
data={
|
||||
"copied": copied_count,
|
||||
"total": len(request.src_ids),
|
||||
"new_ids": new_ids,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@object_router.post(
|
||||
path='/rename',
|
||||
summary='重命名对象',
|
||||
description='重命名对象(文件或目录)。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_object_rename(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ObjectRenameRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
重命名对象端点
|
||||
|
||||
@@ -430,14 +429,20 @@ async def router_object_rename(
|
||||
# 存储 user.id,避免后续 save() 导致 user 过期后无法访问
|
||||
user_id = user.id
|
||||
|
||||
# 验证对象存在
|
||||
obj = await Object.get(session, Object.id == request.id)
|
||||
# 验证对象存在(排除已删除的)
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == request.id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
if obj.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 不能重命名根目录
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="无法重命名根目录")
|
||||
@@ -450,28 +455,27 @@ async def router_object_rename(
|
||||
if '/' in new_name or '\\' in new_name:
|
||||
raise HTTPException(status_code=400, detail="名称不能包含斜杠")
|
||||
|
||||
# 如果名称没有变化,直接返回成功
|
||||
# 如果名称没有变化,直接返回
|
||||
if obj.name == new_name:
|
||||
return ResponseBase(data={"success": True})
|
||||
return # noqa: already 204
|
||||
|
||||
# 检查同目录下是否存在同名对象
|
||||
# 检查同目录下是否存在同名对象(仅检查未删除的)
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == obj.parent_id) &
|
||||
(Object.name == new_name)
|
||||
(Object.name == new_name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名对象已存在")
|
||||
|
||||
# 更新名称
|
||||
obj.name = new_name
|
||||
await obj.save(session)
|
||||
obj = await obj.save(session)
|
||||
|
||||
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
|
||||
|
||||
return ResponseBase(data={"success": True})
|
||||
|
||||
|
||||
@object_router.get(
|
||||
path='/property/{id}',
|
||||
@@ -491,7 +495,10 @@ async def router_object_property(
|
||||
:param id: 对象UUID
|
||||
:return: 对象基本属性
|
||||
"""
|
||||
obj = await Object.get(session, Object.id == id)
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
@@ -503,6 +510,7 @@ async def router_object_property(
|
||||
name=obj.name,
|
||||
type=obj.type,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
parent_id=obj.parent_id,
|
||||
@@ -529,8 +537,8 @@ async def router_object_property_detail(
|
||||
"""
|
||||
obj = await Object.get(
|
||||
session,
|
||||
Object.id == id,
|
||||
load=Object.file_metadata,
|
||||
(Object.id == id) & (Object.deleted_at == None),
|
||||
load=Object.metadata_entries,
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
@@ -543,7 +551,7 @@ async def router_object_property_detail(
|
||||
policy_name = policy.name if policy else None
|
||||
|
||||
# 获取分享统计
|
||||
from models import Share
|
||||
from sqlmodels import Share
|
||||
shares = await Share.get(
|
||||
session,
|
||||
Share.object_id == obj.id,
|
||||
@@ -553,35 +561,301 @@ async def router_object_property_detail(
|
||||
total_views = sum(s.views for s in shares)
|
||||
total_downloads = sum(s.downloads for s in shares)
|
||||
|
||||
# 获取物理文件引用计数
|
||||
# 获取物理文件信息(引用计数、校验和)
|
||||
reference_count = 1
|
||||
checksum_md5: str | None = None
|
||||
checksum_sha256: str | None = None
|
||||
if obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if physical_file:
|
||||
reference_count = physical_file.reference_count
|
||||
checksum_md5 = physical_file.checksum_md5
|
||||
checksum_sha256 = physical_file.checksum_sha256
|
||||
|
||||
# 构建响应
|
||||
response = ObjectPropertyDetailResponse(
|
||||
# 构建元数据字典(排除内部命名空间)
|
||||
metadata: dict[str, str] = {}
|
||||
for entry in obj.metadata_entries:
|
||||
ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||
if ns not in INTERNAL_NAMESPACES:
|
||||
metadata[entry.name] = entry.value
|
||||
|
||||
return ObjectPropertyDetailResponse(
|
||||
id=obj.id,
|
||||
name=obj.name,
|
||||
type=obj.type,
|
||||
size=obj.size,
|
||||
mime_type=obj.mime_type,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
parent_id=obj.parent_id,
|
||||
checksum_md5=checksum_md5,
|
||||
checksum_sha256=checksum_sha256,
|
||||
policy_name=policy_name,
|
||||
share_count=share_count,
|
||||
total_views=total_views,
|
||||
total_downloads=total_downloads,
|
||||
reference_count=reference_count,
|
||||
metadatas=metadata,
|
||||
)
|
||||
|
||||
# 添加文件元数据
|
||||
if obj.file_metadata:
|
||||
response.mime_type = obj.file_metadata.mime_type
|
||||
response.width = obj.file_metadata.width
|
||||
response.height = obj.file_metadata.height
|
||||
response.duration = obj.file_metadata.duration
|
||||
response.checksum_md5 = obj.file_metadata.checksum_md5
|
||||
|
||||
return response
|
||||
@object_router.patch(
|
||||
path='/{object_id}/policy',
|
||||
summary='切换对象存储策略',
|
||||
)
|
||||
async def router_object_switch_policy(
|
||||
session: SessionDep,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
request: ObjectSwitchPolicyRequest,
|
||||
) -> TaskSummaryBase:
|
||||
"""
|
||||
切换对象的存储策略
|
||||
|
||||
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
|
||||
目录:更新目录 policy_id(新文件使用新策略);
|
||||
若 is_migrate_existing=True,额外创建后台任务迁移所有已有文件。
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 404: 对象不存在
|
||||
- 403: 无权操作此对象 / 用户组无权使用目标策略
|
||||
- 400: 目标策略与当前相同 / 不能对根目录操作
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 查找对象
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj:
|
||||
http_exceptions.raise_not_found("对象不存在")
|
||||
if obj.owner_id != user_id:
|
||||
http_exceptions.raise_forbidden("无权操作此对象")
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 根目录不能直接切换策略(应通过子对象或子目录操作)
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
|
||||
|
||||
# 校验目标策略存在
|
||||
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
|
||||
if not dest_policy:
|
||||
http_exceptions.raise_not_found("目标存储策略不存在")
|
||||
|
||||
# 校验用户组权限
|
||||
group: Group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
allowed_ids = {p.id for p in group.policies}
|
||||
if request.policy_id not in allowed_ids:
|
||||
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
|
||||
|
||||
# 不能切换到相同策略
|
||||
if obj.policy_id == request.policy_id:
|
||||
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
|
||||
|
||||
# 保存必要的属性,避免 save 后对象过期
|
||||
src_policy_id = obj.policy_id
|
||||
obj_id = obj.id
|
||||
obj_is_file = obj.type == ObjectType.FILE
|
||||
dest_policy_id = request.policy_id
|
||||
dest_policy_name = dest_policy.name
|
||||
|
||||
# 创建任务记录
|
||||
task = Task(
|
||||
type=TaskType.POLICY_MIGRATE,
|
||||
status=TaskStatus.QUEUED,
|
||||
user_id=user_id,
|
||||
)
|
||||
task = await task.save(session)
|
||||
task_id = task.id
|
||||
|
||||
task_props = TaskProps(
|
||||
task_id=task_id,
|
||||
source_policy_id=src_policy_id,
|
||||
dest_policy_id=dest_policy_id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
task_props = await task_props.save(session)
|
||||
|
||||
if obj_is_file:
|
||||
# 文件:后台迁移
|
||||
async def _run_file_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_obj = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_file_migration)
|
||||
else:
|
||||
# 目录:先更新目录自身的 policy_id
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
obj.policy_id = dest_policy_id
|
||||
obj = await obj.save(session)
|
||||
|
||||
if request.is_migrate_existing:
|
||||
# 后台迁移所有已有文件
|
||||
async def _run_dir_migration() -> None:
|
||||
async with DatabaseManager.session() as bg_session:
|
||||
bg_folder = await Object.get(bg_session, Object.id == obj_id)
|
||||
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
|
||||
bg_task = await Task.get(bg_session, Task.id == task_id)
|
||||
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
|
||||
|
||||
background_tasks.add_task(_run_dir_migration)
|
||||
else:
|
||||
# 不迁移已有文件,直接完成任务
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
|
||||
# 重新获取 task 以读取最新状态
|
||||
task = await Task.get(session, Task.id == task_id)
|
||||
|
||||
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
|
||||
|
||||
return TaskSummaryBase(
|
||||
id=task.id,
|
||||
type=task.type,
|
||||
status=task.status,
|
||||
progress=task.progress,
|
||||
error=task.error,
|
||||
user_id=task.user_id,
|
||||
created_at=task.created_at,
|
||||
updated_at=task.updated_at,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 元数据端点 ====================
|
||||
|
||||
@object_router.get(
|
||||
path='/{object_id}/metadata',
|
||||
summary='获取对象元数据',
|
||||
description='获取对象的元数据键值对,可按命名空间过滤。',
|
||||
)
|
||||
async def router_get_object_metadata(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
ns: str | None = None,
|
||||
) -> MetadataResponse:
|
||||
"""
|
||||
获取对象元数据端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
查询参数:
|
||||
- ns: 逗号分隔的命名空间列表(如 exif,stream),不传返回所有非内部命名空间
|
||||
|
||||
错误处理:
|
||||
- 404: 对象不存在
|
||||
- 403: 无权查看此对象
|
||||
"""
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None),
|
||||
load=Object.metadata_entries,
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
if obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权查看此对象")
|
||||
|
||||
# 解析命名空间过滤
|
||||
ns_filter: set[str] | None = None
|
||||
if ns:
|
||||
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
|
||||
# 不允许查看内部命名空间
|
||||
ns_filter -= INTERNAL_NAMESPACES
|
||||
|
||||
# 构建元数据字典
|
||||
metadata: dict[str, str] = {}
|
||||
for entry in obj.metadata_entries:
|
||||
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
|
||||
if entry_ns in INTERNAL_NAMESPACES:
|
||||
continue
|
||||
if ns_filter is not None and entry_ns not in ns_filter:
|
||||
continue
|
||||
metadata[entry.name] = entry.value
|
||||
|
||||
return MetadataResponse(metadatas=metadata)
|
||||
|
||||
|
||||
@object_router.patch(
|
||||
path='/{object_id}/metadata',
|
||||
summary='批量更新对象元数据',
|
||||
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_patch_object_metadata(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
object_id: UUID,
|
||||
request: MetadataPatchRequest,
|
||||
) -> None:
|
||||
"""
|
||||
批量更新对象元数据端点
|
||||
|
||||
请求体中值为 None 的键将被删除,其余键将被设置/更新。
|
||||
用户只能修改 custom: 命名空间的条目。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 尝试修改非 custom: 命名空间的条目
|
||||
- 404: 对象不存在
|
||||
- 403: 无权操作此对象
|
||||
"""
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == object_id) & (Object.deleted_at == None),
|
||||
)
|
||||
if not obj:
|
||||
raise HTTPException(status_code=404, detail="对象不存在")
|
||||
|
||||
if obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
for patch in request.patches:
|
||||
# 验证命名空间
|
||||
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
|
||||
if patch_ns not in USER_WRITABLE_NAMESPACES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
|
||||
)
|
||||
|
||||
if patch.value is None:
|
||||
# 删除元数据条目
|
||||
existing = await ObjectMetadata.get(
|
||||
session,
|
||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||
)
|
||||
if existing:
|
||||
await ObjectMetadata.delete(session, instances=existing)
|
||||
else:
|
||||
# 设置/更新元数据条目
|
||||
existing = await ObjectMetadata.get(
|
||||
session,
|
||||
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
|
||||
)
|
||||
if existing:
|
||||
existing.value = patch.value
|
||||
existing = await existing.save(session)
|
||||
else:
|
||||
entry = ObjectMetadata(
|
||||
object_id=object_id,
|
||||
name=patch.key,
|
||||
value=patch.value,
|
||||
is_public=True,
|
||||
)
|
||||
entry = await entry.save(session)
|
||||
|
||||
l.info(f"用户 {user.id} 更新了对象 {object_id} 的 {len(request.patches)} 条元数据")
|
||||
|
||||
168
routers/api/v1/object/custom_property/__init__.py
Normal file
168
routers/api/v1/object/custom_property/__init__.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
用户自定义属性定义路由
|
||||
|
||||
提供自定义属性模板的增删改查功能。
|
||||
用户可以定义类型化的属性模板(如标签、评分、分类等),
|
||||
然后通过元数据 PATCH 端点为对象设置属性值。
|
||||
|
||||
路由前缀:/custom_property
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
CustomPropertyDefinition,
|
||||
CustomPropertyCreateRequest,
|
||||
CustomPropertyUpdateRequest,
|
||||
CustomPropertyResponse,
|
||||
User,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/custom_property",
|
||||
tags=["custom_property"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
path='',
|
||||
summary='获取自定义属性定义列表',
|
||||
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
|
||||
)
|
||||
async def router_list_custom_properties(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[CustomPropertyResponse]:
|
||||
"""
|
||||
获取自定义属性定义列表端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
返回当前用户定义的所有自定义属性模板。
|
||||
"""
|
||||
definitions = await CustomPropertyDefinition.get(
|
||||
session,
|
||||
CustomPropertyDefinition.owner_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
return [
|
||||
CustomPropertyResponse(
|
||||
id=d.id,
|
||||
name=d.name,
|
||||
type=d.type,
|
||||
icon=d.icon,
|
||||
options=d.options,
|
||||
default_value=d.default_value,
|
||||
sort_order=d.sort_order,
|
||||
)
|
||||
for d in sorted(definitions, key=lambda x: x.sort_order)
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
path='',
|
||||
summary='创建自定义属性定义',
|
||||
description='创建一个新的自定义属性模板。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_create_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CustomPropertyCreateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
创建自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 400: 请求数据无效
|
||||
- 409: 同名属性已存在
|
||||
"""
|
||||
# 检查同名属性
|
||||
existing = await CustomPropertyDefinition.get(
|
||||
session,
|
||||
(CustomPropertyDefinition.owner_id == user.id) &
|
||||
(CustomPropertyDefinition.name == request.name),
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
|
||||
|
||||
definition = CustomPropertyDefinition(
|
||||
owner_id=user.id,
|
||||
name=request.name,
|
||||
type=request.type,
|
||||
icon=request.icon,
|
||||
options=request.options,
|
||||
default_value=request.default_value,
|
||||
)
|
||||
definition = await definition.save(session)
|
||||
|
||||
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
|
||||
|
||||
|
||||
@router.patch(
|
||||
path='/{id}',
|
||||
summary='更新自定义属性定义',
|
||||
description='更新自定义属性模板的名称、图标、选项等。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_update_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
request: CustomPropertyUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 属性定义不存在
|
||||
- 403: 无权操作此属性
|
||||
"""
|
||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||
|
||||
if definition.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||
|
||||
definition = await definition.update(session, request)
|
||||
|
||||
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
|
||||
|
||||
|
||||
@router.delete(
|
||||
path='/{id}',
|
||||
summary='删除自定义属性定义',
|
||||
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_delete_custom_property(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
删除自定义属性定义端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 属性定义不存在
|
||||
- 403: 无权操作此属性
|
||||
"""
|
||||
definition = await CustomPropertyDefinition.get_exist_one(session, id)
|
||||
|
||||
if definition.owner_id != user.id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此属性")
|
||||
|
||||
await CustomPropertyDefinition.delete(session, instances=definition)
|
||||
|
||||
l.info(f"用户 {user.id} 删除了自定义属性: {id}")
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Annotated, Literal
|
||||
from uuid import uuid4
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException
|
||||
@@ -7,13 +7,16 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase
|
||||
from models.user import User
|
||||
from models.share import Share, ShareCreateRequest, ShareResponse
|
||||
from models.object import Object
|
||||
from models.mixin import ListResponse, TableViewRequest
|
||||
from sqlmodels import ResponseBase
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.share import (
|
||||
Share, ShareCreateRequest, CreateShareResponse, ShareResponse,
|
||||
ShareDetailResponse, ShareOwnerInfo, ShareObjectItem,
|
||||
)
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodel_ext import ListResponse, TableViewRequest
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
share_router = APIRouter(
|
||||
prefix='/share',
|
||||
@@ -22,21 +25,92 @@ share_router = APIRouter(
|
||||
|
||||
@share_router.get(
|
||||
path='/{id}',
|
||||
summary='获取分享',
|
||||
description='Get shared content by info type and ID.',
|
||||
summary='获取分享详情',
|
||||
description='Get share detail by share ID. No authentication required.',
|
||||
)
|
||||
def router_share_get(info: str, id: str) -> ResponseBase:
|
||||
async def router_share_get(
|
||||
session: SessionDep,
|
||||
id: UUID,
|
||||
password: str | None = Query(default=None),
|
||||
) -> ShareDetailResponse:
|
||||
"""
|
||||
Get shared content by info type and ID.
|
||||
获取分享详情
|
||||
|
||||
Args:
|
||||
info (str): The type of information being shared.
|
||||
id (str): The ID of the shared content.
|
||||
认证:无需登录
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing shared content information.
|
||||
流程:
|
||||
1. 通过分享ID查找分享
|
||||
2. 检查过期、封禁状态
|
||||
3. 验证提取码(如果有)
|
||||
4. 返回分享详情(含文件树和分享者信息)
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
# 1. 查询分享(预加载 user 和 object)
|
||||
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
|
||||
|
||||
# 2. 检查过期
|
||||
now = datetime.now()
|
||||
if share.expires and share.expires < now:
|
||||
http_exceptions.raise_not_found(detail="分享已过期")
|
||||
|
||||
# 3. 获取关联对象
|
||||
obj = await share.awaitable_attrs.object
|
||||
user = await share.awaitable_attrs.user
|
||||
|
||||
# 4. 检查封禁和软删除
|
||||
if obj and obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
if obj and obj.deleted_at:
|
||||
http_exceptions.raise_not_found(detail="分享关联的文件已被删除")
|
||||
|
||||
# 5. 检查密码
|
||||
if share.password:
|
||||
if not password:
|
||||
http_exceptions.raise_precondition_required(detail="请输入提取码")
|
||||
if Password.verify(share.password, password) != PasswordStatus.VALID:
|
||||
http_exceptions.raise_forbidden(detail="提取码错误")
|
||||
|
||||
# 6. 加载子对象(目录分享)
|
||||
children_items: list[ShareObjectItem] = []
|
||||
if obj and obj.type == ObjectType.FOLDER:
|
||||
children = await Object.get_children(session, obj.owner_id, obj.id)
|
||||
children_items = [
|
||||
ShareObjectItem(
|
||||
id=child.id,
|
||||
name=child.name,
|
||||
type=child.type,
|
||||
size=child.size,
|
||||
created_at=child.created_at,
|
||||
updated_at=child.updated_at,
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
|
||||
# 7. 构建响应(在 save 之前,避免 MissingGreenlet)
|
||||
response = ShareDetailResponse(
|
||||
expires=share.expires,
|
||||
preview_enabled=share.preview_enabled,
|
||||
score=share.score,
|
||||
created_at=share.created_at,
|
||||
owner=ShareOwnerInfo(
|
||||
nickname=user.nickname if user else None,
|
||||
avatar=user.avatar if user else "default",
|
||||
),
|
||||
object=ShareObjectItem(
|
||||
id=obj.id,
|
||||
name=obj.name,
|
||||
type=obj.type,
|
||||
size=obj.size,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
),
|
||||
children=children_items,
|
||||
)
|
||||
|
||||
# 8. 递增浏览次数(最后执行,避免 MissingGreenlet)
|
||||
share.views += 1
|
||||
await share.save(session, refresh=False)
|
||||
|
||||
return response
|
||||
|
||||
@share_router.put(
|
||||
path='/download/{id}',
|
||||
@@ -72,23 +146,6 @@ def router_share_preview(id: str) -> ResponseBase:
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/doc/{id}',
|
||||
summary='取得Office文档预览地址',
|
||||
description='Get Office document preview URL by ID.',
|
||||
)
|
||||
def router_share_doc(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get Office document preview URL by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the Office document.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the document preview URL.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/content/{id}',
|
||||
summary='获取文本文件内容',
|
||||
@@ -243,7 +300,7 @@ async def router_share_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: ShareCreateRequest,
|
||||
) -> ShareResponse:
|
||||
) -> CreateShareResponse:
|
||||
"""
|
||||
创建新分享
|
||||
|
||||
@@ -254,13 +311,19 @@ async def router_share_create(
|
||||
2. 生成随机分享码(uuid4)
|
||||
3. 如果有密码则加密存储
|
||||
4. 创建 Share 记录并保存
|
||||
5. 返回分享信息
|
||||
5. 返回分享 ID
|
||||
"""
|
||||
# 验证对象存在且属于当前用户
|
||||
obj = await Object.get(session, Object.id == request.object_id)
|
||||
# 验证对象存在且属于当前用户(排除已删除的)
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == request.object_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if not obj or obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 生成分享码
|
||||
code = str(uuid4())
|
||||
|
||||
@@ -270,11 +333,12 @@ async def router_share_create(
|
||||
hashed_password = Password.hash(request.password)
|
||||
|
||||
# 创建分享记录
|
||||
user_id = user.id
|
||||
share = Share(
|
||||
code=code,
|
||||
password=hashed_password,
|
||||
object_id=request.object_id,
|
||||
user_id=user.id,
|
||||
user_id=user_id,
|
||||
expires=request.expires,
|
||||
remain_downloads=request.remain_downloads,
|
||||
preview_enabled=request.preview_enabled,
|
||||
@@ -283,24 +347,9 @@ async def router_share_create(
|
||||
)
|
||||
share = await share.save(session)
|
||||
|
||||
l.info(f"用户 {user.id} 创建分享: {share.code}")
|
||||
l.info(f"用户 {user_id} 创建分享: {share.code}")
|
||||
|
||||
# 返回响应
|
||||
return ShareResponse(
|
||||
id=share.id,
|
||||
code=share.code,
|
||||
object_id=share.object_id,
|
||||
source_name=share.source_name,
|
||||
views=share.views,
|
||||
downloads=share.downloads,
|
||||
remain_downloads=share.remain_downloads,
|
||||
expires=share.expires,
|
||||
preview_enabled=share.preview_enabled,
|
||||
score=share.score,
|
||||
created_at=share.created_at,
|
||||
is_expired=share.expires is not None and share.expires < datetime.now(),
|
||||
has_password=share.password is not None,
|
||||
)
|
||||
return CreateShareResponse(share_id=share.id)
|
||||
|
||||
@share_router.get(
|
||||
path='/',
|
||||
@@ -420,16 +469,29 @@ def router_share_update(id: str) -> ResponseBase:
|
||||
path='/{id}',
|
||||
summary='删除分享',
|
||||
description='Delete a share by ID.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
status_code=204,
|
||||
)
|
||||
def router_share_delete(id: str) -> ResponseBase:
|
||||
async def router_share_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
Delete a share by ID.
|
||||
删除分享
|
||||
|
||||
Args:
|
||||
id (str): The ID of the share to be deleted.
|
||||
认证:需要 JWT token
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deleted share.
|
||||
流程:
|
||||
1. 通过分享ID查找分享
|
||||
2. 验证分享属于当前用户
|
||||
3. 删除分享记录
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
share = await Share.get_exist_one(session, id)
|
||||
if share.user_id != user.id:
|
||||
http_exceptions.raise_forbidden(detail="无权删除此分享")
|
||||
|
||||
user_id = user.id
|
||||
share_code = share.code
|
||||
await Share.delete(session, share)
|
||||
|
||||
l.info(f"用户 {user_id} 删除了分享: {share_code}")
|
||||
@@ -1,7 +1,13 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels import (
|
||||
ResponseBase, Setting, SettingsType, SiteConfigResponse,
|
||||
ThemePreset, ThemePresetResponse, ThemePresetListResponse,
|
||||
AuthMethodConfig,
|
||||
)
|
||||
from sqlmodels.auth_identity import AuthProviderType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
site_router = APIRouter(
|
||||
@@ -40,19 +46,85 @@ def router_site_captcha():
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@site_router.get(
|
||||
path='/themes',
|
||||
summary='获取主题预设列表',
|
||||
)
|
||||
async def router_site_themes(session: SessionDep) -> ThemePresetListResponse:
|
||||
"""
|
||||
获取所有主题预设列表
|
||||
|
||||
无需认证,前端初始化时调用。
|
||||
"""
|
||||
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
|
||||
return ThemePresetListResponse(
|
||||
themes=[ThemePresetResponse.from_preset(p) for p in presets]
|
||||
)
|
||||
|
||||
|
||||
@site_router.get(
|
||||
path='/config',
|
||||
summary='站点全局配置',
|
||||
description='Get the configuration file.',
|
||||
response_model=ResponseBase,
|
||||
description='获取站点全局配置,包括验证码设置、注册开关等。',
|
||||
)
|
||||
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
"""
|
||||
Get the configuration file.
|
||||
获取站点全局配置
|
||||
|
||||
Returns:
|
||||
dict: The site configuration.
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
|
||||
"""
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA) |
|
||||
(Setting.type == SettingsType.AUTH) |
|
||||
(Setting.type == SettingsType.OAUTH) |
|
||||
(Setting.type == SettingsType.AVATAR),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
# 构建 name→value 映射
|
||||
s: dict[str, str | None] = {item.name: item.value for item in settings}
|
||||
|
||||
# 根据 captcha_type 选择对应的 public key
|
||||
captcha_type_str = s.get("captcha_type", "default")
|
||||
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
|
||||
captcha_key: str | None = None
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
captcha_key = s.get("captcha_ReCaptchaKey") or None
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
# 构建认证方式列表
|
||||
auth_methods: list[AuthMethodConfig] = [
|
||||
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
|
||||
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
|
||||
]
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
logo_light=s.get("logo_light") or None,
|
||||
logo_dark=s.get("logo_dark") or None,
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
login_captcha=s.get("login_captcha") == "1",
|
||||
reg_captcha=s.get("reg_captcha") == "1",
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
auth_methods=auth_methods,
|
||||
password_required=s.get("auth_password_required") == "1",
|
||||
phone_binding_required=s.get("auth_phone_binding_required") == "1",
|
||||
email_binding_required=s.get("auth_email_binding_required") == "1",
|
||||
avatar_max_size=int(s["avatar_size"]),
|
||||
footer_code=s.get("footer_code"),
|
||||
tos_url=s.get("tos_url"),
|
||||
privacy_url=s.get("privacy_url"),
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
slave_router = APIRouter(
|
||||
@@ -20,15 +20,15 @@ slave_aria2_router = APIRouter(
|
||||
summary='测试用路由',
|
||||
description='Test route for checking connectivity.',
|
||||
)
|
||||
def router_slave_ping() -> ResponseBase:
|
||||
def router_slave_ping() -> str:
|
||||
"""
|
||||
Test route for checking connectivity.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A response model indicating success.
|
||||
str: 后端版本号
|
||||
"""
|
||||
from utils.conf.appmeta import BackendVersion
|
||||
return ResponseBase(data=BackendVersion)
|
||||
return BackendVersion
|
||||
|
||||
@slave_router.post(
|
||||
path='/post',
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import auth_required
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
tag_router = APIRouter(
|
||||
|
||||
161
routers/api/v1/trash/__init__.py
Normal file
161
routers/api/v1/trash/__init__.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
回收站路由
|
||||
|
||||
提供回收站管理功能:列出、恢复、永久删除、清空。
|
||||
|
||||
路由前缀:/trash
|
||||
"""
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import Object, User
|
||||
from sqlmodels.object import TrashDeleteRequest, TrashItemResponse, TrashRestoreRequest
|
||||
from service.storage.object import (
|
||||
permanently_delete_objects,
|
||||
restore_objects,
|
||||
soft_delete_objects,
|
||||
)
|
||||
|
||||
trash_router = APIRouter(
|
||||
prefix="/trash",
|
||||
tags=["trash"],
|
||||
)
|
||||
|
||||
|
||||
@trash_router.get(
|
||||
path='/',
|
||||
summary='列出回收站内容',
|
||||
description='获取当前用户回收站中的所有顶层对象。',
|
||||
)
|
||||
async def router_trash_list(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[TrashItemResponse]:
|
||||
"""
|
||||
列出回收站内容
|
||||
|
||||
认证:需要 JWT token
|
||||
|
||||
返回回收站中被直接删除的顶层对象列表,
|
||||
不包含其子对象(子对象在恢复/永久删除时会随顶层对象一起处理)。
|
||||
"""
|
||||
items = await Object.get_trash_items(session, user.id)
|
||||
|
||||
return [
|
||||
TrashItemResponse(
|
||||
id=item.id,
|
||||
name=item.name,
|
||||
type=item.type,
|
||||
size=item.size,
|
||||
deleted_at=item.deleted_at,
|
||||
original_parent_id=item.deleted_original_parent_id,
|
||||
)
|
||||
for item in items
|
||||
]
|
||||
|
||||
|
||||
@trash_router.patch(
|
||||
path='/restore',
|
||||
summary='恢复对象',
|
||||
description='从回收站恢复一个或多个对象到原位置。如果原位置不存在则恢复到根目录。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_trash_restore(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: TrashRestoreRequest,
|
||||
) -> None:
|
||||
"""
|
||||
从回收站恢复对象
|
||||
|
||||
认证:需要 JWT token
|
||||
|
||||
流程:
|
||||
1. 验证对象存在且在回收站中(deleted_at IS NOT NULL)
|
||||
2. 检查原父目录是否存在且未删除
|
||||
3. 存在 → 恢复到原位置;不存在 → 恢复到根目录
|
||||
4. 处理同名冲突(自动重命名)
|
||||
5. 清除 deleted_at 和 deleted_original_parent_id
|
||||
"""
|
||||
user_id = user.id
|
||||
objects_to_restore: list[Object] = []
|
||||
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
|
||||
)
|
||||
if obj:
|
||||
objects_to_restore.append(obj)
|
||||
|
||||
if objects_to_restore:
|
||||
restored_count = await restore_objects(session, objects_to_restore, user_id)
|
||||
l.info(f"用户 {user_id} 从回收站恢复了 {restored_count} 个对象")
|
||||
|
||||
|
||||
@trash_router.delete(
|
||||
path='/',
|
||||
summary='永久删除对象',
|
||||
description='永久删除回收站中的指定对象,包括物理文件和数据库记录。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_trash_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: TrashDeleteRequest,
|
||||
) -> None:
|
||||
"""
|
||||
永久删除回收站中的对象
|
||||
|
||||
认证:需要 JWT token
|
||||
|
||||
流程:
|
||||
1. 验证对象存在且在回收站中
|
||||
2. BFS 收集所有子文件的 PhysicalFile
|
||||
3. 处理引用计数,引用为 0 时物理删除文件
|
||||
4. 硬删除根 Object(CASCADE 清理子对象)
|
||||
5. 更新用户存储配额
|
||||
"""
|
||||
user_id = user.id
|
||||
objects_to_delete: list[Object] = []
|
||||
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(
|
||||
session,
|
||||
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
|
||||
)
|
||||
if obj:
|
||||
objects_to_delete.append(obj)
|
||||
|
||||
if objects_to_delete:
|
||||
deleted_count = await permanently_delete_objects(session, objects_to_delete, user_id)
|
||||
l.info(f"用户 {user_id} 永久删除了 {deleted_count} 个对象")
|
||||
|
||||
|
||||
@trash_router.delete(
|
||||
path='/empty',
|
||||
summary='清空回收站',
|
||||
description='永久删除回收站中的所有对象。',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_trash_empty(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> None:
|
||||
"""
|
||||
清空回收站
|
||||
|
||||
认证:需要 JWT token
|
||||
|
||||
获取回收站中所有顶层对象,逐个执行永久删除。
|
||||
"""
|
||||
user_id = user.id
|
||||
trash_items = await Object.get_trash_items(session, user_id)
|
||||
|
||||
if trash_items:
|
||||
deleted_count = await permanently_delete_objects(session, trash_items, user_id)
|
||||
l.info(f"用户 {user_id} 清空回收站,共删除 {deleted_count} 个对象")
|
||||
File diff suppressed because it is too large
Load Diff
692
routers/api/v1/user/settings/__init__.py
Normal file
692
routers/api/v1/user/settings/__init__.py
Normal file
@@ -0,0 +1,692 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
|
||||
SettingOption, UserSettingUpdateRequest,
|
||||
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
|
||||
ChangePasswordRequest,
|
||||
AuthnDetailResponse, AuthnRenameRequest,
|
||||
PolicySummary,
|
||||
)
|
||||
from sqlmodels.color import ThemeColorsBase
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from utils.password.pwd import PasswordStatus, TwoFactorResponse, TwoFactorVerifyRequest
|
||||
from .file_viewers import file_viewers_router
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
user_settings_router.include_router(file_viewers_router)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
)
|
||||
async def router_user_settings_policies(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[PolicySummary]:
|
||||
"""
|
||||
获取当前用户所在组可选的存储策略列表
|
||||
|
||||
返回用户组关联的所有存储策略的摘要信息。
|
||||
"""
|
||||
group = await user.awaitable_attrs.group
|
||||
await session.refresh(group, ['policies'])
|
||||
return [
|
||||
PolicySummary(
|
||||
id=p.id, name=p.name, type=p.type,
|
||||
server=p.server, max_size=p.max_size, is_private=p.is_private,
|
||||
)
|
||||
for p in group.policies
|
||||
]
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
)
|
||||
async def router_user_settings(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserSettingResponse:
|
||||
"""
|
||||
获取当前用户设定
|
||||
|
||||
主题颜色合并策略:
|
||||
1. 用户有颜色快照(7个字段均有值)→ 直接使用快照
|
||||
2. 否则查找默认预设 → 使用默认预设颜色
|
||||
3. 无默认预设 → 使用内置默认值
|
||||
"""
|
||||
# 计算主题颜色
|
||||
has_snapshot = all([
|
||||
user.color_primary, user.color_secondary, user.color_success,
|
||||
user.color_info, user.color_warning, user.color_error, user.color_neutral,
|
||||
])
|
||||
if has_snapshot:
|
||||
theme_colors = ThemeColorsBase(
|
||||
primary=user.color_primary,
|
||||
secondary=user.color_secondary,
|
||||
success=user.color_success,
|
||||
info=user.color_info,
|
||||
warning=user.color_warning,
|
||||
error=user.color_error,
|
||||
neutral=user.color_neutral,
|
||||
)
|
||||
else:
|
||||
default_preset: ThemePreset | None = await ThemePreset.get(
|
||||
session, ThemePreset.is_default == True # noqa: E712
|
||||
)
|
||||
if default_preset:
|
||||
theme_colors = ThemeColorsBase(
|
||||
primary=default_preset.primary,
|
||||
secondary=default_preset.secondary,
|
||||
success=default_preset.success,
|
||||
info=default_preset.info,
|
||||
warning=default_preset.warning,
|
||||
error=default_preset.error,
|
||||
neutral=default_preset.neutral,
|
||||
)
|
||||
else:
|
||||
theme_colors = BUILTIN_DEFAULT_COLORS
|
||||
|
||||
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
|
||||
has_two_factor = False
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if email_identity and email_identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data)
|
||||
has_two_factor = bool(extra.get("two_factor"))
|
||||
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
phone=user.phone,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=has_two_factor,
|
||||
theme_preset_id=user.theme_preset_id,
|
||||
theme_colors=theme_colors,
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_settings_avatar(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
file: UploadFile = File(...),
|
||||
) -> None:
|
||||
"""
|
||||
上传头像文件
|
||||
|
||||
认证:JWT token
|
||||
请求体:multipart/form-data,file 字段
|
||||
|
||||
流程:
|
||||
1. 验证文件 MIME 类型(JPEG/PNG/GIF/WebP)
|
||||
2. 验证文件大小 <= avatar_size 设置(默认 2MB)
|
||||
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S)
|
||||
4. 保存三种尺寸的 WebP 文件
|
||||
5. 更新 User.avatar = "file"
|
||||
|
||||
错误处理:
|
||||
- 400: 文件类型不支持 / 图片无法解析
|
||||
- 413: 文件过大
|
||||
"""
|
||||
from service.avatar import (
|
||||
ALLOWED_CONTENT_TYPES,
|
||||
get_avatar_settings,
|
||||
process_and_save_avatar,
|
||||
)
|
||||
|
||||
# 验证 MIME 类型
|
||||
if file.content_type not in ALLOWED_CONTENT_TYPES:
|
||||
http_exceptions.raise_bad_request(
|
||||
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
|
||||
)
|
||||
|
||||
# 读取并验证大小
|
||||
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
|
||||
raw_bytes = await file.read()
|
||||
if len(raw_bytes) > max_upload_size:
|
||||
raise HTTPException(
|
||||
status_code=413,
|
||||
detail=f"文件过大,最大允许 {max_upload_size} 字节",
|
||||
)
|
||||
|
||||
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError)
|
||||
try:
|
||||
await process_and_save_avatar(session, user.id, raw_bytes)
|
||||
except ValueError as e:
|
||||
http_exceptions.raise_bad_request(str(e))
|
||||
|
||||
# 更新用户头像字段
|
||||
user.avatar = "file"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为 Gravatar 头像',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_settings_avatar_gravatar(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> None:
|
||||
"""
|
||||
将头像切换为 Gravatar
|
||||
|
||||
认证:JWT token
|
||||
|
||||
流程:
|
||||
1. 验证用户有邮箱(Gravatar 基于邮箱 MD5)
|
||||
2. 如果当前是 FILE 头像,删除本地文件
|
||||
3. 更新 User.avatar = "gravatar"
|
||||
|
||||
错误处理:
|
||||
- 400: 用户没有邮箱
|
||||
"""
|
||||
from service.avatar import delete_avatar_files
|
||||
|
||||
if not user.email:
|
||||
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
|
||||
|
||||
if user.avatar == "file":
|
||||
await delete_avatar_files(session, user.id)
|
||||
|
||||
user.avatar = "gravatar"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/avatar',
|
||||
summary='重置头像为默认',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_settings_avatar_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> None:
|
||||
"""
|
||||
重置头像为默认
|
||||
|
||||
认证:JWT token
|
||||
|
||||
流程:
|
||||
1. 如果当前是 FILE 头像,删除本地文件
|
||||
2. 更新 User.avatar = "default"
|
||||
"""
|
||||
from service.avatar import delete_avatar_files
|
||||
|
||||
if user.avatar == "file":
|
||||
await delete_avatar_files(session, user.id)
|
||||
|
||||
user.avatar = "default"
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/theme',
|
||||
summary='更新用户主题设置',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_theme(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: UserThemeUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新用户主题设置
|
||||
|
||||
请求体(均可选):
|
||||
- theme_preset_id: 主题预设UUID
|
||||
- theme_colors: 颜色配置对象(写入颜色快照)
|
||||
|
||||
错误处理:
|
||||
- 404: 指定的主题预设不存在
|
||||
"""
|
||||
# 验证 preset_id 存在性
|
||||
if request.theme_preset_id is not None:
|
||||
preset: ThemePreset | None = await ThemePreset.get(
|
||||
session, ThemePreset.id == request.theme_preset_id
|
||||
)
|
||||
if not preset:
|
||||
http_exceptions.raise_not_found("主题预设不存在")
|
||||
user.theme_preset_id = request.theme_preset_id
|
||||
|
||||
# 将颜色解构到快照列
|
||||
if request.theme_colors is not None:
|
||||
user.color_primary = request.theme_colors.primary
|
||||
user.color_secondary = request.theme_colors.secondary
|
||||
user.color_success = request.theme_colors.success
|
||||
user.color_info = request.theme_colors.info
|
||||
user.color_warning = request.theme_colors.warning
|
||||
user.color_error = request.theme_colors.error
|
||||
user.color_neutral = request.theme_colors.neutral
|
||||
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/password',
|
||||
summary='修改密码',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_change_password(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: ChangePasswordRequest,
|
||||
) -> None:
|
||||
"""
|
||||
修改当前用户密码
|
||||
|
||||
请求体:
|
||||
- old_password: 当前密码
|
||||
- new_password: 新密码(至少 8 位)
|
||||
|
||||
错误处理:
|
||||
- 400: 用户没有邮箱密码认证身份
|
||||
- 403: 当前密码错误
|
||||
"""
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if not email_identity or not email_identity.credential:
|
||||
http_exceptions.raise_bad_request("未找到邮箱密码认证身份")
|
||||
|
||||
verify_result = Password.verify(email_identity.credential, request.old_password)
|
||||
if verify_result == PasswordStatus.INVALID:
|
||||
http_exceptions.raise_forbidden("当前密码错误")
|
||||
|
||||
email_identity.credential = Password.hash(request.new_password)
|
||||
email_identity = await email_identity.save(session)
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_patch(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
option: SettingOption,
|
||||
request: UserSettingUpdateRequest,
|
||||
) -> None:
|
||||
"""
|
||||
更新单个用户设置项
|
||||
|
||||
路径参数:
|
||||
- option: 设置项名称(nickname / language / timezone)
|
||||
|
||||
请求体:
|
||||
- 包含与 option 同名的字段及其新值
|
||||
|
||||
错误处理:
|
||||
- 422: 无效的 option 或字段值不符合约束
|
||||
- 400: 必填字段值缺失
|
||||
"""
|
||||
value = getattr(request, option.value)
|
||||
|
||||
# language / timezone 不允许设为 null
|
||||
if value is None and option != SettingOption.NICKNAME:
|
||||
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
|
||||
|
||||
setattr(user, option.value, value)
|
||||
user = await user.save(session)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> TwoFactorResponse:
|
||||
"""
|
||||
获取两步验证初始化信息
|
||||
|
||||
返回 setup_token(用于后续验证请求)和 uri(用于生成二维码)。
|
||||
"""
|
||||
return await Password.generate_totp(name=user.email or str(user.id))
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: TwoFactorVerifyRequest,
|
||||
) -> None:
|
||||
"""
|
||||
启用两步验证
|
||||
|
||||
将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
secret = serializer.loads(request.setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 将 secret 存储到 AuthIdentity.extra_data 中
|
||||
email_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.user_id == user.id)
|
||||
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
|
||||
)
|
||||
if not email_identity:
|
||||
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
|
||||
|
||||
import orjson
|
||||
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
|
||||
extra["two_factor"] = secret
|
||||
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
|
||||
email_identity = await email_identity.save(session)
|
||||
|
||||
|
||||
# ==================== 认证身份管理 ====================
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/identities',
|
||||
summary='列出已绑定的认证身份',
|
||||
)
|
||||
async def router_user_settings_identities(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[AuthIdentityResponse]:
|
||||
"""
|
||||
列出当前用户已绑定的所有认证身份
|
||||
|
||||
返回:
|
||||
- 认证身份列表,包含 provider、identifier、display_name 等
|
||||
"""
|
||||
identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [identity.to_response() for identity in identities]
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/identity',
|
||||
summary='绑定新的认证身份',
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def router_user_settings_bind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
request: BindIdentityRequest,
|
||||
) -> AuthIdentityResponse:
|
||||
"""
|
||||
绑定新的登录方式
|
||||
|
||||
请求体:
|
||||
- provider: 提供者类型
|
||||
- identifier: 标识符(邮箱 / 手机号 / OAuth code)
|
||||
- credential: 凭证(密码、验证码等)
|
||||
- redirect_uri: OAuth 回调地址(可选)
|
||||
|
||||
错误处理:
|
||||
- 400: provider 未启用
|
||||
- 409: 该身份已被其他用户绑定
|
||||
"""
|
||||
# 检查是否已被绑定
|
||||
existing = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == request.provider)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="该身份已被绑定")
|
||||
|
||||
# 处理密码类型的凭证
|
||||
credential: str | None = None
|
||||
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
|
||||
credential = Password.hash(request.credential)
|
||||
|
||||
identity = AuthIdentity(
|
||||
provider=request.provider,
|
||||
identifier=request.identifier,
|
||||
credential=credential,
|
||||
is_primary=False,
|
||||
is_verified=False,
|
||||
user_id=user.id,
|
||||
)
|
||||
identity = await identity.save(session)
|
||||
return identity.to_response()
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/identity/{identity_id}',
|
||||
summary='解绑认证身份',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_unbind_identity(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
identity_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
解绑一个认证身份
|
||||
|
||||
约束:
|
||||
- 不能解绑最后一个身份
|
||||
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
|
||||
|
||||
错误处理:
|
||||
- 404: 身份不存在或不属于当前用户
|
||||
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
|
||||
"""
|
||||
# 查找目标身份
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_not_found("认证身份不存在")
|
||||
|
||||
# 检查是否为最后一个身份
|
||||
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
if len(all_identities) <= 1:
|
||||
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
|
||||
|
||||
# 检查强制绑定约束
|
||||
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
|
||||
email_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_email_binding_required"),
|
||||
)
|
||||
if email_required_setting and email_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
|
||||
|
||||
if identity.provider == AuthProviderType.PHONE_SMS:
|
||||
phone_required_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
|
||||
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
|
||||
)
|
||||
if phone_required_setting and phone_required_setting.value == "1":
|
||||
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
|
||||
|
||||
await AuthIdentity.delete(session, identity)
|
||||
|
||||
|
||||
# ==================== WebAuthn 凭证管理 ====================
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/authns',
|
||||
summary='列出用户所有 WebAuthn 凭证',
|
||||
)
|
||||
async def router_user_settings_authns(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> list[AuthnDetailResponse]:
|
||||
"""
|
||||
列出当前用户所有已注册的 WebAuthn 凭证
|
||||
|
||||
返回:
|
||||
- 凭证列表,包含 credential_id、name、device_type 等
|
||||
"""
|
||||
authns: list[UserAuthn] = await UserAuthn.get(
|
||||
session,
|
||||
UserAuthn.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [authn.to_detail_response() for authn in authns]
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/authn/{authn_id}',
|
||||
summary='重命名 WebAuthn 凭证',
|
||||
)
|
||||
async def router_user_settings_rename_authn(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
authn_id: int,
|
||||
request: AuthnRenameRequest,
|
||||
) -> AuthnDetailResponse:
|
||||
"""
|
||||
重命名一个 WebAuthn 凭证
|
||||
|
||||
错误处理:
|
||||
- 404: 凭证不存在或不属于当前用户
|
||||
"""
|
||||
authn: UserAuthn | None = await UserAuthn.get(
|
||||
session,
|
||||
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
|
||||
)
|
||||
if not authn:
|
||||
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
|
||||
|
||||
authn.name = request.name
|
||||
authn = await authn.save(session)
|
||||
return authn.to_detail_response()
|
||||
|
||||
|
||||
@user_settings_router.delete(
|
||||
path='/authn/{authn_id}',
|
||||
summary='删除 WebAuthn 凭证',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def router_user_settings_delete_authn(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
authn_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
删除一个 WebAuthn 凭证
|
||||
|
||||
同时删除对应的 AuthIdentity(provider=passkey) 记录。
|
||||
如果这是用户最后一个认证身份,拒绝删除。
|
||||
|
||||
错误处理:
|
||||
- 404: 凭证不存在或不属于当前用户
|
||||
- 400: 不能删除最后一个认证身份
|
||||
"""
|
||||
authn: UserAuthn | None = await UserAuthn.get(
|
||||
session,
|
||||
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
|
||||
)
|
||||
if not authn:
|
||||
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
|
||||
|
||||
# 检查是否为最后一个认证身份
|
||||
all_identities: list[AuthIdentity] = await AuthIdentity.get(
|
||||
session,
|
||||
AuthIdentity.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
if len(all_identities) <= 1:
|
||||
http_exceptions.raise_bad_request("不能删除最后一个认证身份")
|
||||
|
||||
# 删除对应的 AuthIdentity
|
||||
passkey_identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.PASSKEY)
|
||||
& (AuthIdentity.identifier == authn.credential_id)
|
||||
& (AuthIdentity.user_id == user.id),
|
||||
)
|
||||
if passkey_identity:
|
||||
await AuthIdentity.delete(session, passkey_identity, commit=False)
|
||||
|
||||
# 删除 UserAuthn
|
||||
await UserAuthn.delete(session, authn)
|
||||
146
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
146
routers/api/v1/user/settings/file_viewers/__init__.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""
|
||||
用户文件查看器偏好设置端点
|
||||
|
||||
提供用户"始终使用"默认查看器的增删查功能。
|
||||
"""
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
FileApp,
|
||||
FileAppExtension,
|
||||
SetDefaultViewerRequest,
|
||||
User,
|
||||
UserFileAppDefault,
|
||||
UserFileAppDefaultResponse,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
file_viewers_router = APIRouter(
|
||||
prefix='/file-viewers',
|
||||
tags=["user", "user_settings", "file_viewers"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
|
||||
@file_viewers_router.put(
|
||||
path='/default',
|
||||
summary='设置默认查看器',
|
||||
description='为指定扩展名设置"始终使用"的查看器。',
|
||||
)
|
||||
async def set_default_viewer(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: SetDefaultViewerRequest,
|
||||
) -> UserFileAppDefaultResponse:
|
||||
"""
|
||||
设置默认查看器端点
|
||||
|
||||
如果用户已有该扩展名的默认设置,则更新;否则创建新记录。
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 应用不存在
|
||||
- 400: 应用不支持该扩展名
|
||||
"""
|
||||
# 规范化扩展名
|
||||
normalized_ext = request.extension.lower().strip().lstrip('.')
|
||||
|
||||
# 验证应用存在
|
||||
app: FileApp | None = await FileApp.get(session, FileApp.id == request.app_id)
|
||||
if not app:
|
||||
http_exceptions.raise_not_found("应用不存在")
|
||||
|
||||
# 验证应用支持该扩展名
|
||||
ext_record: FileAppExtension | None = await FileAppExtension.get(
|
||||
session,
|
||||
and_(
|
||||
FileAppExtension.app_id == app.id,
|
||||
FileAppExtension.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
if not ext_record:
|
||||
http_exceptions.raise_bad_request("该应用不支持此扩展名")
|
||||
|
||||
# 查找已有记录
|
||||
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
UserFileAppDefault.extension == normalized_ext,
|
||||
),
|
||||
)
|
||||
|
||||
if existing:
|
||||
existing.app_id = request.app_id
|
||||
existing = await existing.save(session, load=UserFileAppDefault.app)
|
||||
return existing.to_response()
|
||||
else:
|
||||
new_default = UserFileAppDefault(
|
||||
user_id=user.id,
|
||||
extension=normalized_ext,
|
||||
app_id=request.app_id,
|
||||
)
|
||||
new_default = await new_default.save(session, load=UserFileAppDefault.app)
|
||||
return new_default.to_response()
|
||||
|
||||
|
||||
@file_viewers_router.get(
|
||||
path='/defaults',
|
||||
summary='列出所有默认查看器设置',
|
||||
description='获取当前用户所有"始终使用"的查看器偏好。',
|
||||
)
|
||||
async def list_default_viewers(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[UserFileAppDefaultResponse]:
|
||||
"""
|
||||
列出所有默认查看器设置端点
|
||||
|
||||
认证:JWT token 必填
|
||||
"""
|
||||
defaults: list[UserFileAppDefault] = await UserFileAppDefault.get(
|
||||
session,
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
fetch_mode="all",
|
||||
load=UserFileAppDefault.app,
|
||||
)
|
||||
return [d.to_response() for d in defaults]
|
||||
|
||||
|
||||
@file_viewers_router.delete(
|
||||
path='/default/{default_id}',
|
||||
summary='撤销默认查看器设置',
|
||||
description='删除指定的"始终使用"偏好。',
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_default_viewer(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
default_id: UUID,
|
||||
) -> None:
|
||||
"""
|
||||
撤销默认查看器设置端点
|
||||
|
||||
认证:JWT token 必填
|
||||
|
||||
错误处理:
|
||||
- 404: 记录不存在或不属于当前用户
|
||||
"""
|
||||
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
|
||||
session,
|
||||
and_(
|
||||
UserFileAppDefault.id == default_id,
|
||||
UserFileAppDefault.user_id == user.id,
|
||||
),
|
||||
)
|
||||
if not existing:
|
||||
http_exceptions.raise_not_found("默认设置不存在")
|
||||
|
||||
await UserFileAppDefault.delete(session, existing)
|
||||
@@ -1,106 +0,0 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
prefix="/vas",
|
||||
tags=["vas"]
|
||||
)
|
||||
|
||||
@vas_router.get(
|
||||
path='/pack',
|
||||
summary='获取容量包及配额信息',
|
||||
description='Get information about storage packs and quotas.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_pack() -> ResponseBase:
|
||||
"""
|
||||
Get information about storage packs and quotas.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for storage packs and quotas.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/product',
|
||||
summary='获取商品信息,同时返回支付信息',
|
||||
description='Get product information along with payment details.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_product() -> ResponseBase:
|
||||
"""
|
||||
Get product information along with payment details.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for products and payment information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/order',
|
||||
summary='新建支付订单',
|
||||
description='Create an order for a product.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order() -> ResponseBase:
|
||||
"""
|
||||
Create an order for a product.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created order.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/order/{id}',
|
||||
summary='查询订单状态',
|
||||
description='Get information about a specific payment order by ID.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_order_get(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get information about a specific payment order by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the order to retrieve information for.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified order.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.get(
|
||||
path='/redeem',
|
||||
summary='获取兑换码信息',
|
||||
description='Get information about a specific redemption code.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem(code: str) -> ResponseBase:
|
||||
"""
|
||||
Get information about a specific redemption code.
|
||||
|
||||
Args:
|
||||
code (str): The redemption code to retrieve information for.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the specified redemption code.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@vas_router.post(
|
||||
path='/redeem',
|
||||
summary='执行兑换',
|
||||
description='Redeem a redemption code for a product or service.',
|
||||
dependencies=[Depends(auth_required)]
|
||||
)
|
||||
def router_vas_redeem_post() -> ResponseBase:
|
||||
"""
|
||||
Redeem a redemption code for a product or service.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the redeemed code.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
@@ -1,110 +1,207 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
User,
|
||||
WebDAV,
|
||||
WebDAVAccountResponse,
|
||||
WebDAVCreateRequest,
|
||||
WebDAVUpdateRequest,
|
||||
)
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
# WebDAV 管理路由
|
||||
webdav_router = APIRouter(
|
||||
prefix='/webdav',
|
||||
tags=["webdav"],
|
||||
)
|
||||
|
||||
|
||||
def _check_webdav_enabled(user: User) -> None:
|
||||
"""检查用户组是否启用了 WebDAV 功能"""
|
||||
if not user.group.web_dav_enabled:
|
||||
http_exceptions.raise_forbidden("WebDAV 功能未启用")
|
||||
|
||||
|
||||
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
|
||||
"""将 WebDAV 数据库模型转换为响应 DTO"""
|
||||
return WebDAVAccountResponse(
|
||||
id=account.id,
|
||||
name=account.name,
|
||||
root=account.root,
|
||||
readonly=account.readonly,
|
||||
use_proxy=account.use_proxy,
|
||||
created_at=str(account.created_at),
|
||||
updated_at=str(account.updated_at),
|
||||
)
|
||||
|
||||
|
||||
@webdav_router.get(
|
||||
path='/accounts',
|
||||
summary='获取账号信息',
|
||||
description='Get account information for WebDAV.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='获取账号列表',
|
||||
)
|
||||
def router_webdav_accounts() -> ResponseBase:
|
||||
async def list_accounts(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> list[WebDAVAccountResponse]:
|
||||
"""
|
||||
Get account information for WebDAV.
|
||||
列出当前用户所有 WebDAV 账户
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the account information.
|
||||
认证:JWT Bearer Token
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
accounts: list[WebDAV] = await WebDAV.get(
|
||||
session,
|
||||
WebDAV.user_id == user_id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
return [_to_response(a) for a in accounts]
|
||||
|
||||
|
||||
@webdav_router.post(
|
||||
path='/accounts',
|
||||
summary='新建账号',
|
||||
description='Create a new WebDAV account.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
summary='创建账号',
|
||||
status_code=201,
|
||||
)
|
||||
def router_webdav_create_account() -> ResponseBase:
|
||||
async def create_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: WebDAVCreateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Create a new WebDAV account.
|
||||
创建 WebDAV 账户
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created account.
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
- 409: 账户名已存在
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{id}',
|
||||
summary='删除账号',
|
||||
description='Delete a WebDAV account by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_account(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV account by its ID.
|
||||
# 验证账户名唯一
|
||||
existing = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == request.name) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if existing:
|
||||
http_exceptions.raise_conflict("账户名已存在")
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be deleted.
|
||||
# 验证 root 路径存在且为目录
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
# 创建账户
|
||||
account = WebDAV(
|
||||
name=request.name,
|
||||
password=Password.hash(request.password),
|
||||
root=request.root,
|
||||
readonly=request.readonly,
|
||||
use_proxy=request.use_proxy,
|
||||
user_id=user_id,
|
||||
)
|
||||
account = await account.save(session)
|
||||
|
||||
@webdav_router.post(
|
||||
path='/mount',
|
||||
summary='新建目录挂载',
|
||||
description='Create a new WebDAV mount point.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_create_mount() -> ResponseBase:
|
||||
"""
|
||||
Create a new WebDAV mount point.
|
||||
l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the created mount point.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/mount/{id}',
|
||||
summary='删除目录挂载',
|
||||
description='Delete a WebDAV mount point by its ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_webdav_delete_mount(id: str) -> ResponseBase:
|
||||
"""
|
||||
Delete a WebDAV mount point by its ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the mount point to be deleted.
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the deletion operation.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@webdav_router.patch(
|
||||
path='accounts/{id}',
|
||||
summary='更新账号信息',
|
||||
description='Update WebDAV account information by ID.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
path='/accounts/{account_id}',
|
||||
summary='更新账号',
|
||||
)
|
||||
def router_webdav_update_account(id: str) -> ResponseBase:
|
||||
async def update_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
request: WebDAVUpdateRequest,
|
||||
) -> WebDAVAccountResponse:
|
||||
"""
|
||||
Update WebDAV account information by ID.
|
||||
更新 WebDAV 账户
|
||||
|
||||
Args:
|
||||
id (str): The ID of the account to be updated.
|
||||
认证:JWT Bearer Token
|
||||
|
||||
Returns:
|
||||
ResponseBase: A model containing the response data for the updated account.
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
- 400: 根目录路径不存在或不是目录
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
# 验证 root 路径
|
||||
if request.root is not None:
|
||||
root_obj = await Object.get_by_path(session, user_id, request.root)
|
||||
if not root_obj or not root_obj.is_folder:
|
||||
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
|
||||
|
||||
# 密码哈希后原地替换,update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
|
||||
is_password_changed = request.password is not None
|
||||
if is_password_changed:
|
||||
request.password = Password.hash(request.password)
|
||||
|
||||
account = await account.update(session, request)
|
||||
|
||||
# 密码变更时清除认证缓存
|
||||
if is_password_changed:
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account.name)
|
||||
|
||||
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
|
||||
return _to_response(account)
|
||||
|
||||
|
||||
@webdav_router.delete(
|
||||
path='/accounts/{account_id}',
|
||||
summary='删除账号',
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_account(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
account_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
删除 WebDAV 账户
|
||||
|
||||
认证:JWT Bearer Token
|
||||
|
||||
错误处理:
|
||||
- 403: WebDAV 功能未启用
|
||||
- 404: 账户不存在
|
||||
"""
|
||||
_check_webdav_enabled(user)
|
||||
user_id: UUID = user.id
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
|
||||
)
|
||||
if not account:
|
||||
http_exceptions.raise_not_found("WebDAV 账户不存在")
|
||||
|
||||
account_name = account.name
|
||||
await WebDAV.delete(session, account)
|
||||
|
||||
# 清除认证缓存
|
||||
await WebDAVAuthCache.invalidate_account(user_id, account_name)
|
||||
|
||||
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# WebDAV 操作路由
|
||||
35
routers/dav/__init__.py
Normal file
35
routers/dav/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
WebDAV 协议入口
|
||||
|
||||
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
|
||||
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
|
||||
"""
|
||||
from a2wsgi import WSGIMiddleware
|
||||
from wsgidav.wsgidav_app import WsgiDAVApp
|
||||
|
||||
from .domain_controller import DiskNextDomainController
|
||||
from .provider import DiskNextDAVProvider
|
||||
|
||||
_wsgidav_config: dict[str, object] = {
|
||||
"provider_mapping": {
|
||||
"/": DiskNextDAVProvider(),
|
||||
},
|
||||
"http_authenticator": {
|
||||
"domain_controller": DiskNextDomainController,
|
||||
"accept_basic": True,
|
||||
"accept_digest": False,
|
||||
"default_to_digest": False,
|
||||
},
|
||||
"verbose": 1,
|
||||
# 使用 WsgiDAV 内置的内存锁管理器
|
||||
"lock_storage": True,
|
||||
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
|
||||
"dir_browser": {
|
||||
"enable": False,
|
||||
},
|
||||
}
|
||||
|
||||
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
|
||||
|
||||
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
|
||||
"""ASGI 应用,挂载到 /dav 路径"""
|
||||
148
routers/dav/domain_controller.py
Normal file
148
routers/dav/domain_controller.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
WebDAV 认证控制器
|
||||
|
||||
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
|
||||
通过 DiskNext 的 WebDAV 账户模型进行认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
"""
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dc.base_dc import BaseDomainController
|
||||
|
||||
from routers.dav.provider import EventLoopRef, _get_session
|
||||
from service.redis.webdav_auth_cache import WebDAVAuthCache
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from sqlmodels.webdav import WebDAV
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def _authenticate(
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
异步认证 WebDAV 用户。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 明文密码
|
||||
:return: (user_id, webdav_id) 或 None
|
||||
"""
|
||||
# 1. 查缓存
|
||||
cached = await WebDAVAuthCache.get(email, account_name, password)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. 缓存未命中,查库验证
|
||||
async with _get_session() as session:
|
||||
user = await User.get(session, User.email == email, load=User.group)
|
||||
if not user:
|
||||
return None
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
return None
|
||||
if not user.group.web_dav_enabled:
|
||||
return None
|
||||
|
||||
account = await WebDAV.get(
|
||||
session,
|
||||
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
|
||||
)
|
||||
if not account:
|
||||
return None
|
||||
|
||||
status = Password.verify(account.password, password)
|
||||
if status == PasswordStatus.INVALID:
|
||||
return None
|
||||
|
||||
user_id: UUID = user.id
|
||||
webdav_id: int = account.id
|
||||
|
||||
# 3. 写入缓存
|
||||
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
|
||||
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
class DiskNextDomainController(BaseDomainController):
|
||||
"""
|
||||
DiskNext WebDAV 认证控制器
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
密码: WebDAV 账户密码(创建账户时设置)
|
||||
"""
|
||||
|
||||
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
|
||||
super().__init__(wsgidav_app, config)
|
||||
|
||||
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
|
||||
"""返回 realm 名称"""
|
||||
return "DiskNext WebDAV"
|
||||
|
||||
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
|
||||
"""所有请求都需要认证"""
|
||||
return True
|
||||
|
||||
def is_share_anonymous(self, path_info: str) -> bool:
|
||||
"""不支持匿名访问"""
|
||||
return False
|
||||
|
||||
def supports_http_digest_auth(self) -> bool:
|
||||
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
|
||||
return False
|
||||
|
||||
def basic_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
password: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""
|
||||
HTTP Basic Auth 认证。
|
||||
|
||||
用户名格式: {email}/{webdav_account_name}
|
||||
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
|
||||
"""
|
||||
# 解析用户名
|
||||
if "/" not in user_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
email, account_name = user_name.split("/", 1)
|
||||
if not email or not account_name:
|
||||
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
|
||||
return False
|
||||
|
||||
# 在 WSGI 线程中调用异步认证
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
_authenticate(email, account_name, password),
|
||||
EventLoopRef.get(),
|
||||
)
|
||||
result = future.result()
|
||||
|
||||
if result is None:
|
||||
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
|
||||
return False
|
||||
|
||||
user_id, webdav_id = result
|
||||
|
||||
# 将认证信息存入 environ,供 Provider 使用
|
||||
environ["disknext.user_id"] = user_id
|
||||
environ["disknext.webdav_id"] = webdav_id
|
||||
environ["disknext.email"] = email
|
||||
environ["disknext.account_name"] = account_name
|
||||
|
||||
return True
|
||||
|
||||
def digest_auth_user(
|
||||
self,
|
||||
realm: str,
|
||||
user_name: str,
|
||||
environ: dict[str, object],
|
||||
) -> bool:
|
||||
"""不支持 Digest 认证"""
|
||||
return False
|
||||
645
routers/dav/provider.py
Normal file
645
routers/dav/provider.py
Normal file
@@ -0,0 +1,645 @@
|
||||
"""
|
||||
DiskNext WebDAV 存储 Provider
|
||||
|
||||
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
|
||||
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
|
||||
"""
|
||||
import asyncio
|
||||
import io
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from wsgidav.dav_error import (
|
||||
DAVError,
|
||||
HTTP_FORBIDDEN,
|
||||
HTTP_INSUFFICIENT_STORAGE,
|
||||
HTTP_NOT_FOUND,
|
||||
)
|
||||
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
|
||||
|
||||
from service.storage import LocalStorageService, adjust_user_storage
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.webdav import WebDAV
|
||||
|
||||
|
||||
class EventLoopRef:
|
||||
"""持有主线程事件循环引用,供 WSGI 线程使用"""
|
||||
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
|
||||
|
||||
@classmethod
|
||||
async def capture(cls) -> None:
|
||||
"""在 async 上下文中调用,捕获当前事件循环"""
|
||||
cls._loop = asyncio.get_running_loop()
|
||||
|
||||
@classmethod
|
||||
def get(cls) -> asyncio.AbstractEventLoop:
|
||||
if cls._loop is None:
|
||||
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
|
||||
return cls._loop
|
||||
|
||||
|
||||
def _run_async(coro): # type: ignore[no-untyped-def]
|
||||
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
|
||||
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_session(): # type: ignore[no-untyped-def]
|
||||
"""获取数据库会话上下文管理器"""
|
||||
return DatabaseManager._async_session_factory()
|
||||
|
||||
|
||||
# ==================== 异步辅助函数 ====================
|
||||
|
||||
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
|
||||
"""获取 WebDAV 账户"""
|
||||
async with _get_session() as session:
|
||||
return await WebDAV.get(session, WebDAV.id == webdav_id)
|
||||
|
||||
|
||||
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
|
||||
"""根据路径获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_by_path(session, user_id, path)
|
||||
|
||||
|
||||
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
|
||||
"""获取目录子对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get_children(session, user_id, parent_id)
|
||||
|
||||
|
||||
async def _get_object_by_id(object_id: UUID) -> Object | None:
|
||||
"""根据ID获取对象"""
|
||||
async with _get_session() as session:
|
||||
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
|
||||
|
||||
|
||||
async def _get_user(user_id: UUID) -> User | None:
|
||||
"""获取用户(含 group 关系)"""
|
||||
async with _get_session() as session:
|
||||
return await User.get(session, User.id == user_id, load=User.group)
|
||||
|
||||
|
||||
async def _get_policy(policy_id: UUID) -> Policy | None:
|
||||
"""获取存储策略"""
|
||||
async with _get_session() as session:
|
||||
return await Policy.get(session, Policy.id == policy_id)
|
||||
|
||||
|
||||
async def _create_folder(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建目录对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FOLDER,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _create_file(
|
||||
name: str,
|
||||
parent_id: UUID,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> Object:
|
||||
"""创建空文件对象"""
|
||||
async with _get_session() as session:
|
||||
obj = Object(
|
||||
name=name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
parent_id=parent_id,
|
||||
owner_id=owner_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
obj = await obj.save(session)
|
||||
return obj
|
||||
|
||||
|
||||
async def _soft_delete_object(object_id: UUID) -> None:
|
||||
"""软删除对象(移入回收站)"""
|
||||
from service.storage import soft_delete_objects
|
||||
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
await soft_delete_objects(session, [obj])
|
||||
|
||||
|
||||
async def _finalize_upload(
|
||||
object_id: UUID,
|
||||
physical_path: str,
|
||||
size: int,
|
||||
owner_id: UUID,
|
||||
policy_id: UUID,
|
||||
) -> None:
|
||||
"""上传完成后更新对象元数据和物理文件记录"""
|
||||
async with _get_session() as session:
|
||||
# 获取存储路径(相对路径)
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
base_path = Path(policy.server).resolve()
|
||||
full_path = Path(physical_path).resolve()
|
||||
storage_path = str(full_path.relative_to(base_path))
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
pf = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
pf = await pf.save(session)
|
||||
|
||||
# 更新 Object
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
|
||||
obj = await obj.save(session)
|
||||
|
||||
# 更新用户存储用量
|
||||
if size > 0:
|
||||
await adjust_user_storage(session, owner_id, size)
|
||||
|
||||
|
||||
async def _move_object(
|
||||
object_id: UUID,
|
||||
new_parent_id: UUID,
|
||||
new_name: str,
|
||||
) -> None:
|
||||
"""移动/重命名对象"""
|
||||
async with _get_session() as session:
|
||||
obj = await Object.get(session, Object.id == object_id)
|
||||
if obj:
|
||||
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
|
||||
obj = await obj.save(session)
|
||||
|
||||
|
||||
async def _copy_object_recursive(
|
||||
src_id: UUID,
|
||||
dst_parent_id: UUID,
|
||||
dst_name: str,
|
||||
owner_id: UUID,
|
||||
) -> None:
|
||||
"""递归复制对象"""
|
||||
from service.storage import copy_object_recursive
|
||||
|
||||
async with _get_session() as session:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
if not src:
|
||||
return
|
||||
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
|
||||
|
||||
|
||||
# ==================== 辅助工具 ====================
|
||||
|
||||
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
|
||||
"""从 environ 中提取认证信息"""
|
||||
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
|
||||
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
|
||||
return user_id, webdav_id
|
||||
|
||||
|
||||
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
|
||||
"""
|
||||
将 DAV 相对路径映射到 DiskNext 绝对路径。
|
||||
|
||||
:param account_root: 账户挂载根路径,如 "/" 或 "/docs"
|
||||
:param dav_path: DAV 请求路径,如 "/" 或 "/photos/cat.jpg"
|
||||
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
|
||||
"""
|
||||
# 规范化根路径
|
||||
root = account_root.rstrip("/")
|
||||
if not root:
|
||||
root = ""
|
||||
|
||||
# 规范化 DAV 路径
|
||||
if not dav_path or dav_path == "/":
|
||||
return root + "/" if root else "/"
|
||||
|
||||
if not dav_path.startswith("/"):
|
||||
dav_path = "/" + dav_path
|
||||
|
||||
full = root + dav_path
|
||||
return full if full else "/"
|
||||
|
||||
|
||||
def _check_readonly(environ: dict[str, object]) -> None:
|
||||
"""检查账户是否只读,只读则抛出 403"""
|
||||
account = environ.get("disknext.webdav_account")
|
||||
if account and getattr(account, 'readonly', False):
|
||||
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
|
||||
|
||||
|
||||
def _check_storage_quota(user: User, additional_bytes: int) -> None:
|
||||
"""检查存储配额"""
|
||||
max_storage = user.group.max_storage
|
||||
if max_storage > 0 and user.storage + additional_bytes > max_storage:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
|
||||
|
||||
class QuotaLimitedWriter(io.RawIOBase):
|
||||
"""带配额限制的写入流包装器"""
|
||||
|
||||
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
|
||||
self._stream = stream
|
||||
self._max_bytes = max_bytes
|
||||
self._bytes_written = 0
|
||||
|
||||
def writable(self) -> bool:
|
||||
return True
|
||||
|
||||
def write(self, b: bytes | bytearray) -> int:
|
||||
if self._bytes_written + len(b) > self._max_bytes:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
written = self._stream.write(b)
|
||||
self._bytes_written += written
|
||||
return written
|
||||
|
||||
def close(self) -> None:
|
||||
self._stream.close()
|
||||
super().close()
|
||||
|
||||
@property
|
||||
def bytes_written(self) -> int:
|
||||
return self._bytes_written
|
||||
|
||||
|
||||
# ==================== Provider ====================
|
||||
|
||||
class DiskNextDAVProvider(DAVProvider):
|
||||
"""DiskNext WebDAV 存储 Provider"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_resource_inst(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""
|
||||
将 WebDAV 路径映射到资源对象。
|
||||
|
||||
首次调用时加载 WebDAV 账户信息并缓存到 environ。
|
||||
"""
|
||||
user_id, webdav_id = _get_environ_info(environ)
|
||||
|
||||
# 首次请求时加载账户信息
|
||||
if "disknext.webdav_account" not in environ:
|
||||
account = _run_async(_get_webdav_account(webdav_id))
|
||||
if not account:
|
||||
return None
|
||||
environ["disknext.webdav_account"] = account
|
||||
|
||||
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
|
||||
disknext_path = _resolve_dav_path(account.root, path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(path, environ, obj, user_id, account)
|
||||
else:
|
||||
return DiskNextFile(path, environ, obj, user_id, account)
|
||||
|
||||
def is_readonly(self) -> bool:
|
||||
"""只读由账户级别控制,不在 provider 级别限制"""
|
||||
return False
|
||||
|
||||
|
||||
# ==================== Collection(目录) ====================
|
||||
|
||||
class DiskNextCollection(DAVCollection):
|
||||
"""DiskNext 目录资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "Directory"}
|
||||
|
||||
def get_member_names(self) -> list[str]:
|
||||
"""获取子对象名称列表"""
|
||||
children = _run_async(_get_children(self._user_id, self._obj.id))
|
||||
return [c.name for c in children]
|
||||
|
||||
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
|
||||
"""获取指定名称的子资源"""
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
account_root = self._account.root
|
||||
disknext_path = _resolve_dav_path(account_root, member_path)
|
||||
|
||||
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
|
||||
if not obj:
|
||||
return None
|
||||
|
||||
if obj.is_folder:
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
else:
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def create_empty_resource(self, name: str) -> 'DiskNextFile':
|
||||
"""创建空文件(PUT 操作的第一步)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_file(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def create_collection(self, name: str) -> 'DiskNextCollection':
|
||||
"""创建子目录(MKCOL)"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
obj = _run_async(_create_folder(
|
||||
name=name,
|
||||
parent_id=self._obj.id,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
member_path = self.path.rstrip("/") + "/" + name
|
||||
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除目录"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动目录"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_recursive_delete(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_recursive_move(self, dest_path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# ==================== NonCollection(文件) ====================
|
||||
|
||||
class DiskNextFile(DAVNonCollection):
|
||||
"""DiskNext 文件资源"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
environ: dict[str, object],
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
account: WebDAV,
|
||||
) -> None:
|
||||
super().__init__(path, environ)
|
||||
self._obj = obj
|
||||
self._user_id = user_id
|
||||
self._account = account
|
||||
self._write_path: str | None = None
|
||||
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
|
||||
|
||||
def get_content_length(self) -> int | None:
|
||||
return self._obj.size if self._obj.size else 0
|
||||
|
||||
def get_content_type(self) -> str | None:
|
||||
# 尝试从文件名推断 MIME 类型
|
||||
mime, _ = mimetypes.guess_type(self._obj.name)
|
||||
return mime or "application/octet-stream"
|
||||
|
||||
def get_creation_date(self) -> float | None:
|
||||
if self._obj.created_at:
|
||||
return self._obj.created_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_last_modified(self) -> float | None:
|
||||
if self._obj.updated_at:
|
||||
return self._obj.updated_at.timestamp()
|
||||
return None
|
||||
|
||||
def get_display_info(self) -> dict[str, str]:
|
||||
return {"type": "File"}
|
||||
|
||||
def get_content(self) -> io.BufferedReader | None:
|
||||
"""
|
||||
返回文件内容的可读流。
|
||||
|
||||
WsgiDAV 在线程中运行,可安全使用同步 open()。
|
||||
"""
|
||||
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
|
||||
if not obj_with_file or not obj_with_file.physical_file:
|
||||
return None
|
||||
|
||||
pf = obj_with_file.physical_file
|
||||
policy = _run_async(_get_policy(obj_with_file.policy_id))
|
||||
if not policy or not policy.server:
|
||||
return None
|
||||
|
||||
full_path = Path(policy.server).resolve() / pf.storage_path
|
||||
if not full_path.is_file():
|
||||
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
|
||||
return None
|
||||
|
||||
return open(full_path, "rb") # noqa: SIM115
|
||||
|
||||
def begin_write(
|
||||
self,
|
||||
*,
|
||||
content_type: str | None = None,
|
||||
) -> io.BufferedWriter | QuotaLimitedWriter:
|
||||
"""
|
||||
开始写入文件(PUT 操作)。
|
||||
|
||||
返回一个可写的文件流,WsgiDAV 将向其中写入请求体数据。
|
||||
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
|
||||
"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
# 检查配额
|
||||
remaining_quota: int = 0
|
||||
user = _run_async(_get_user(self._user_id))
|
||||
if user:
|
||||
max_storage = user.group.max_storage
|
||||
if max_storage > 0:
|
||||
remaining_quota = max_storage - user.storage
|
||||
if remaining_quota <= 0:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
# Content-Length 预检(如果有的话)
|
||||
content_length = self.environ.get("CONTENT_LENGTH")
|
||||
if content_length and int(content_length) > remaining_quota:
|
||||
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
|
||||
|
||||
# 获取策略以确定存储路径
|
||||
policy = _run_async(_get_policy(self._obj.policy_id))
|
||||
if not policy or not policy.server:
|
||||
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
|
||||
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = _run_async(
|
||||
storage_service.generate_file_path(
|
||||
user_id=self._user_id,
|
||||
original_filename=self._obj.name,
|
||||
)
|
||||
)
|
||||
|
||||
self._write_path = full_path
|
||||
raw_stream = open(full_path, "wb") # noqa: SIM115
|
||||
|
||||
# 有配额限制时使用包装流,实时检查写入量
|
||||
if remaining_quota > 0:
|
||||
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
|
||||
else:
|
||||
self._write_stream = raw_stream
|
||||
|
||||
return self._write_stream
|
||||
|
||||
def end_write(self, *, with_errors: bool) -> None:
|
||||
"""写入完成后的收尾工作"""
|
||||
if self._write_stream:
|
||||
self._write_stream.close()
|
||||
self._write_stream = None
|
||||
|
||||
if with_errors:
|
||||
if self._write_path:
|
||||
file_path = Path(self._write_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
return
|
||||
|
||||
if not self._write_path:
|
||||
return
|
||||
|
||||
# 获取文件大小
|
||||
file_path = Path(self._write_path)
|
||||
if not file_path.exists():
|
||||
return
|
||||
|
||||
size = file_path.stat().st_size
|
||||
|
||||
# 更新数据库记录
|
||||
_run_async(_finalize_upload(
|
||||
object_id=self._obj.id,
|
||||
physical_path=self._write_path,
|
||||
size=size,
|
||||
owner_id=self._user_id,
|
||||
policy_id=self._obj.policy_id,
|
||||
))
|
||||
|
||||
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
|
||||
|
||||
def delete(self) -> None:
|
||||
"""软删除文件"""
|
||||
_check_readonly(self.environ)
|
||||
_run_async(_soft_delete_object(self._obj.id))
|
||||
|
||||
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
|
||||
"""复制或移动文件"""
|
||||
_check_readonly(self.environ)
|
||||
|
||||
account_root = self._account.root
|
||||
dest_disknext = _resolve_dav_path(account_root, dest_path)
|
||||
|
||||
# 解析目标父路径和新名称
|
||||
if "/" in dest_disknext.rstrip("/"):
|
||||
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
|
||||
new_name = dest_disknext.rsplit("/", 1)[1]
|
||||
else:
|
||||
parent_path = "/"
|
||||
new_name = dest_disknext.lstrip("/")
|
||||
|
||||
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
|
||||
if not dest_parent:
|
||||
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
|
||||
|
||||
if is_move:
|
||||
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
|
||||
else:
|
||||
_run_async(_copy_object_recursive(
|
||||
self._obj.id, dest_parent.id, new_name, self._user_id,
|
||||
))
|
||||
|
||||
return True
|
||||
|
||||
def support_content_length(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_etag(self) -> str | None:
|
||||
"""返回 ETag(基于ID和更新时间),WsgiDAV 会自动加双引号"""
|
||||
if self._obj.updated_at:
|
||||
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
|
||||
return None
|
||||
|
||||
def support_etag(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_ranges(self) -> bool:
|
||||
return True
|
||||
11
routers/wopi/__init__.py
Normal file
11
routers/wopi/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
WOPI(Web Application Open Platform Interface)路由
|
||||
|
||||
挂载在根级别 /wopi(非 /api/v1 下),因为 WOPI 客户端要求标准路径。
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .files import wopi_files_router
|
||||
|
||||
wopi_router = APIRouter(prefix="/wopi", tags=["wopi"])
|
||||
wopi_router.include_router(wopi_files_router)
|
||||
203
routers/wopi/files/__init__.py
Normal file
203
routers/wopi/files/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
WOPI 文件操作端点
|
||||
|
||||
实现 WOPI 协议的核心文件操作接口:
|
||||
- CheckFileInfo: 获取文件元数据
|
||||
- GetFile: 下载文件内容
|
||||
- PutFile: 上传/更新文件内容
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Query, Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger as l
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from sqlmodels import Object, PhysicalFile, Policy, PolicyType, User, WopiFileInfo
|
||||
from service.storage import LocalStorageService
|
||||
from utils import http_exceptions
|
||||
from utils.JWT.wopi_token import verify_wopi_token
|
||||
|
||||
wopi_files_router = APIRouter(prefix="/files", tags=["wopi"])
|
||||
|
||||
|
||||
@wopi_files_router.get(
|
||||
path='/{file_id}',
|
||||
summary='WOPI CheckFileInfo',
|
||||
description='返回文件的元数据信息。',
|
||||
)
|
||||
async def check_file_info(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
WOPI CheckFileInfo 端点
|
||||
|
||||
认证:WOPI access_token(query 参数)
|
||||
|
||||
返回 WOPI 规范的 PascalCase JSON。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(
|
||||
session,
|
||||
Object.id == file_id,
|
||||
)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取用户信息
|
||||
user: User | None = await User.get(session, User.id == payload.user_id)
|
||||
user_name = user.nickname or user.email or str(payload.user_id) if user else str(payload.user_id)
|
||||
|
||||
# 构建响应
|
||||
info = WopiFileInfo(
|
||||
base_file_name=file_obj.name,
|
||||
size=file_obj.size or 0,
|
||||
owner_id=str(file_obj.owner_id),
|
||||
user_id=str(payload.user_id),
|
||||
user_friendly_name=user_name,
|
||||
version=file_obj.updated_at.isoformat() if file_obj.updated_at else "",
|
||||
user_can_write=payload.can_write,
|
||||
read_only=not payload.can_write,
|
||||
supports_update=payload.can_write,
|
||||
)
|
||||
|
||||
return JSONResponse(content=info.to_wopi_dict())
|
||||
|
||||
|
||||
@wopi_files_router.get(
|
||||
path='/{file_id}/contents',
|
||||
summary='WOPI GetFile',
|
||||
description='返回文件的二进制内容。',
|
||||
)
|
||||
async def get_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> Response:
|
||||
"""
|
||||
WOPI GetFile 端点
|
||||
|
||||
认证:WOPI access_token(query 参数)
|
||||
|
||||
返回文件的原始二进制内容。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取物理文件
|
||||
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
# 获取策略
|
||||
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
if not await storage_service.file_exists(physical_file.storage_path):
|
||||
http_exceptions.raise_not_found("物理文件不存在")
|
||||
|
||||
import aiofiles
|
||||
async with aiofiles.open(physical_file.storage_path, 'rb') as f:
|
||||
content = await f.read()
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type="application/octet-stream",
|
||||
headers={"X-WOPI-ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||
)
|
||||
else:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
|
||||
|
||||
@wopi_files_router.post(
|
||||
path='/{file_id}/contents',
|
||||
summary='WOPI PutFile',
|
||||
description='更新文件内容。',
|
||||
)
|
||||
async def put_file(
|
||||
session: SessionDep,
|
||||
request: Request,
|
||||
file_id: UUID,
|
||||
access_token: str = Query(...),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
WOPI PutFile 端点
|
||||
|
||||
认证:WOPI access_token(query 参数,需要写权限)
|
||||
|
||||
接收请求体中的文件二进制内容并覆盖存储。
|
||||
"""
|
||||
# 验证令牌
|
||||
payload = verify_wopi_token(access_token)
|
||||
if not payload or payload.file_id != file_id:
|
||||
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
|
||||
|
||||
if not payload.can_write:
|
||||
http_exceptions.raise_forbidden("没有写入权限")
|
||||
|
||||
# 获取文件
|
||||
file_obj: Object | None = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj or not file_obj.is_file:
|
||||
http_exceptions.raise_not_found("文件不存在")
|
||||
|
||||
# 获取物理文件
|
||||
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
http_exceptions.raise_internal_error("文件存储路径丢失")
|
||||
|
||||
# 获取策略
|
||||
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
|
||||
if not policy:
|
||||
http_exceptions.raise_internal_error("存储策略不存在")
|
||||
|
||||
# 读取请求体
|
||||
content = await request.body()
|
||||
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
import aiofiles
|
||||
async with aiofiles.open(physical_file.storage_path, 'wb') as f:
|
||||
await f.write(content)
|
||||
|
||||
# 更新文件大小
|
||||
new_size = len(content)
|
||||
old_size = file_obj.size or 0
|
||||
file_obj.size = new_size
|
||||
file_obj = await file_obj.save(session, commit=False)
|
||||
|
||||
# 更新物理文件大小
|
||||
physical_file.size = new_size
|
||||
await physical_file.save(session, commit=False)
|
||||
|
||||
# 更新用户存储配额
|
||||
size_diff = new_size - old_size
|
||||
if size_diff != 0:
|
||||
from service.storage import adjust_user_storage
|
||||
await adjust_user_storage(session, file_obj.owner_id, size_diff, commit=False)
|
||||
|
||||
await session.commit()
|
||||
|
||||
l.info(f"WOPI PutFile: file_id={file_id}, new_size={new_size}")
|
||||
|
||||
return JSONResponse(
|
||||
content={"ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
|
||||
status_code=200,
|
||||
)
|
||||
else:
|
||||
http_exceptions.raise_not_implemented("S3 存储暂未实现")
|
||||
@@ -1,18 +1,20 @@
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
|
||||
import aiohttp
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .gcaptcha import GCaptcha
|
||||
from .turnstile import TurnstileCaptcha
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class CaptchaRequestBase(BaseModel):
|
||||
"""验证码验证请求"""
|
||||
token: str
|
||||
"""验证 token"""
|
||||
|
||||
response: str
|
||||
"""用户的验证码 response token"""
|
||||
|
||||
secret: str
|
||||
"""验证密钥"""
|
||||
"""服务端密钥"""
|
||||
|
||||
|
||||
class CaptchaBase(abc.ABC):
|
||||
@@ -30,10 +32,85 @@ class CaptchaBase(abc.ABC):
|
||||
"""
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
async with aiohttp.ClientSession() as client_session:
|
||||
async with client_session.post(self.verify_url, data=payload) as resp:
|
||||
if resp.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
result = await resp.json()
|
||||
return result.get('success', False)
|
||||
|
||||
|
||||
# 子类导入必须在 CaptchaBase 定义之后(gcaptcha.py / turnstile.py 依赖 CaptchaBase)
|
||||
from .gcaptcha import GCaptcha # noqa: E402
|
||||
from .turnstile import TurnstileCaptcha # noqa: E402
|
||||
|
||||
|
||||
class CaptchaScene(StrEnum):
|
||||
"""验证码使用场景,value 对应 Setting 表中的 name"""
|
||||
|
||||
LOGIN = "login_captcha"
|
||||
REGISTER = "reg_captcha"
|
||||
FORGET = "forget_captcha"
|
||||
|
||||
|
||||
async def verify_captcha_if_needed(
|
||||
session: AsyncSession,
|
||||
scene: CaptchaScene,
|
||||
captcha_code: str,
|
||||
) -> None:
|
||||
"""
|
||||
通用验证码校验:查询设置判断是否需要,需要则校验。
|
||||
|
||||
:param session: 数据库异步会话
|
||||
:param scene: 验证码使用场景
|
||||
:param captcha_code: 用户提交的验证码 response token
|
||||
:raises HTTPException 400: 需要验证码但未提供
|
||||
:raises HTTPException 403: 验证码验证失败
|
||||
:raises HTTPException 500: 验证码密钥未配置
|
||||
"""
|
||||
from sqlmodels import Setting, SettingsType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
# 1. 查询该场景是否需要验证码
|
||||
scene_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.LOGIN) & (Setting.name == scene.value),
|
||||
)
|
||||
if not scene_setting or scene_setting.value != "1":
|
||||
return
|
||||
|
||||
# 2. 查询验证码类型和密钥
|
||||
captcha_settings: list[Setting] = await Setting.get(
|
||||
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
|
||||
)
|
||||
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
|
||||
captcha_type = CaptchaType(s.get("captcha_type") or "default")
|
||||
|
||||
# 3. DEFAULT 图片验证码尚未实现,跳过
|
||||
if captcha_type == CaptchaType.DEFAULT:
|
||||
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
|
||||
return
|
||||
|
||||
# 4. 选择验证器和密钥
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
secret = s.get("captcha_ReCaptchaSecret")
|
||||
verifier: CaptchaBase = GCaptcha()
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
secret = s.get("captcha_CloudflareSecret")
|
||||
verifier = TurnstileCaptcha()
|
||||
else:
|
||||
l.error(f"未知的验证码类型: {captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
if not secret:
|
||||
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 5. 调用第三方 API 校验
|
||||
is_valid = await verifier.verify_captcha(
|
||||
CaptchaRequestBase(response=captcha_code, secret=secret)
|
||||
)
|
||||
if not is_valid:
|
||||
http_exceptions.raise_forbidden(detail="验证码验证失败")
|
||||
|
||||
5
service/captcha/default.py
Normal file
5
service/captcha/default.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from captcha.image import ImageCaptcha
|
||||
|
||||
captcha = ImageCaptcha()
|
||||
|
||||
print(captcha.generate())
|
||||
68
service/redis/challenge_store.py
Normal file
68
service/redis/challenge_store.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
WebAuthn Challenge 一次性存储
|
||||
|
||||
支持 Redis(首选,使用 GETDEL 原子操作)和内存 TTLCache(降级)。
|
||||
Challenge 存储后 5 分钟过期,取出即删除(防重放)。
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
# Challenge 过期时间(秒)
|
||||
_CHALLENGE_TTL: int = 300
|
||||
|
||||
|
||||
class ChallengeStore:
|
||||
"""
|
||||
WebAuthn Challenge 一次性存储管理器
|
||||
|
||||
根据 Redis 可用性自动选择存储后端:
|
||||
- Redis 可用:使用 Redis GETDEL 原子操作
|
||||
- Redis 不可用:使用内存 TTLCache(仅单实例)
|
||||
|
||||
Key 约定:
|
||||
- 注册: ``reg:{user_id}``
|
||||
- 登录: ``auth:{challenge_token}``
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, bytes]] = TTLCache(
|
||||
maxsize=10000,
|
||||
ttl=_CHALLENGE_TTL,
|
||||
)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
async def store(cls, key: str, challenge: bytes) -> None:
|
||||
"""
|
||||
存储 challenge,TTL 5 分钟。
|
||||
|
||||
:param key: 存储键(如 ``reg:{user_id}`` 或 ``auth:{token}``)
|
||||
:param challenge: challenge 字节数据
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
|
||||
if client is not None:
|
||||
redis_key = f"webauthn_challenge:{key}"
|
||||
await client.set(redis_key, challenge, ex=_CHALLENGE_TTL)
|
||||
else:
|
||||
cls._memory_cache[key] = challenge
|
||||
|
||||
@classmethod
|
||||
async def retrieve_and_delete(cls, key: str) -> bytes | None:
|
||||
"""
|
||||
一次性取出并删除 challenge(防重放)。
|
||||
|
||||
:param key: 存储键
|
||||
:return: challenge 字节数据,过期或不存在时返回 None
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
|
||||
if client is not None:
|
||||
redis_key = f"webauthn_challenge:{key}"
|
||||
result: bytes | None = await client.getdel(redis_key)
|
||||
return result
|
||||
else:
|
||||
return cls._memory_cache.pop(key, None)
|
||||
72
service/redis/user_ban_store.py
Normal file
72
service/redis/user_ban_store.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
用于 JWT 模式下的即时封禁生效。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
# access_token 有效期(秒)
|
||||
_BAN_TTL: int = 3600
|
||||
|
||||
|
||||
class UserBanStore:
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
管理员封禁用户时调用 ban(),jwt_required 每次请求调用 is_banned() 检查。
|
||||
TTL 与 access_token 有效期一致(1h),过期后旧 token 自然失效,无需继续记录。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, bool]] = TTLCache(maxsize=10000, ttl=_BAN_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
async def ban(cls, user_id: str) -> None:
|
||||
"""
|
||||
标记用户为已封禁。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.set(key, "1", ex=_BAN_TTL)
|
||||
else:
|
||||
cls._memory_cache[user_id] = True
|
||||
l.info(f"用户 {user_id} 已加入封禁黑名单")
|
||||
|
||||
@classmethod
|
||||
async def unban(cls, user_id: str) -> None:
|
||||
"""
|
||||
移除用户封禁标记(解封时调用)。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.delete(key)
|
||||
else:
|
||||
cls._memory_cache.pop(user_id, None)
|
||||
l.info(f"用户 {user_id} 已从封禁黑名单移除")
|
||||
|
||||
@classmethod
|
||||
async def is_banned(cls, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否在封禁黑名单中。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
:return: True 表示已封禁
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
return await client.exists(key) > 0
|
||||
else:
|
||||
return user_id in cls._memory_cache
|
||||
128
service/redis/webdav_auth_cache.py
Normal file
128
service/redis/webdav_auth_cache.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
WebDAV 认证缓存
|
||||
|
||||
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
import hashlib
|
||||
from typing import ClassVar
|
||||
from uuid import UUID
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
_AUTH_TTL: int = 300
|
||||
"""认证缓存 TTL(秒),5 分钟"""
|
||||
|
||||
|
||||
class WebDAVAuthCache:
|
||||
"""
|
||||
WebDAV 认证结果缓存
|
||||
|
||||
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
|
||||
缓存值格式: {user_id}:{webdav_id}
|
||||
|
||||
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
def _build_key(cls, email: str, account_name: str, password: str) -> str:
|
||||
"""构建缓存键"""
|
||||
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
|
||||
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
) -> tuple[UUID, int] | None:
|
||||
"""
|
||||
查询缓存中的认证结果。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:return: (user_id, webdav_id) 或 None(缓存未命中)
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
value = await client.get(key)
|
||||
if value is not None:
|
||||
raw = value.decode() if isinstance(value, bytes) else value
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
else:
|
||||
raw = cls._memory_cache.get(key)
|
||||
if raw is not None:
|
||||
user_id_str, webdav_id_str = raw.split(":", 1)
|
||||
return UUID(user_id_str), int(webdav_id_str)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def set(
|
||||
cls,
|
||||
email: str,
|
||||
account_name: str,
|
||||
password: str,
|
||||
user_id: UUID,
|
||||
webdav_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
写入认证结果到缓存。
|
||||
|
||||
:param email: 用户邮箱
|
||||
:param account_name: WebDAV 账户名
|
||||
:param password: 用户提供的明文密码
|
||||
:param user_id: 用户UUID
|
||||
:param webdav_id: WebDAV 账户ID
|
||||
"""
|
||||
key = cls._build_key(email, account_name, password)
|
||||
value = f"{user_id}:{webdav_id}"
|
||||
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
await client.set(key, value, ex=_AUTH_TTL)
|
||||
else:
|
||||
cls._memory_cache[key] = value
|
||||
|
||||
@classmethod
|
||||
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
|
||||
"""
|
||||
失效指定账户的所有缓存。
|
||||
|
||||
由于缓存键包含 password hash,无法精确删除,
|
||||
Redis 端使用 pattern scan 删除,内存端清空全部。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param account_name: WebDAV 账户名
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
pattern = f"webdav_auth:*/{account_name}:*"
|
||||
cursor: int = 0
|
||||
while True:
|
||||
cursor, keys = await client.scan(cursor, match=pattern, count=100)
|
||||
if keys:
|
||||
await client.delete(*keys)
|
||||
if cursor == 0:
|
||||
break
|
||||
else:
|
||||
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
|
||||
keys_to_delete = [
|
||||
k for k in cls._memory_cache
|
||||
if f"/{account_name}:" in k
|
||||
]
|
||||
for k in keys_to_delete:
|
||||
cls._memory_cache.pop(k, None)
|
||||
|
||||
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
提供文件存储相关的服务,包括:
|
||||
- 本地存储服务
|
||||
- S3 存储服务
|
||||
- 命名规则解析器
|
||||
- 存储异常定义
|
||||
"""
|
||||
@@ -11,6 +12,8 @@ from .exceptions import (
|
||||
FileReadError,
|
||||
FileWriteError,
|
||||
InvalidPathError,
|
||||
S3APIError,
|
||||
S3MultipartUploadError,
|
||||
StorageException,
|
||||
StorageFileNotFoundError,
|
||||
UploadSessionExpiredError,
|
||||
@@ -18,3 +21,13 @@ from .exceptions import (
|
||||
)
|
||||
from .local_storage import LocalStorageService
|
||||
from .naming_rule import NamingContext, NamingRuleParser
|
||||
from .object import (
|
||||
adjust_user_storage,
|
||||
copy_object_recursive,
|
||||
delete_object_recursive,
|
||||
permanently_delete_objects,
|
||||
restore_objects,
|
||||
soft_delete_objects,
|
||||
)
|
||||
from .migrate import migrate_file_with_task, migrate_directory_files
|
||||
from .s3_storage import S3StorageService
|
||||
@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
|
||||
class InvalidPathError(StorageException):
|
||||
"""无效的路径"""
|
||||
pass
|
||||
|
||||
|
||||
class S3APIError(StorageException):
|
||||
"""S3 API 请求错误"""
|
||||
pass
|
||||
|
||||
|
||||
class S3MultipartUploadError(S3APIError):
|
||||
"""S3 分片上传错误"""
|
||||
pass
|
||||
|
||||
@@ -15,7 +15,7 @@ import aiofiles
|
||||
import aiofiles.os
|
||||
from loguru import logger as l
|
||||
|
||||
from models.policy import Policy
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import (
|
||||
DirectoryCreationError,
|
||||
FileReadError,
|
||||
@@ -263,15 +263,49 @@ class LocalStorageService:
|
||||
"""
|
||||
删除文件(物理删除)
|
||||
|
||||
删除文件后会尝试清理因此变空的父目录。
|
||||
|
||||
:param path: 完整文件路径
|
||||
"""
|
||||
if await self.file_exists(path):
|
||||
try:
|
||||
await aiofiles.os.remove(path)
|
||||
l.debug(f"已删除文件: {path}")
|
||||
await self._cleanup_empty_parents(path)
|
||||
except OSError as e:
|
||||
l.warning(f"删除文件失败 {path}: {e}")
|
||||
|
||||
async def _cleanup_empty_parents(self, file_path: str) -> None:
|
||||
"""
|
||||
从被删文件的父目录开始,向上逐级删除空目录
|
||||
|
||||
在以下情况停止:
|
||||
|
||||
- 到达存储根目录(_base_path)
|
||||
- 遇到非空目录
|
||||
- 遇到 .trash 目录
|
||||
- 删除失败(权限、并发等)
|
||||
|
||||
:param file_path: 被删文件的完整路径
|
||||
"""
|
||||
current = Path(file_path).parent
|
||||
|
||||
while current != self._base_path and str(current).startswith(str(self._base_path)):
|
||||
if current.name == '.trash':
|
||||
break
|
||||
|
||||
try:
|
||||
entries = await aiofiles.os.listdir(str(current))
|
||||
if entries:
|
||||
break
|
||||
|
||||
await aiofiles.os.rmdir(str(current))
|
||||
l.debug(f"已清理空目录: {current}")
|
||||
current = current.parent
|
||||
except OSError as e:
|
||||
l.debug(f"清理空目录失败(忽略): {current}: {e}")
|
||||
break
|
||||
|
||||
async def move_to_trash(
|
||||
self,
|
||||
source_path: str,
|
||||
@@ -304,6 +338,7 @@ class LocalStorageService:
|
||||
try:
|
||||
await aiofiles.os.rename(source_path, str(trash_path))
|
||||
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
|
||||
await self._cleanup_empty_parents(source_path)
|
||||
return str(trash_path)
|
||||
except OSError as e:
|
||||
raise StorageException(f"移动文件到回收站失败: {e}")
|
||||
|
||||
291
service/storage/migrate.py
Normal file
291
service/storage/migrate.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
存储策略迁移服务
|
||||
|
||||
提供跨存储策略的文件迁移功能:
|
||||
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
|
||||
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
|
||||
"""
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.physical_file import PhysicalFile
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.task import Task, TaskStatus
|
||||
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
|
||||
|
||||
async def _get_storage_service(
|
||||
policy: Policy,
|
||||
) -> LocalStorageService | S3StorageService:
|
||||
"""
|
||||
根据策略类型创建对应的存储服务实例
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: 存储服务实例
|
||||
"""
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
return LocalStorageService(policy)
|
||||
elif policy.type == PolicyType.S3:
|
||||
return await S3StorageService.from_policy(policy)
|
||||
else:
|
||||
raise ValueError(f"不支持的存储策略类型: {policy.type}")
|
||||
|
||||
|
||||
async def _read_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> bytes:
|
||||
"""
|
||||
从存储服务读取文件内容
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:return: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
return await service.read_file(storage_path)
|
||||
else:
|
||||
return await service.download_file(storage_path)
|
||||
|
||||
|
||||
async def _write_file_to_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
data: bytes,
|
||||
) -> None:
|
||||
"""
|
||||
将文件内容写入存储服务
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
:param data: 文件二进制内容
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.write_file(storage_path, data)
|
||||
else:
|
||||
await service.upload_file(storage_path, data)
|
||||
|
||||
|
||||
async def _delete_file_from_storage(
|
||||
service: LocalStorageService | S3StorageService,
|
||||
storage_path: str,
|
||||
) -> None:
|
||||
"""
|
||||
从存储服务删除文件
|
||||
|
||||
:param service: 存储服务实例
|
||||
:param storage_path: 文件存储路径
|
||||
"""
|
||||
if isinstance(service, LocalStorageService):
|
||||
await service.delete_file(storage_path)
|
||||
else:
|
||||
await service.delete_file(storage_path)
|
||||
|
||||
|
||||
async def migrate_single_file(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
) -> None:
|
||||
"""
|
||||
将单个文件对象从当前存储策略迁移到目标策略
|
||||
|
||||
流程:
|
||||
1. 获取源物理文件和存储服务
|
||||
2. 读取源文件内容
|
||||
3. 在目标存储中生成新路径并写入
|
||||
4. 创建新的 PhysicalFile 记录
|
||||
5. 更新 Object 的 policy_id 和 physical_file_id
|
||||
6. 旧 PhysicalFile 引用计数 -1,如为 0 则删除源物理文件
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象(必须为文件类型)
|
||||
:param dest_policy: 目标存储策略
|
||||
"""
|
||||
if obj.type != ObjectType.FILE:
|
||||
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
|
||||
|
||||
# 获取源策略和物理文件
|
||||
src_policy: Policy = await obj.awaitable_attrs.policy
|
||||
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
|
||||
|
||||
if not old_physical:
|
||||
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
|
||||
return
|
||||
|
||||
if src_policy.id == dest_policy.id:
|
||||
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
|
||||
return
|
||||
|
||||
# 1. 从源存储读取文件
|
||||
src_service = await _get_storage_service(src_policy)
|
||||
data = await _read_file_from_storage(src_service, old_physical.storage_path)
|
||||
|
||||
# 2. 在目标存储生成新路径并写入
|
||||
dest_service = await _get_storage_service(dest_policy)
|
||||
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
|
||||
user_id=obj.owner_id,
|
||||
original_filename=obj.name,
|
||||
)
|
||||
await _write_file_to_storage(dest_service, new_storage_path, data)
|
||||
|
||||
# 3. 创建新的 PhysicalFile
|
||||
new_physical = PhysicalFile(
|
||||
storage_path=new_storage_path,
|
||||
size=old_physical.size,
|
||||
checksum_md5=old_physical.checksum_md5,
|
||||
policy_id=dest_policy.id,
|
||||
reference_count=1,
|
||||
)
|
||||
new_physical = await new_physical.save(session)
|
||||
|
||||
# 4. 更新 Object
|
||||
obj.policy_id = dest_policy.id
|
||||
obj.physical_file_id = new_physical.id
|
||||
obj = await obj.save(session)
|
||||
|
||||
# 5. 旧 PhysicalFile 引用计数 -1
|
||||
old_physical.decrement_reference()
|
||||
if old_physical.can_be_deleted:
|
||||
# 删除源存储中的物理文件
|
||||
try:
|
||||
await _delete_file_from_storage(src_service, old_physical.storage_path)
|
||||
except Exception as e:
|
||||
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
|
||||
await PhysicalFile.delete(session, old_physical)
|
||||
else:
|
||||
old_physical = await old_physical.save(session)
|
||||
|
||||
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name} → {dest_policy.name}")
|
||||
|
||||
|
||||
async def migrate_file_with_task(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移单个文件并更新任务状态
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 待迁移的文件对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
await migrate_single_file(session, obj, dest_policy)
|
||||
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
except Exception as e:
|
||||
l.error(f"文件迁移任务失败: {obj.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
task = await task.save(session)
|
||||
|
||||
|
||||
async def migrate_directory_files(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
dest_policy: Policy,
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""
|
||||
迁移目录下所有文件到目标存储策略
|
||||
|
||||
递归遍历目录树,将所有文件迁移到目标策略。
|
||||
子目录的 policy_id 同步更新。
|
||||
任务进度按文件数比例更新。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 目录对象
|
||||
:param dest_policy: 目标存储策略
|
||||
:param task: 关联的任务记录
|
||||
"""
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = 0
|
||||
task = await task.save(session)
|
||||
|
||||
# 收集所有需要迁移的文件
|
||||
files_to_migrate: list[Object] = []
|
||||
folders_to_update: list[Object] = []
|
||||
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
|
||||
|
||||
total = len(files_to_migrate)
|
||||
migrated = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for file_obj in files_to_migrate:
|
||||
try:
|
||||
await migrate_single_file(session, file_obj, dest_policy)
|
||||
migrated += 1
|
||||
except Exception as e:
|
||||
error_msg = f"{file_obj.name}: {e}"
|
||||
l.error(f"迁移文件失败: {error_msg}")
|
||||
errors.append(error_msg)
|
||||
|
||||
# 更新进度
|
||||
if total > 0:
|
||||
task.progress = min(99, int(migrated / total * 100))
|
||||
task = await task.save(session)
|
||||
|
||||
# 更新所有子目录的 policy_id
|
||||
for sub_folder in folders_to_update:
|
||||
sub_folder.policy_id = dest_policy.id
|
||||
sub_folder = await sub_folder.save(session)
|
||||
|
||||
# 完成任务
|
||||
if errors:
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
|
||||
else:
|
||||
task.status = TaskStatus.COMPLETED
|
||||
|
||||
task.progress = 100
|
||||
task = await task.save(session)
|
||||
|
||||
l.info(
|
||||
f"目录迁移完成: {folder.name} ({folder.id}), "
|
||||
f"成功 {migrated}/{total}, 错误 {len(errors)}"
|
||||
)
|
||||
except Exception as e:
|
||||
l.error(f"目录迁移任务失败: {folder.id}: {e}")
|
||||
task.status = TaskStatus.ERROR
|
||||
task.error = str(e)[:500]
|
||||
task = await task.save(session)
|
||||
|
||||
|
||||
async def _collect_objects_recursive(
|
||||
session: AsyncSession,
|
||||
folder: Object,
|
||||
files: list[Object],
|
||||
folders: list[Object],
|
||||
) -> None:
|
||||
"""
|
||||
递归收集目录下所有文件和子目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param folder: 当前目录
|
||||
:param files: 文件列表(输出)
|
||||
:param folders: 子目录列表(输出)
|
||||
"""
|
||||
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
|
||||
|
||||
for child in children:
|
||||
if child.type == ObjectType.FILE:
|
||||
files.append(child)
|
||||
elif child.type == ObjectType.FOLDER:
|
||||
folders.append(child)
|
||||
await _collect_objects_recursive(session, child, files, folders)
|
||||
@@ -23,7 +23,7 @@ import string
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class NamingContext(SQLModelBase):
|
||||
|
||||
505
service/storage/object.py
Normal file
505
service/storage/object.py
Normal file
@@ -0,0 +1,505 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import update as sql_update
|
||||
from sqlalchemy.sql.functions import func
|
||||
from middleware.dependencies import SessionDep
|
||||
|
||||
from .local_storage import LocalStorageService
|
||||
from .s3_storage import S3StorageService
|
||||
from sqlmodels import (
|
||||
Object,
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
async def adjust_user_storage(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
delta: int,
|
||||
commit: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
原子更新用户已用存储空间
|
||||
|
||||
使用 SQL UPDATE SET storage = GREATEST(0, storage + delta) 避免竞态条件。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param delta: 变化量(正数增加,负数减少)
|
||||
:param commit: 是否立即提交
|
||||
"""
|
||||
if delta == 0:
|
||||
return
|
||||
|
||||
stmt = (
|
||||
sql_update(User)
|
||||
.where(User.id == user_id)
|
||||
.values(storage=func.greatest(0, User.storage + delta))
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
l.debug(f"用户 {user_id} 存储配额变更: {'+' if delta > 0 else ''}{delta} bytes")
|
||||
|
||||
|
||||
# ==================== 软删除 ====================
|
||||
|
||||
async def soft_delete_objects(
|
||||
session: SessionDep,
|
||||
objects: list[Object],
|
||||
) -> int:
|
||||
"""
|
||||
软删除对象列表
|
||||
|
||||
只标记顶层对象:设置 deleted_at、保存原 parent_id 到 deleted_original_parent_id、
|
||||
将 parent_id 置 NULL 脱离文件树。子对象保持不变,物理文件不移动。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param objects: 待软删除的对象列表
|
||||
:return: 软删除的对象数量
|
||||
"""
|
||||
deleted_count = 0
|
||||
now = datetime.now()
|
||||
|
||||
for obj in objects:
|
||||
obj.deleted_at = now
|
||||
obj.deleted_original_parent_id = obj.parent_id
|
||||
obj.parent_id = None
|
||||
await obj.save(session, commit=False, refresh=False)
|
||||
deleted_count += 1
|
||||
|
||||
await session.commit()
|
||||
return deleted_count
|
||||
|
||||
|
||||
# ==================== 恢复 ====================
|
||||
|
||||
async def _resolve_name_conflict(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
parent_id: UUID,
|
||||
name: str,
|
||||
) -> str:
|
||||
"""
|
||||
解决同名冲突,返回不冲突的名称
|
||||
|
||||
命名规则:原名称 → 原名称 (1) → 原名称 (2) → ...
|
||||
对于有扩展名的文件:name.ext → name (1).ext → name (2).ext → ...
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param parent_id: 父目录UUID
|
||||
:param name: 原始名称
|
||||
:return: 不冲突的名称
|
||||
"""
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent_id) &
|
||||
(Object.name == name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if not existing:
|
||||
return name
|
||||
|
||||
# 分离文件名和扩展名
|
||||
if '.' in name:
|
||||
base, ext = name.rsplit('.', 1)
|
||||
ext = f".{ext}"
|
||||
else:
|
||||
base = name
|
||||
ext = ""
|
||||
|
||||
counter = 1
|
||||
while True:
|
||||
new_name = f"{base} ({counter}){ext}"
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent_id) &
|
||||
(Object.name == new_name) &
|
||||
(Object.deleted_at == None)
|
||||
)
|
||||
if not existing:
|
||||
return new_name
|
||||
counter += 1
|
||||
|
||||
|
||||
async def restore_objects(
|
||||
session: SessionDep,
|
||||
objects: list[Object],
|
||||
user_id: UUID,
|
||||
) -> int:
|
||||
"""
|
||||
从回收站恢复对象
|
||||
|
||||
检查原父目录是否存在且未删除:
|
||||
- 存在 → 恢复到原位置
|
||||
- 不存在 → 恢复到用户根目录
|
||||
处理同名冲突(自动重命名)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param objects: 待恢复的对象列表(必须是回收站中的顶层对象)
|
||||
:param user_id: 用户UUID
|
||||
:return: 恢复的对象数量
|
||||
"""
|
||||
root = await Object.get_root(session, user_id)
|
||||
if not root:
|
||||
raise ValueError("用户根目录不存在")
|
||||
|
||||
restored_count = 0
|
||||
|
||||
for obj in objects:
|
||||
if not obj.deleted_at:
|
||||
continue
|
||||
|
||||
# 确定恢复目标目录
|
||||
target_parent_id = root.id
|
||||
if obj.deleted_original_parent_id:
|
||||
original_parent = await Object.get(
|
||||
session,
|
||||
(Object.id == obj.deleted_original_parent_id) & (Object.deleted_at == None)
|
||||
)
|
||||
if original_parent:
|
||||
target_parent_id = original_parent.id
|
||||
|
||||
# 解决同名冲突
|
||||
resolved_name = await _resolve_name_conflict(
|
||||
session, user_id, target_parent_id, obj.name
|
||||
)
|
||||
|
||||
# 恢复对象
|
||||
obj.parent_id = target_parent_id
|
||||
obj.deleted_at = None
|
||||
obj.deleted_original_parent_id = None
|
||||
if resolved_name != obj.name:
|
||||
obj.name = resolved_name
|
||||
await obj.save(session, commit=False, refresh=False)
|
||||
restored_count += 1
|
||||
|
||||
await session.commit()
|
||||
return restored_count
|
||||
|
||||
|
||||
# ==================== 永久删除 ====================
|
||||
|
||||
async def _collect_file_entries_all(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
root: Object,
|
||||
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
|
||||
"""
|
||||
BFS 收集子树中所有文件的物理文件信息(包含已删除和未删除的子对象)
|
||||
|
||||
只执行 SELECT 查询,不触发 commit,ORM 对象始终有效。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param root: 根对象
|
||||
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
|
||||
"""
|
||||
file_entries: list[tuple[UUID, str, UUID]] = []
|
||||
total_count = 1
|
||||
total_file_size = 0
|
||||
|
||||
# 根对象本身是文件
|
||||
if root.is_file and root.physical_file_id:
|
||||
file_entries.append((root.id, root.name, root.physical_file_id))
|
||||
total_file_size += root.size
|
||||
|
||||
# BFS 遍历子目录(使用 get_all_children 包含所有子对象)
|
||||
if root.is_folder:
|
||||
queue: list[UUID] = [root.id]
|
||||
while queue:
|
||||
parent_id = queue.pop(0)
|
||||
children = await Object.get_all_children(session, user_id, parent_id)
|
||||
for child in children:
|
||||
total_count += 1
|
||||
if child.is_file and child.physical_file_id:
|
||||
file_entries.append((child.id, child.name, child.physical_file_id))
|
||||
total_file_size += child.size
|
||||
elif child.is_folder:
|
||||
queue.append(child.id)
|
||||
|
||||
return file_entries, total_count, total_file_size
|
||||
|
||||
|
||||
async def permanently_delete_objects(
|
||||
session: SessionDep,
|
||||
objects: list[Object],
|
||||
user_id: UUID,
|
||||
) -> int:
|
||||
"""
|
||||
永久删除回收站中的对象
|
||||
|
||||
验证对象在回收站中(deleted_at IS NOT NULL),
|
||||
BFS 收集所有子文件的 PhysicalFile 信息,
|
||||
处理引用计数,引用为 0 时物理删除文件,
|
||||
最后硬删除根 Object(CASCADE 自动清理子对象)。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param objects: 待永久删除的对象列表
|
||||
:param user_id: 用户UUID
|
||||
:return: 永久删除的对象数量
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
for obj in objects:
|
||||
if not obj.deleted_at:
|
||||
l.warning(f"对象 {obj.id} 不在回收站中,跳过永久删除")
|
||||
continue
|
||||
|
||||
root_id = obj.id
|
||||
file_entries, obj_count, total_file_size = await _collect_file_entries_all(
|
||||
session, user_id, obj
|
||||
)
|
||||
|
||||
# 处理 PhysicalFile 引用计数
|
||||
for obj_id, obj_name, physical_file_id in file_entries:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
|
||||
if not physical_file:
|
||||
continue
|
||||
|
||||
physical_file.decrement_reference()
|
||||
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy:
|
||||
try:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
s3_service = await S3StorageService.from_policy(policy)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
|
||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||
else:
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||
|
||||
# 更新用户存储配额
|
||||
if total_file_size > 0:
|
||||
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
|
||||
|
||||
# 硬删除根对象,CASCADE 自动删除所有子对象(不立即提交,避免其余对象过期)
|
||||
await Object.delete(session, condition=Object.id == root_id, commit=False)
|
||||
|
||||
total_deleted += obj_count
|
||||
|
||||
# 统一提交所有变更
|
||||
await session.commit()
|
||||
return total_deleted
|
||||
|
||||
|
||||
# ==================== 旧接口(保持向后兼容) ====================
|
||||
|
||||
async def _collect_file_entries(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
root: Object,
|
||||
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
|
||||
"""
|
||||
BFS 收集子树中所有文件的物理文件信息
|
||||
|
||||
只执行 SELECT 查询,不触发 commit,ORM 对象始终有效。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param root: 根对象
|
||||
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
|
||||
"""
|
||||
file_entries: list[tuple[UUID, str, UUID]] = []
|
||||
total_count = 1
|
||||
total_file_size = 0
|
||||
|
||||
# 根对象本身是文件
|
||||
if root.is_file and root.physical_file_id:
|
||||
file_entries.append((root.id, root.name, root.physical_file_id))
|
||||
total_file_size += root.size
|
||||
|
||||
# BFS 遍历子目录
|
||||
if root.is_folder:
|
||||
queue: list[UUID] = [root.id]
|
||||
while queue:
|
||||
parent_id = queue.pop(0)
|
||||
children = await Object.get_children(session, user_id, parent_id)
|
||||
for child in children:
|
||||
total_count += 1
|
||||
if child.is_file and child.physical_file_id:
|
||||
file_entries.append((child.id, child.name, child.physical_file_id))
|
||||
total_file_size += child.size
|
||||
elif child.is_folder:
|
||||
queue.append(child.id)
|
||||
|
||||
return file_entries, total_count, total_file_size
|
||||
|
||||
|
||||
async def delete_object_recursive(
|
||||
session: SessionDep,
|
||||
obj: Object,
|
||||
user_id: UUID,
|
||||
) -> int:
|
||||
"""
|
||||
删除对象及其所有子对象(硬删除)
|
||||
|
||||
两阶段策略:
|
||||
1. BFS 只读收集所有文件的 PhysicalFile 信息
|
||||
2. 批量处理引用计数(commit=False),最后删除根对象触发 CASCADE
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要删除的对象
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除的对象数量
|
||||
"""
|
||||
# 阶段一:只读收集(不触发任何 commit)
|
||||
root_id = obj.id
|
||||
file_entries, total_count, total_file_size = await _collect_file_entries(session, user_id, obj)
|
||||
|
||||
# 阶段二:批量处理 PhysicalFile 引用(全部 commit=False)
|
||||
for obj_id, obj_name, physical_file_id in file_entries:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
|
||||
if not physical_file:
|
||||
continue
|
||||
|
||||
physical_file.decrement_reference()
|
||||
|
||||
if physical_file.can_be_deleted:
|
||||
# 物理删除文件
|
||||
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
|
||||
if policy:
|
||||
try:
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
await storage_service.delete_file(physical_file.storage_path)
|
||||
elif policy.type == PolicyType.S3:
|
||||
options = await policy.awaitable_attrs.options
|
||||
s3_service = S3StorageService(
|
||||
policy,
|
||||
region=options.s3_region if options else 'us-east-1',
|
||||
is_path_style=options.s3_path_style if options else False,
|
||||
)
|
||||
await s3_service.delete_file(physical_file.storage_path)
|
||||
l.debug(f"物理文件已删除: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
|
||||
|
||||
await PhysicalFile.delete(session, physical_file, commit=False)
|
||||
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
|
||||
else:
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
|
||||
|
||||
# 阶段三:更新用户存储配额(与删除在同一事务中)
|
||||
if total_file_size > 0:
|
||||
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
|
||||
|
||||
# 阶段四:删除根对象,数据库 CASCADE 自动删除所有子对象
|
||||
# commit=True(默认),一次性提交所有 PhysicalFile 变更 + Object 删除 + 配额更新
|
||||
await Object.delete(session, condition=Object.id == root_id)
|
||||
|
||||
return total_count
|
||||
|
||||
|
||||
# ==================== 复制 ====================
|
||||
|
||||
async def _copy_object_recursive(
|
||||
session: SessionDep,
|
||||
src: Object,
|
||||
dst_parent_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> tuple[int, list[UUID], int]:
|
||||
"""
|
||||
递归复制对象(内部实现)
|
||||
|
||||
:param session: 数据库会话
|
||||
:param src: 源对象
|
||||
:param dst_parent_id: 目标父目录UUID
|
||||
:param user_id: 用户UUID
|
||||
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
|
||||
"""
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
total_copied_size = 0
|
||||
|
||||
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
|
||||
src_is_folder = src.is_folder
|
||||
src_is_file = src.is_file
|
||||
src_id = src.id
|
||||
src_size = src.size
|
||||
src_physical_file_id = src.physical_file_id
|
||||
|
||||
# 创建新的 Object 记录
|
||||
new_obj = Object(
|
||||
name=src.name,
|
||||
type=src.type,
|
||||
size=src.size,
|
||||
password=src.password,
|
||||
parent_id=dst_parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=src.policy_id,
|
||||
physical_file_id=src.physical_file_id,
|
||||
)
|
||||
|
||||
# 如果是文件,增加物理文件引用计数
|
||||
if src_is_file and src_physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
|
||||
if physical_file:
|
||||
physical_file.increment_reference()
|
||||
physical_file = await physical_file.save(session)
|
||||
total_copied_size += src_size
|
||||
|
||||
new_obj = await new_obj.save(session)
|
||||
copied_count += 1
|
||||
new_ids.append(new_obj.id)
|
||||
|
||||
# 如果是目录,递归复制子对象
|
||||
if src_is_folder:
|
||||
children = await Object.get_children(session, user_id, src_id)
|
||||
for child in children:
|
||||
child_count, child_ids, child_size = await _copy_object_recursive(
|
||||
session, child, new_obj.id, user_id
|
||||
)
|
||||
copied_count += child_count
|
||||
new_ids.extend(child_ids)
|
||||
total_copied_size += child_size
|
||||
|
||||
return copied_count, new_ids, total_copied_size
|
||||
|
||||
|
||||
async def copy_object_recursive(
|
||||
session: SessionDep,
|
||||
src: Object,
|
||||
dst_parent_id: UUID,
|
||||
user_id: UUID,
|
||||
) -> tuple[int, list[UUID], int]:
|
||||
"""
|
||||
递归复制对象
|
||||
|
||||
对于文件:
|
||||
- 增加 PhysicalFile 引用计数
|
||||
- 创建新的 Object 记录指向同一 PhysicalFile
|
||||
|
||||
对于目录:
|
||||
- 创建新目录
|
||||
- 递归复制所有子对象
|
||||
|
||||
:param session: 数据库会话
|
||||
:param src: 源对象
|
||||
:param dst_parent_id: 目标父目录UUID
|
||||
:param user_id: 用户UUID
|
||||
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
|
||||
"""
|
||||
return await _copy_object_recursive(session, src, dst_parent_id, user_id)
|
||||
709
service/storage/s3_storage.py
Normal file
709
service/storage/s3_storage.py
Normal file
@@ -0,0 +1,709 @@
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
|
||||
|
||||
移植自 foxline-pro-backend-server 项目的 S3APIClient,
|
||||
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import ClassVar, Literal
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from yarl import URL
|
||||
from loguru import logger as l
|
||||
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import S3APIError, S3MultipartUploadError
|
||||
from .naming_rule import NamingContext, NamingRuleParser
|
||||
|
||||
|
||||
def _sign(key: bytes, msg: str) -> bytes:
|
||||
"""HMAC-SHA256 签名"""
|
||||
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
|
||||
|
||||
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
|
||||
|
||||
|
||||
class S3StorageService:
|
||||
"""
|
||||
S3 存储服务
|
||||
|
||||
使用 AWS Signature V4 签名的异步 S3 API 客户端。
|
||||
从 Policy 配置中读取 S3 连接信息。
|
||||
|
||||
使用示例::
|
||||
|
||||
service = S3StorageService(policy, region='us-east-1')
|
||||
await service.upload_file('path/to/file.txt', b'content')
|
||||
data = await service.download_file('path/to/file.txt')
|
||||
"""
|
||||
|
||||
_http_session: ClassVar[aiohttp.ClientSession | None] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: Policy,
|
||||
region: str = 'us-east-1',
|
||||
is_path_style: bool = False,
|
||||
):
|
||||
"""
|
||||
:param policy: 存储策略(server=endpoint_url, bucket_name, access_key, secret_key)
|
||||
:param region: S3 区域
|
||||
:param is_path_style: 是否使用路径风格 URL
|
||||
"""
|
||||
if not policy.server:
|
||||
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
|
||||
if not policy.bucket_name:
|
||||
raise S3APIError("S3 策略必须指定 bucket_name")
|
||||
if not policy.access_key:
|
||||
raise S3APIError("S3 策略必须指定 access_key")
|
||||
if not policy.secret_key:
|
||||
raise S3APIError("S3 策略必须指定 secret_key")
|
||||
|
||||
self._policy = policy
|
||||
self._endpoint_url = policy.server.rstrip("/")
|
||||
self._bucket_name = policy.bucket_name
|
||||
self._access_key = policy.access_key
|
||||
self._secret_key = policy.secret_key
|
||||
self._region = region
|
||||
self._is_path_style = is_path_style
|
||||
self._base_url = policy.base_url
|
||||
|
||||
# 从 endpoint_url 提取 host
|
||||
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
|
||||
|
||||
# ==================== 工厂方法 ====================
|
||||
|
||||
@classmethod
|
||||
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
|
||||
"""
|
||||
根据 Policy 异步创建 S3StorageService(自动加载 options)
|
||||
|
||||
:param policy: 存储策略
|
||||
:return: S3StorageService 实例
|
||||
"""
|
||||
options = await policy.awaitable_attrs.options
|
||||
region = options.s3_region if options else 'us-east-1'
|
||||
is_path_style = options.s3_path_style if options else False
|
||||
return cls(policy, region=region, is_path_style=is_path_style)
|
||||
|
||||
# ==================== HTTP Session 管理 ====================
|
||||
|
||||
@classmethod
|
||||
async def initialize_session(cls) -> None:
|
||||
"""初始化全局 aiohttp ClientSession"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
l.info("S3StorageService HTTP session 已初始化")
|
||||
|
||||
@classmethod
|
||||
async def close_session(cls) -> None:
|
||||
"""关闭全局 aiohttp ClientSession"""
|
||||
if cls._http_session and not cls._http_session.closed:
|
||||
await cls._http_session.close()
|
||||
cls._http_session = None
|
||||
l.info("S3StorageService HTTP session 已关闭")
|
||||
|
||||
@classmethod
|
||||
def _get_session(cls) -> aiohttp.ClientSession:
|
||||
"""获取 HTTP session"""
|
||||
if cls._http_session is None or cls._http_session.closed:
|
||||
# 懒初始化,以防 initialize_session 未被调用
|
||||
cls._http_session = aiohttp.ClientSession()
|
||||
return cls._http_session
|
||||
|
||||
# ==================== AWS Signature V4 签名 ====================
|
||||
|
||||
def _get_signature_key(self, date_stamp: str) -> bytes:
|
||||
"""生成 AWS Signature V4 签名密钥"""
|
||||
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
|
||||
k_region = _sign(k_date, self._region)
|
||||
k_service = _sign(k_region, "s3")
|
||||
return _sign(k_service, "aws4_request")
|
||||
|
||||
def _create_authorization_header(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str,
|
||||
headers: dict[str, str],
|
||||
payload_hash: str,
|
||||
amz_date: str,
|
||||
date_stamp: str,
|
||||
) -> str:
|
||||
"""创建 AWS Signature V4 授权头"""
|
||||
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
|
||||
canonical_headers = "".join(
|
||||
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
|
||||
)
|
||||
canonical_request = (
|
||||
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
|
||||
f"{signed_headers}\n{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
return (
|
||||
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
|
||||
f"SignedHeaders={signed_headers}, Signature={signature}"
|
||||
)
|
||||
|
||||
def _build_headers(
|
||||
self,
|
||||
method: str,
|
||||
uri: str,
|
||||
query_string: str = "",
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
payload_hash: str | None = None,
|
||||
host: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
构建包含 AWS V4 签名的完整请求头
|
||||
|
||||
:param method: HTTP 方法
|
||||
:param uri: 请求 URI
|
||||
:param query_string: 查询字符串
|
||||
:param payload: 请求体字节(用于计算哈希)
|
||||
:param content_type: Content-Type
|
||||
:param extra_headers: 额外请求头
|
||||
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
|
||||
:param host: Host 头(默认使用 self._host)
|
||||
"""
|
||||
now_utc = datetime.now(timezone.utc)
|
||||
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = now_utc.strftime("%Y%m%d")
|
||||
|
||||
if payload_hash is None:
|
||||
payload_hash = hashlib.sha256(payload).hexdigest()
|
||||
|
||||
effective_host = host or self._host
|
||||
|
||||
headers: dict[str, str] = {
|
||||
"Host": effective_host,
|
||||
"X-Amz-Date": amz_date,
|
||||
"X-Amz-Content-Sha256": payload_hash,
|
||||
}
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
if extra_headers:
|
||||
headers.update(extra_headers)
|
||||
|
||||
authorization = self._create_authorization_header(
|
||||
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
|
||||
)
|
||||
headers["Authorization"] = authorization
|
||||
return headers
|
||||
|
||||
# ==================== 内部请求方法 ====================
|
||||
|
||||
def _build_uri(self, key: str | None = None) -> str:
|
||||
"""
|
||||
构建请求 URI
|
||||
|
||||
按 AWS S3 Signature V4 规范对路径进行 URI 编码(S3 仅需一次)。
|
||||
斜杠作为路径分隔符保留不编码。
|
||||
"""
|
||||
if self._is_path_style:
|
||||
if key:
|
||||
return f"/{self._bucket_name}/{quote(key, safe='/')}"
|
||||
return f"/{self._bucket_name}"
|
||||
else:
|
||||
if key:
|
||||
return f"/{quote(key, safe='/')}"
|
||||
return "/"
|
||||
|
||||
def _build_url(self, uri: str, query_string: str = "") -> str:
|
||||
"""构建完整请求 URL"""
|
||||
if self._is_path_style:
|
||||
base = self._endpoint_url
|
||||
else:
|
||||
# 虚拟主机风格:bucket.endpoint
|
||||
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
|
||||
base = f"{protocol}{self._bucket_name}.{self._host}"
|
||||
|
||||
url = f"{base}{uri}"
|
||||
if query_string:
|
||||
url = f"{url}?{query_string}"
|
||||
return url
|
||||
|
||||
def _get_effective_host(self) -> str:
|
||||
"""获取实际请求的 Host 头"""
|
||||
if self._is_path_style:
|
||||
return self._host
|
||||
return f"{self._bucket_name}.{self._host}"
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
key: str | None = None,
|
||||
query_params: dict[str, str] | None = None,
|
||||
payload: bytes = b"",
|
||||
content_type: str | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""发送签名请求"""
|
||||
uri = self._build_uri(key)
|
||||
query_string = urlencode(sorted(query_params.items())) if query_params else ""
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method, uri, query_string, payload, content_type,
|
||||
extra_headers, host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri, query_string)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=payload if payload else None,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
|
||||
|
||||
async def _request_streaming(
|
||||
self,
|
||||
method: str,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""
|
||||
发送流式签名请求(大文件上传)
|
||||
|
||||
使用 UNSIGNED-PAYLOAD 作为 payload hash。
|
||||
"""
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
headers = self._build_headers(
|
||||
method,
|
||||
uri,
|
||||
query_string="",
|
||||
content_type=content_type,
|
||||
extra_headers={"Content-Length": str(content_length)},
|
||||
payload_hash="UNSIGNED-PAYLOAD",
|
||||
host=effective_host,
|
||||
)
|
||||
|
||||
url = self._build_url(uri)
|
||||
|
||||
try:
|
||||
response = await self._get_session().request(
|
||||
method, URL(url, encoded=True),
|
||||
headers=headers, data=data_stream,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
|
||||
|
||||
# ==================== 文件操作 ====================
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
key: str,
|
||||
data: bytes,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> None:
|
||||
"""
|
||||
上传文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data: 文件内容
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT", key=key, payload=data, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
|
||||
|
||||
async def upload_file_streaming(
|
||||
self,
|
||||
key: str,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
content_length: int,
|
||||
content_type: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
流式上传文件(大文件,避免全部加载到内存)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param data_stream: 异步字节流迭代器
|
||||
:param content_length: 数据总长度(必须准确)
|
||||
:param content_type: MIME 类型
|
||||
"""
|
||||
async with await self._request_streaming(
|
||||
"PUT", key=key, data_stream=data_stream,
|
||||
content_length=content_length, content_type=content_type,
|
||||
) as response:
|
||||
if response.status not in (200, 201):
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"流式上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
|
||||
|
||||
async def download_file(self, key: str) -> bytes:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件内容
|
||||
"""
|
||||
async with await self._request("GET", key=key) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"下载失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
data = await response.read()
|
||||
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
|
||||
return data
|
||||
|
||||
async def delete_file(self, key: str) -> None:
|
||||
"""
|
||||
删除文件
|
||||
|
||||
:param key: S3 对象键
|
||||
"""
|
||||
async with await self._request("DELETE", key=key) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
raise S3APIError(
|
||||
f"删除失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
async def file_exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查文件是否存在
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 是否存在
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
elif response.status == 404:
|
||||
return False
|
||||
else:
|
||||
raise S3APIError(
|
||||
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
|
||||
async def get_file_size(self, key: str) -> int:
|
||||
"""
|
||||
获取文件大小
|
||||
|
||||
:param key: S3 对象键
|
||||
:return: 文件大小(字节)
|
||||
"""
|
||||
async with await self._request("HEAD", key=key) as response:
|
||||
if response.status != 200:
|
||||
raise S3APIError(
|
||||
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
|
||||
)
|
||||
return int(response.headers.get("Content-Length", 0))
|
||||
|
||||
# ==================== Multipart Upload ====================
|
||||
|
||||
async def create_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
content_type: str = 'application/octet-stream',
|
||||
) -> str:
|
||||
"""
|
||||
创建分片上传任务
|
||||
|
||||
:param key: S3 对象键
|
||||
:param content_type: MIME 类型
|
||||
:return: Upload ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploads": ""},
|
||||
content_type=content_type,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
body = await response.text()
|
||||
root = ET.fromstring(body)
|
||||
|
||||
# 查找 UploadId 元素(支持命名空间)
|
||||
upload_id_elem = root.find("UploadId")
|
||||
if upload_id_elem is None:
|
||||
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
|
||||
if upload_id_elem is None or not upload_id_elem.text:
|
||||
raise S3MultipartUploadError(
|
||||
f"创建分片上传响应中未找到 UploadId: {body}"
|
||||
)
|
||||
|
||||
upload_id = upload_id_elem.text
|
||||
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
|
||||
return upload_id
|
||||
|
||||
async def upload_part(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
part_number: int,
|
||||
data: bytes,
|
||||
) -> str:
|
||||
"""
|
||||
上传单个分片
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param part_number: 分片编号(从 1 开始)
|
||||
:param data: 分片数据
|
||||
:return: ETag
|
||||
"""
|
||||
async with await self._request(
|
||||
"PUT",
|
||||
key=key,
|
||||
query_params={
|
||||
"partNumber": str(part_number),
|
||||
"uploadId": upload_id,
|
||||
},
|
||||
payload=data,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"上传分片失败: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, 状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
etag = response.headers.get("ETag", "").strip('"')
|
||||
l.debug(
|
||||
f"S3 分片上传成功: {self._bucket_name}/{key}, "
|
||||
f"part={part_number}, etag={etag}"
|
||||
)
|
||||
return etag
|
||||
|
||||
async def complete_multipart_upload(
|
||||
self,
|
||||
key: str,
|
||||
upload_id: str,
|
||||
parts: list[tuple[int, str]],
|
||||
) -> None:
|
||||
"""
|
||||
完成分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
:param parts: 分片列表 [(part_number, etag)]
|
||||
"""
|
||||
# 按 part_number 排序
|
||||
parts_sorted = sorted(parts, key=lambda p: p[0])
|
||||
|
||||
# 构建 CompleteMultipartUpload XML
|
||||
xml_parts = ''.join(
|
||||
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
|
||||
for pn, etag in parts_sorted
|
||||
)
|
||||
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
|
||||
payload_bytes = payload.encode('utf-8')
|
||||
|
||||
async with await self._request(
|
||||
"POST",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
payload=payload_bytes,
|
||||
content_type="application/xml",
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise S3MultipartUploadError(
|
||||
f"完成分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
l.info(
|
||||
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
|
||||
f"共 {len(parts)} 个分片"
|
||||
)
|
||||
|
||||
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
|
||||
"""
|
||||
取消分片上传
|
||||
|
||||
:param key: S3 对象键
|
||||
:param upload_id: 分片上传 ID
|
||||
"""
|
||||
async with await self._request(
|
||||
"DELETE",
|
||||
key=key,
|
||||
query_params={"uploadId": upload_id},
|
||||
) as response:
|
||||
if response.status in (200, 204):
|
||||
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
|
||||
else:
|
||||
body = await response.text()
|
||||
l.warning(
|
||||
f"取消分片上传失败: {self._bucket_name}/{key}, "
|
||||
f"状态: {response.status}, {body}"
|
||||
)
|
||||
|
||||
# ==================== 预签名 URL ====================
|
||||
|
||||
def generate_presigned_url(
|
||||
self,
|
||||
key: str,
|
||||
method: Literal['GET', 'PUT'] = 'GET',
|
||||
expires_in: int = 3600,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
生成 S3 预签名 URL(AWS Signature V4 Query String)
|
||||
|
||||
:param key: S3 对象键
|
||||
:param method: HTTP 方法(GET 下载,PUT 上传)
|
||||
:param expires_in: URL 有效期(秒)
|
||||
:param filename: 文件名(GET 请求时设置 Content-Disposition)
|
||||
:return: 预签名 URL
|
||||
"""
|
||||
current_time = datetime.now(timezone.utc)
|
||||
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
|
||||
date_stamp = current_time.strftime("%Y%m%d")
|
||||
|
||||
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
|
||||
credential = f"{self._access_key}/{credential_scope}"
|
||||
|
||||
uri = self._build_uri(key)
|
||||
effective_host = self._get_effective_host()
|
||||
|
||||
query_params: dict[str, str] = {
|
||||
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
|
||||
'X-Amz-Credential': credential,
|
||||
'X-Amz-Date': amz_date,
|
||||
'X-Amz-Expires': str(expires_in),
|
||||
'X-Amz-SignedHeaders': 'host',
|
||||
}
|
||||
|
||||
# GET 请求时添加 Content-Disposition
|
||||
if method == "GET" and filename:
|
||||
encoded_filename = quote(filename, safe='')
|
||||
query_params['response-content-disposition'] = (
|
||||
f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
)
|
||||
|
||||
canonical_query_string = "&".join(
|
||||
f"{quote(k, safe='')}={quote(v, safe='')}"
|
||||
for k, v in sorted(query_params.items())
|
||||
)
|
||||
|
||||
canonical_headers = f"host:{effective_host}\n"
|
||||
signed_headers = "host"
|
||||
payload_hash = "UNSIGNED-PAYLOAD"
|
||||
|
||||
canonical_request = (
|
||||
f"{method}\n"
|
||||
f"{uri}\n"
|
||||
f"{canonical_query_string}\n"
|
||||
f"{canonical_headers}\n"
|
||||
f"{signed_headers}\n"
|
||||
f"{payload_hash}"
|
||||
)
|
||||
|
||||
algorithm = "AWS4-HMAC-SHA256"
|
||||
string_to_sign = (
|
||||
f"{algorithm}\n"
|
||||
f"{amz_date}\n"
|
||||
f"{credential_scope}\n"
|
||||
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
|
||||
)
|
||||
|
||||
signing_key = self._get_signature_key(date_stamp)
|
||||
signature = hmac.new(
|
||||
signing_key, string_to_sign.encode(), hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
base_url = self._build_url(uri)
|
||||
return (
|
||||
f"{base_url}?"
|
||||
f"{canonical_query_string}&"
|
||||
f"X-Amz-Signature={signature}"
|
||||
)
|
||||
|
||||
# ==================== 路径生成 ====================
|
||||
|
||||
async def generate_file_path(
|
||||
self,
|
||||
user_id: UUID,
|
||||
original_filename: str,
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
根据命名规则生成 S3 文件存储路径
|
||||
|
||||
与 LocalStorageService.generate_file_path 接口一致。
|
||||
|
||||
:param user_id: 用户UUID
|
||||
:param original_filename: 原始文件名
|
||||
:return: (相对目录路径, 存储文件名, 完整存储路径)
|
||||
"""
|
||||
context = NamingContext(
|
||||
user_id=user_id,
|
||||
original_filename=original_filename,
|
||||
)
|
||||
|
||||
# 解析目录规则
|
||||
dir_path = ""
|
||||
if self._policy.dir_name_rule:
|
||||
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
|
||||
|
||||
# 解析文件名规则
|
||||
if self._policy.auto_rename and self._policy.file_name_rule:
|
||||
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
|
||||
# 确保有扩展名
|
||||
if '.' in original_filename and '.' not in storage_name:
|
||||
ext = original_filename.rsplit('.', 1)[1]
|
||||
storage_name = f"{storage_name}.{ext}"
|
||||
else:
|
||||
storage_name = original_filename
|
||||
|
||||
# S3 不需要创建目录,直接拼接路径
|
||||
if dir_path:
|
||||
storage_path = f"{dir_path}/{storage_name}"
|
||||
else:
|
||||
storage_path = storage_name
|
||||
|
||||
return dir_path, storage_name, storage_path
|
||||
@@ -1 +1 @@
|
||||
from .login import login
|
||||
from .login import unified_login
|
||||
|
||||
@@ -1,72 +1,429 @@
|
||||
from uuid import uuid4
|
||||
"""
|
||||
统一登录服务
|
||||
|
||||
from loguru import logger
|
||||
支持多种认证方式:邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信(预留)。
|
||||
"""
|
||||
import hashlib
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import LoginRequest, TokenResponse, User
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from service.redis.token_store import TokenStore
|
||||
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
|
||||
from utils import JWT, http_exceptions
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
|
||||
|
||||
async def login(
|
||||
session: SessionDep,
|
||||
login_request: LoginRequest,
|
||||
async def unified_login(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> TokenResponse:
|
||||
"""
|
||||
根据账号密码进行登录。
|
||||
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
|
||||
统一登录入口,根据 provider 分发到不同的登录逻辑。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param login_request: 登录请求
|
||||
|
||||
:return: TokenResponse 对象或状态码或 None
|
||||
:param request: 统一登录请求
|
||||
:return: TokenResponse
|
||||
"""
|
||||
# TODO: 验证码校验
|
||||
# captcha_setting = await Setting.get(
|
||||
# session,
|
||||
# (Setting.type == "auth") & (Setting.name == "login_captcha")
|
||||
# )
|
||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||
await _check_provider_enabled(session, request.provider)
|
||||
|
||||
# 获取用户信息
|
||||
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
|
||||
match request.provider:
|
||||
case AuthProviderType.EMAIL_PASSWORD:
|
||||
user = await _login_email_password(session, request)
|
||||
case AuthProviderType.GITHUB:
|
||||
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
|
||||
case AuthProviderType.QQ:
|
||||
user = await _login_oauth(session, request, AuthProviderType.QQ)
|
||||
case AuthProviderType.PASSKEY:
|
||||
user = await _login_passkey(session, request)
|
||||
case AuthProviderType.MAGIC_LINK:
|
||||
user = await _login_magic_link(session, request)
|
||||
case AuthProviderType.PHONE_SMS:
|
||||
http_exceptions.raise_not_implemented("短信登录暂未开放")
|
||||
case _:
|
||||
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
logger.debug(f"Cannot find user with username: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
return await _issue_tokens(session, user)
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
logger.debug(f"Password verification failed for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
|
||||
# 验证用户是否可登录
|
||||
if not current_user.status:
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
|
||||
"""检查认证方式是否已被站长启用"""
|
||||
# OAuth 类型从 OAUTH 设置中查询
|
||||
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
|
||||
setting_name = f"{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
return
|
||||
|
||||
# 检查两步验证
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
logger.debug(f"2FA required for user: {login_request.username}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
# 其他类型从 AUTH 设置中查询
|
||||
setting_name = f"auth_{provider.value}_enabled"
|
||||
setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
|
||||
)
|
||||
if not setting or setting.value != "1":
|
||||
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
|
||||
async def _login_email_password(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""邮箱+密码登录"""
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("密码不能为空")
|
||||
|
||||
# 查找 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
& (AuthIdentity.identifier == request.identifier),
|
||||
)
|
||||
if not identity:
|
||||
l.debug(f"未找到邮箱密码身份: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 验证密码
|
||||
if not identity.credential:
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
|
||||
l.debug(f"密码验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("邮箱或密码错误")
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
# 验证用户状态
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret)
|
||||
if identity.extra_data:
|
||||
import orjson
|
||||
extra: dict = orjson.loads(identity.extra_data)
|
||||
two_factor_secret: str | None = extra.get("two_factor")
|
||||
if two_factor_secret:
|
||||
if not request.two_fa_code:
|
||||
l.debug(f"需要两步验证: {request.identifier}")
|
||||
http_exceptions.raise_precondition_required("需要两步验证")
|
||||
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
|
||||
l.debug(f"两步验证失败: {request.identifier}")
|
||||
http_exceptions.raise_unauthorized("两步验证码错误")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_oauth(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
provider: AuthProviderType,
|
||||
) -> User:
|
||||
"""
|
||||
OAuth 登录(GitHub / QQ)
|
||||
|
||||
identifier 为 OAuth authorization code,后端换取 access_token 再获取用户信息。
|
||||
"""
|
||||
# 读取 OAuth 配置
|
||||
client_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
|
||||
)
|
||||
client_secret_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
|
||||
)
|
||||
if not client_id_setting or not client_secret_setting:
|
||||
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
|
||||
|
||||
client_id = client_id_setting.value or ""
|
||||
client_secret = client_secret_setting.value or ""
|
||||
|
||||
# 根据 provider 创建对应的 OAuth 客户端
|
||||
if provider == AuthProviderType.GITHUB:
|
||||
from service.oauth import GithubOAuth
|
||||
oauth_client = GithubOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(code=request.identifier)
|
||||
user_info_resp = await oauth_client.get_user_info(token_resp)
|
||||
openid = str(user_info_resp.user_data.id)
|
||||
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
|
||||
avatar_url = user_info_resp.user_data.avatar_url
|
||||
email = user_info_resp.user_data.email
|
||||
elif provider == AuthProviderType.QQ:
|
||||
from service.oauth import QQOAuth
|
||||
oauth_client = QQOAuth(client_id, client_secret)
|
||||
token_resp = await oauth_client.get_access_token(
|
||||
code=request.identifier,
|
||||
redirect_uri=request.redirect_uri or "",
|
||||
)
|
||||
openid_resp = await oauth_client.get_openid(token_resp.access_token)
|
||||
user_info_resp = await oauth_client.get_user_info(
|
||||
token_resp,
|
||||
app_id=client_id,
|
||||
openid=openid_resp.openid,
|
||||
)
|
||||
openid = openid_resp.openid
|
||||
nickname = user_info_resp.user_data.nickname
|
||||
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
|
||||
email = None
|
||||
else:
|
||||
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
|
||||
|
||||
# 查找已有 AuthIdentity
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
|
||||
)
|
||||
|
||||
if identity:
|
||||
# 已绑定 → 更新 OAuth 信息并返回关联用户
|
||||
identity.display_name = nickname
|
||||
identity.avatar_url = avatar_url
|
||||
identity = await identity.save(session)
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
return user
|
||||
|
||||
# 未绑定 → 自动注册
|
||||
user = await _auto_register_oauth_user(
|
||||
session,
|
||||
provider=provider,
|
||||
openid=openid,
|
||||
nickname=nickname,
|
||||
avatar_url=avatar_url,
|
||||
email=email,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def _auto_register_oauth_user(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
provider: AuthProviderType,
|
||||
openid: str,
|
||||
nickname: str | None,
|
||||
avatar_url: str | None,
|
||||
email: str | None,
|
||||
) -> User:
|
||||
"""OAuth 自动注册用户"""
|
||||
# 获取默认用户组
|
||||
default_group_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
|
||||
)
|
||||
if not default_group_setting or not default_group_setting.value:
|
||||
l.error("默认用户组未配置")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
|
||||
# 创建用户
|
||||
new_user = User(
|
||||
email=email,
|
||||
nickname=nickname,
|
||||
avatar=avatar_url or "default",
|
||||
group_id=default_group_id,
|
||||
)
|
||||
new_user_id = new_user.id
|
||||
new_user = await new_user.save(session)
|
||||
|
||||
# 创建 AuthIdentity
|
||||
identity = AuthIdentity(
|
||||
provider=provider,
|
||||
identifier=openid,
|
||||
display_name=nickname,
|
||||
avatar_url=avatar_url,
|
||||
is_primary=True,
|
||||
is_verified=True,
|
||||
user_id=new_user_id,
|
||||
)
|
||||
identity = await identity.save(session)
|
||||
|
||||
# 创建用户根目录
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
if default_policy:
|
||||
await Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
# 重新加载用户(含 group 关系)
|
||||
user: User = await User.get(session, User.id == new_user_id, load=User.group)
|
||||
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
|
||||
return user
|
||||
|
||||
|
||||
async def _login_passkey(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Passkey/WebAuthn 登录(Discoverable Credentials 模式)
|
||||
|
||||
identifier 为 challenge_token(前端从 ``POST /authn/options`` 获取),
|
||||
credential 为 JSON 格式的 authenticator assertion response。
|
||||
"""
|
||||
from webauthn import verify_authentication_response
|
||||
from webauthn.helpers import base64url_to_bytes
|
||||
|
||||
from service.redis.challenge_store import ChallengeStore
|
||||
from service.webauthn import get_rp_config
|
||||
from sqlmodels.user_authn import UserAuthn
|
||||
|
||||
if not request.credential:
|
||||
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
|
||||
|
||||
if not request.identifier:
|
||||
http_exceptions.raise_bad_request("challenge_token 不能为空")
|
||||
|
||||
# 从 ChallengeStore 取出 challenge(一次性,防重放)
|
||||
challenge: bytes | None = await ChallengeStore.retrieve_and_delete(f"auth:{request.identifier}")
|
||||
if challenge is None:
|
||||
http_exceptions.raise_unauthorized("登录会话已过期,请重新获取 options")
|
||||
|
||||
# 从 assertion JSON 中解析 credential_id(Discoverable Credentials 模式)
|
||||
import orjson
|
||||
credential_dict: dict = orjson.loads(request.credential)
|
||||
credential_id_b64: str | None = credential_dict.get("id")
|
||||
if not credential_id_b64:
|
||||
http_exceptions.raise_bad_request("缺少凭证 ID")
|
||||
|
||||
# 查找 UserAuthn 记录
|
||||
authn: UserAuthn | None = await UserAuthn.get(
|
||||
session,
|
||||
UserAuthn.credential_id == credential_id_b64,
|
||||
)
|
||||
if not authn:
|
||||
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
|
||||
|
||||
# 获取 RP 配置
|
||||
rp_id, _rp_name, origin = await get_rp_config(session)
|
||||
|
||||
# 验证 WebAuthn assertion
|
||||
try:
|
||||
verification = verify_authentication_response(
|
||||
credential=request.credential,
|
||||
expected_rp_id=rp_id,
|
||||
expected_origin=origin,
|
||||
expected_challenge=challenge,
|
||||
credential_public_key=base64url_to_bytes(authn.credential_public_key),
|
||||
credential_current_sign_count=authn.sign_count,
|
||||
)
|
||||
except Exception as e:
|
||||
l.warning(f"WebAuthn 验证失败: {e}")
|
||||
http_exceptions.raise_unauthorized("Passkey 验证失败")
|
||||
|
||||
# 更新签名计数
|
||||
authn.sign_count = verification.new_sign_count
|
||||
authn = await authn.save(session)
|
||||
|
||||
# 加载用户
|
||||
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _login_magic_link(
|
||||
session: AsyncSession,
|
||||
request: UnifiedLoginRequest,
|
||||
) -> User:
|
||||
"""
|
||||
Magic Link 登录
|
||||
|
||||
identifier 为签名 token,由 itsdangerous 生成。
|
||||
"""
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
http_exceptions.raise_unauthorized("Magic Link 已过期")
|
||||
except BadSignature:
|
||||
http_exceptions.raise_unauthorized("Magic Link 无效")
|
||||
|
||||
# 防重放:使用 token 哈希作为标识符
|
||||
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
|
||||
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
|
||||
if not is_first_use:
|
||||
http_exceptions.raise_unauthorized("Magic Link 已被使用")
|
||||
|
||||
# 查找绑定了该邮箱的 AuthIdentity(email_password 或 magic_link)
|
||||
identity: AuthIdentity | None = await AuthIdentity.get(
|
||||
session,
|
||||
(AuthIdentity.identifier == email)
|
||||
& (
|
||||
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
|
||||
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
|
||||
),
|
||||
)
|
||||
if not identity:
|
||||
http_exceptions.raise_unauthorized("该邮箱未注册")
|
||||
|
||||
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 标记邮箱已验证
|
||||
if not identity.is_verified:
|
||||
identity.is_verified = True
|
||||
identity = await identity.save(session)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
|
||||
"""
|
||||
签发 JWT 双令牌(access + refresh)
|
||||
|
||||
提取自原 login.py 的签发逻辑,供所有 provider 共用。
|
||||
"""
|
||||
# 加载 GroupOptions
|
||||
group_options: GroupOptions | None = await GroupOptions.get(
|
||||
session,
|
||||
GroupOptions.group_id == user.group_id,
|
||||
)
|
||||
|
||||
# 构建权限快照
|
||||
user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(user.group)
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
refresh_token = create_refresh_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
|
||||
41
service/webauthn.py
Normal file
41
service/webauthn.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
WebAuthn RP(Relying Party)配置辅助模块
|
||||
|
||||
从数据库 Setting 中读取 siteURL / siteTitle,
|
||||
解析出 rp_id、rp_name、origin,供注册/登录流程复用。
|
||||
"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
|
||||
|
||||
async def get_rp_config(session: AsyncSession) -> tuple[str, str, str]:
|
||||
"""
|
||||
获取 WebAuthn RP 配置。
|
||||
|
||||
:param session: 数据库会话
|
||||
:return: ``(rp_id, rp_name, origin)`` 元组
|
||||
|
||||
- ``rp_id``: 站点域名(从 siteURL 解析,如 ``example.com``)
|
||||
- ``rp_name``: 站点标题
|
||||
- ``origin``: 完整 origin(如 ``https://example.com``)
|
||||
"""
|
||||
site_url_setting: Setting | None = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||
)
|
||||
site_title_setting: Setting | None = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteTitle"),
|
||||
)
|
||||
|
||||
site_url: str = site_url_setting.value if site_url_setting and site_url_setting.value else "https://localhost"
|
||||
rp_name: str = site_title_setting.value if site_title_setting and site_title_setting.value else "DiskNext"
|
||||
|
||||
parsed = urlparse(site_url)
|
||||
rp_id: str = parsed.hostname or "localhost"
|
||||
origin: str = f"{parsed.scheme}://{parsed.netloc}" if parsed.netloc else site_url
|
||||
|
||||
return rp_id, rp_name, origin
|
||||
185
service/wopi/__init__.py
Normal file
185
service/wopi/__init__.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
WOPI Discovery 服务模块
|
||||
|
||||
解析 WOPI 服务端(Collabora / OnlyOffice 等)的 Discovery XML,
|
||||
提取支持的文件扩展名及对应的编辑器 URL 模板。
|
||||
|
||||
参考:Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
|
||||
"""
|
||||
import xml.etree.ElementTree as ET
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from loguru import logger as l
|
||||
|
||||
# WOPI URL 模板中已知的查询参数占位符及其替换值
|
||||
# 值为 None 表示删除该参数,非 None 表示替换为该值
|
||||
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
|
||||
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
|
||||
'BUSINESS_USER': None,
|
||||
'DC_LLCC': 'lng',
|
||||
'DISABLE_ASYNC': None,
|
||||
'DISABLE_CHAT': None,
|
||||
'EMBEDDED': 'true',
|
||||
'FULLSCREEN': 'true',
|
||||
'HOST_SESSION_ID': None,
|
||||
'SESSION_CONTEXT': None,
|
||||
'RECORDING': None,
|
||||
'THEME_ID': 'darkmode',
|
||||
'UI_LLCC': 'lng',
|
||||
'VALIDATOR_TEST_CATEGORY': None,
|
||||
}
|
||||
|
||||
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
|
||||
|
||||
|
||||
def process_wopi_action_url(raw_urlsrc: str) -> str:
|
||||
"""
|
||||
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
|
||||
|
||||
处理流程(参考 Cloudreve generateActionUrl):
|
||||
1. 去除 ``<>`` 占位符标记
|
||||
2. 解析查询参数,替换/删除已知占位符
|
||||
3. ``WOPI_SOURCE`` → ``{wopi_src}``
|
||||
|
||||
注意:access_token 和 access_token_ttl 不放在 URL 中,
|
||||
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
|
||||
|
||||
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
|
||||
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
|
||||
"""
|
||||
# 去除 <> 标记
|
||||
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
|
||||
parsed = urlparse(cleaned)
|
||||
raw_params = parse_qs(parsed.query, keep_blank_values=True)
|
||||
|
||||
new_params: list[tuple[str, str]] = []
|
||||
is_src_replaced = False
|
||||
|
||||
for key, values in raw_params.items():
|
||||
value = values[0] if values else ''
|
||||
|
||||
# WOPI_SOURCE 占位符 → {wopi_src}
|
||||
if value == _WOPI_SRC_PLACEHOLDER:
|
||||
new_params.append((key, '{wopi_src}'))
|
||||
is_src_replaced = True
|
||||
continue
|
||||
|
||||
# 已知占位符
|
||||
if value in _WOPI_QUERY_PLACEHOLDERS:
|
||||
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
|
||||
if replacement is not None:
|
||||
new_params.append((key, replacement))
|
||||
# replacement 为 None 时删除该参数
|
||||
continue
|
||||
|
||||
# 其他参数保留原值
|
||||
new_params.append((key, value))
|
||||
|
||||
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
|
||||
if not is_src_replaced:
|
||||
new_params.append(('WOPISrc', '{wopi_src}'))
|
||||
|
||||
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
|
||||
existing_keys = {k for k, _ in new_params}
|
||||
if 'lang' not in existing_keys:
|
||||
new_params.append(('lang', 'lng'))
|
||||
|
||||
# 注意:access_token 和 access_token_ttl 不放在 URL 中
|
||||
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
|
||||
|
||||
# 重建 URL
|
||||
new_query = urlencode(new_params, safe='{}')
|
||||
result = urlunparse((
|
||||
parsed.scheme,
|
||||
parsed.netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
new_query,
|
||||
'',
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
|
||||
"""
|
||||
解析 WOPI Discovery XML,提取扩展名到 URL 模板的映射。
|
||||
|
||||
XML 结构::
|
||||
|
||||
<wopi-discovery>
|
||||
<net-zone name="external-https">
|
||||
<app name="Writer" favIconUrl="...">
|
||||
<action name="edit" ext="docx" urlsrc="https://..."/>
|
||||
<action name="view" ext="docx" urlsrc="https://..."/>
|
||||
</app>
|
||||
</net-zone>
|
||||
</wopi-discovery>
|
||||
|
||||
动作优先级:edit > embedview > view(参考 Cloudreve discovery.go)
|
||||
|
||||
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
|
||||
:return: (action_urls, app_names) 元组
|
||||
action_urls: {extension: processed_url_template}
|
||||
app_names: 发现的应用名称列表
|
||||
:raises ValueError: XML 解析失败或格式无效
|
||||
"""
|
||||
try:
|
||||
root = ET.fromstring(xml_content)
|
||||
except ET.ParseError as e:
|
||||
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
|
||||
|
||||
# 查找 net-zone(可能有多个,取第一个非空的)
|
||||
net_zones = root.findall('net-zone')
|
||||
if not net_zones:
|
||||
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
|
||||
|
||||
# ext_actions: {extension: {action_name: urlsrc}}
|
||||
ext_actions: dict[str, dict[str, str]] = {}
|
||||
app_names: list[str] = []
|
||||
|
||||
for net_zone in net_zones:
|
||||
for app_elem in net_zone.findall('app'):
|
||||
app_name = app_elem.get('name', '')
|
||||
if app_name:
|
||||
app_names.append(app_name)
|
||||
|
||||
for action_elem in app_elem.findall('action'):
|
||||
action_name = action_elem.get('name', '')
|
||||
ext = action_elem.get('ext', '')
|
||||
urlsrc = action_elem.get('urlsrc', '')
|
||||
|
||||
if not ext or not urlsrc:
|
||||
continue
|
||||
|
||||
# 只关注 edit / embedview / view 三种动作
|
||||
if action_name not in ('edit', 'embedview', 'view'):
|
||||
continue
|
||||
|
||||
if ext not in ext_actions:
|
||||
ext_actions[ext] = {}
|
||||
ext_actions[ext][action_name] = urlsrc
|
||||
|
||||
# 为每个扩展名选择最佳 URL: edit > embedview > view
|
||||
action_urls: dict[str, str] = {}
|
||||
for ext, actions_map in ext_actions.items():
|
||||
selected_urlsrc: str | None = None
|
||||
for preferred in ('edit', 'embedview', 'view'):
|
||||
if preferred in actions_map:
|
||||
selected_urlsrc = actions_map[preferred]
|
||||
break
|
||||
|
||||
if selected_urlsrc:
|
||||
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
|
||||
|
||||
# 去重 app_names
|
||||
seen: set[str] = set()
|
||||
unique_names: list[str] = []
|
||||
for name in app_names:
|
||||
if name not in seen:
|
||||
seen.add(name)
|
||||
unique_names.append(name)
|
||||
|
||||
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
|
||||
|
||||
return action_urls, unique_names
|
||||
92
setup_cython.py
Normal file
92
setup_cython.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Cython 编译脚本 — 将 ee/ 下的纯逻辑文件编译为 .so
|
||||
|
||||
用法:
|
||||
uv run --extra build python setup_cython.py build_ext --inplace
|
||||
|
||||
编译规则:
|
||||
- 跳过 __init__.py(Python 包发现需要)
|
||||
- 只编译 .py 文件(纯函数 / 服务逻辑)
|
||||
|
||||
编译后清理(Pro Docker 构建用):
|
||||
uv run --extra build python setup_cython.py clean_artifacts
|
||||
"""
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
EE_DIR = Path("ee")
|
||||
|
||||
# 跳过 __init__.py —— 包发现需要原始 .py
|
||||
SKIP_NAMES = {"__init__.py"}
|
||||
|
||||
|
||||
def _collect_modules() -> list[str]:
|
||||
"""收集 ee/ 下需要编译的 .py 文件路径(点分模块名)。"""
|
||||
modules: list[str] = []
|
||||
for py_file in EE_DIR.rglob("*.py"):
|
||||
if py_file.name in SKIP_NAMES:
|
||||
continue
|
||||
# ee/license.py → ee.license
|
||||
module = str(py_file.with_suffix("")).replace("\\", "/").replace("/", ".")
|
||||
modules.append(module)
|
||||
return modules
|
||||
|
||||
|
||||
def clean_artifacts() -> None:
|
||||
"""删除已编译的 .py 源码、.c 中间文件和 build/ 目录。"""
|
||||
for py_file in EE_DIR.rglob("*.py"):
|
||||
if py_file.name in SKIP_NAMES:
|
||||
continue
|
||||
# 只删除有对应 .so / .pyd 的源文件
|
||||
parent = py_file.parent
|
||||
stem = py_file.stem
|
||||
has_compiled = (
|
||||
any(parent.glob(f"{stem}*.so")) or
|
||||
any(parent.glob(f"{stem}*.pyd"))
|
||||
)
|
||||
if has_compiled:
|
||||
py_file.unlink()
|
||||
print(f"已删除源码: {py_file}")
|
||||
|
||||
# 删除 .c 中间文件
|
||||
for c_file in EE_DIR.rglob("*.c"):
|
||||
c_file.unlink()
|
||||
print(f"已删除中间文件: {c_file}")
|
||||
|
||||
# 删除 build/ 目录
|
||||
build_dir = Path("build")
|
||||
if build_dir.exists():
|
||||
shutil.rmtree(build_dir)
|
||||
print(f"已删除: {build_dir}/")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "clean_artifacts":
|
||||
clean_artifacts()
|
||||
sys.exit(0)
|
||||
|
||||
# 动态导入(仅在编译时需要)
|
||||
from Cython.Build import cythonize
|
||||
from setuptools import Extension, setup
|
||||
|
||||
modules = _collect_modules()
|
||||
if not modules:
|
||||
print("未找到需要编译的模块")
|
||||
sys.exit(0)
|
||||
|
||||
print(f"即将编译以下模块: {modules}")
|
||||
|
||||
extensions = [
|
||||
Extension(mod, [mod.replace(".", "/") + ".py"])
|
||||
for mod in modules
|
||||
]
|
||||
|
||||
setup(
|
||||
name="disknext-ee",
|
||||
packages=[],
|
||||
ext_modules=cythonize(
|
||||
extensions,
|
||||
compiler_directives={'language_level': "3"},
|
||||
),
|
||||
)
|
||||
@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
|
||||
S3 = "s3" # S3 兼容存储
|
||||
```
|
||||
|
||||
### StorageType
|
||||
### PolicyType
|
||||
```python
|
||||
class StorageType(StrEnum):
|
||||
class PolicyType(StrEnum):
|
||||
LOCAL = "local" # 本地存储
|
||||
QINIU = "qiniu" # 七牛云
|
||||
TENCENT = "tencent" # 腾讯云
|
||||
ALIYUN = "aliyun" # 阿里云
|
||||
ONEDRIVE = "onedrive" # OneDrive
|
||||
GOOGLE_DRIVE = "google_drive" # Google Drive
|
||||
DROPBOX = "dropbox" # Dropbox
|
||||
WEBDAV = "webdav" # WebDAV
|
||||
REMOTE = "remote" # 远程存储
|
||||
S3 = "s3" # S3 兼容存储
|
||||
```
|
||||
|
||||
### UserStatus
|
||||
181
sqlmodels/__init__.py
Normal file
181
sqlmodels/__init__.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from .auth_identity import (
|
||||
AuthIdentity,
|
||||
AuthIdentityResponse,
|
||||
AuthProviderType,
|
||||
BindIdentityRequest,
|
||||
ChangePasswordRequest,
|
||||
)
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
JWTPayload,
|
||||
MagicLinkRequest,
|
||||
UnifiedLoginRequest,
|
||||
UnifiedRegisterRequest,
|
||||
RefreshTokenRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
UserStorageResponse,
|
||||
UserPublic,
|
||||
UserResponse,
|
||||
UserSettingResponse,
|
||||
UserThemeUpdateRequest,
|
||||
SettingOption,
|
||||
UserSettingUpdateRequest,
|
||||
WebAuthnInfo,
|
||||
UserTwoFactorResponse,
|
||||
# 管理员DTO
|
||||
UserAdminUpdateRequest,
|
||||
UserCalibrateResponse,
|
||||
UserAdminDetailResponse,
|
||||
)
|
||||
from .user_authn import (
|
||||
AuthnDetailResponse,
|
||||
AuthnFinishRequest,
|
||||
AuthnRenameRequest,
|
||||
UserAuthn,
|
||||
)
|
||||
from .color import ChromaticColor, NeutralColor, ThemeColorsBase, BUILTIN_DEFAULT_COLORS
|
||||
from .theme_preset import (
|
||||
ThemePreset, ThemePresetBase,
|
||||
ThemePresetCreateRequest, ThemePresetUpdateRequest,
|
||||
ThemePresetResponse, ThemePresetListResponse,
|
||||
)
|
||||
|
||||
from .download import (
|
||||
Download,
|
||||
DownloadAria2File,
|
||||
DownloadAria2Info,
|
||||
DownloadAria2InfoBase,
|
||||
DownloadStatus,
|
||||
DownloadType,
|
||||
)
|
||||
from .node import (
|
||||
Aria2Configuration,
|
||||
Aria2ConfigurationBase,
|
||||
Node,
|
||||
NodeStatus,
|
||||
NodeType,
|
||||
)
|
||||
from .group import (
|
||||
Group, GroupBase, GroupClaims, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
# 管理员DTO
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
|
||||
)
|
||||
from .object import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
ObjectBase,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
ObjectFileFinalize,
|
||||
ObjectMoveRequest,
|
||||
ObjectMoveUpdate,
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectResponse,
|
||||
ObjectSwitchPolicyRequest,
|
||||
ObjectType,
|
||||
FileCategory,
|
||||
PolicyResponse,
|
||||
UploadChunkResponse,
|
||||
UploadSession,
|
||||
UploadSessionBase,
|
||||
UploadSessionResponse,
|
||||
# 管理员DTO
|
||||
AdminFileResponse,
|
||||
AdminFileListResponse,
|
||||
FileBanRequest,
|
||||
# 回收站DTO
|
||||
TrashItemResponse,
|
||||
TrashRestoreRequest,
|
||||
TrashDeleteRequest,
|
||||
)
|
||||
from .object_metadata import (
|
||||
ObjectMetadata,
|
||||
ObjectMetadataBase,
|
||||
MetadataNamespace,
|
||||
MetadataResponse,
|
||||
MetadataPatchItem,
|
||||
MetadataPatchRequest,
|
||||
INTERNAL_NAMESPACES,
|
||||
USER_WRITABLE_NAMESPACES,
|
||||
)
|
||||
from .custom_property import (
|
||||
CustomPropertyDefinition,
|
||||
CustomPropertyDefinitionBase,
|
||||
CustomPropertyType,
|
||||
CustomPropertyCreateRequest,
|
||||
CustomPropertyUpdateRequest,
|
||||
CustomPropertyResponse,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .uri import DiskNextURI, FileSystemNamespace
|
||||
from .order import (
|
||||
Order, OrderStatus, OrderType,
|
||||
CreateOrderRequest, OrderResponse,
|
||||
)
|
||||
from .policy import (
|
||||
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
|
||||
PolicyType, PolicySummary, PolicyUpdateRequest,
|
||||
)
|
||||
from .product import (
|
||||
Product, ProductBase, ProductType, PaymentMethod,
|
||||
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
|
||||
)
|
||||
from .redeem import (
|
||||
Redeem, RedeemType,
|
||||
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
|
||||
)
|
||||
from .report import Report, ReportReason
|
||||
from .setting import (
|
||||
Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
|
||||
# 管理员DTO
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from .share import (
|
||||
Share, ShareBase, ShareCreateRequest, CreateShareResponse, ShareResponse,
|
||||
ShareOwnerInfo, ShareObjectItem, ShareDetailResponse,
|
||||
AdminShareListItem,
|
||||
)
|
||||
from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack, StoragePackResponse
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
|
||||
from .webdav import (
|
||||
WebDAV, WebDAVBase,
|
||||
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
|
||||
)
|
||||
from .file_app import (
|
||||
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
|
||||
# DTO
|
||||
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
|
||||
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
|
||||
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
|
||||
WopiDiscoveredExtension, WopiDiscoveryResponse,
|
||||
)
|
||||
from .wopi import WopiFileInfo, WopiAccessTokenPayload
|
||||
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
from .model_base import (
|
||||
MCPBase,
|
||||
MCPMethod,
|
||||
MCPRequestBase,
|
||||
MCPResponseBase,
|
||||
ResponseBase,
|
||||
# Admin Summary DTO
|
||||
MetricsSummary,
|
||||
LicenseInfo,
|
||||
VersionInfo,
|
||||
AdminSummaryResponse,
|
||||
)
|
||||
|
||||
# 通用分页模型
|
||||
from sqlmodel_ext import ListResponse
|
||||
148
sqlmodels/auth_identity.py
Normal file
148
sqlmodels/auth_identity.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
认证身份模块
|
||||
|
||||
一个用户可拥有多种登录方式(邮箱密码、OAuth、Passkey、Magic Link 等)。
|
||||
AuthIdentity 表存储每种认证方式的凭证信息。
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class AuthProviderType(StrEnum):
|
||||
"""认证提供者类型"""
|
||||
|
||||
EMAIL_PASSWORD = "email_password"
|
||||
"""邮箱+密码"""
|
||||
|
||||
PHONE_SMS = "phone_sms"
|
||||
"""手机号+短信验证码(预留)"""
|
||||
|
||||
GITHUB = "github"
|
||||
"""GitHub OAuth"""
|
||||
|
||||
QQ = "qq"
|
||||
"""QQ OAuth"""
|
||||
|
||||
PASSKEY = "passkey"
|
||||
"""Passkey/WebAuthn"""
|
||||
|
||||
MAGIC_LINK = "magic_link"
|
||||
"""邮箱 Magic Link"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class AuthIdentityResponse(SQLModelBase):
|
||||
"""认证身份响应 DTO(列表展示用)"""
|
||||
|
||||
id: UUID
|
||||
"""身份UUID"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str
|
||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||
|
||||
display_name: str | None = None
|
||||
"""显示名称(OAuth 昵称等)"""
|
||||
|
||||
avatar_url: str | None = None
|
||||
"""头像 URL"""
|
||||
|
||||
is_primary: bool = False
|
||||
"""是否主要身份"""
|
||||
|
||||
is_verified: bool = False
|
||||
"""是否已验证"""
|
||||
|
||||
|
||||
class BindIdentityRequest(SQLModelBase):
|
||||
"""绑定认证身份请求 DTO"""
|
||||
|
||||
provider: AuthProviderType
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: str
|
||||
"""标识符(邮箱/手机号/OAuth code)"""
|
||||
|
||||
credential: str | None = None
|
||||
"""凭证(密码、验证码等)"""
|
||||
|
||||
redirect_uri: str | None = None
|
||||
"""OAuth 回调地址"""
|
||||
|
||||
|
||||
class ChangePasswordRequest(SQLModelBase):
|
||||
"""修改密码请求 DTO"""
|
||||
|
||||
old_password: str = Field(min_length=1)
|
||||
"""当前密码"""
|
||||
|
||||
new_password: Str128 = Field(min_length=8)
|
||||
"""新密码(至少 8 位)"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""用户认证身份 — 一个用户可以有多种登录方式"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"),
|
||||
)
|
||||
|
||||
provider: AuthProviderType = Field(index=True)
|
||||
"""提供者类型"""
|
||||
|
||||
identifier: Str255 = Field(index=True)
|
||||
"""标识符(邮箱/手机号/OAuth openid)"""
|
||||
|
||||
credential: Text1024 | None = None
|
||||
"""凭证(Argon2 哈希密码 / null)"""
|
||||
|
||||
display_name: Str100 | None = None
|
||||
"""OAuth 昵称"""
|
||||
|
||||
avatar_url: str | None = Field(default=None, max_length=512)
|
||||
"""OAuth 头像 URL"""
|
||||
|
||||
extra_data: str | None = None
|
||||
"""JSON 附加数据(2FA secret、OAuth refresh_token 等)"""
|
||||
|
||||
is_primary: bool = False
|
||||
"""是否主要身份"""
|
||||
|
||||
is_verified: bool = False
|
||||
"""是否已验证"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="auth_identities")
|
||||
|
||||
def to_response(self) -> AuthIdentityResponse:
|
||||
"""转换为响应 DTO"""
|
||||
return AuthIdentityResponse(
|
||||
id=self.id,
|
||||
provider=self.provider,
|
||||
identifier=self.identifier,
|
||||
display_name=self.display_name,
|
||||
avatar_url=self.avatar_url,
|
||||
is_primary=self.is_primary,
|
||||
is_verified=self.is_verified,
|
||||
)
|
||||
71
sqlmodels/color.py
Normal file
71
sqlmodels/color.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlmodel_ext import SQLModelBase
|
||||
|
||||
|
||||
class ChromaticColor(StrEnum):
|
||||
"""有彩色枚举(17种 Tailwind 调色板颜色)"""
|
||||
|
||||
RED = "red"
|
||||
ORANGE = "orange"
|
||||
AMBER = "amber"
|
||||
YELLOW = "yellow"
|
||||
LIME = "lime"
|
||||
GREEN = "green"
|
||||
EMERALD = "emerald"
|
||||
TEAL = "teal"
|
||||
CYAN = "cyan"
|
||||
SKY = "sky"
|
||||
BLUE = "blue"
|
||||
INDIGO = "indigo"
|
||||
VIOLET = "violet"
|
||||
PURPLE = "purple"
|
||||
FUCHSIA = "fuchsia"
|
||||
PINK = "pink"
|
||||
ROSE = "rose"
|
||||
|
||||
|
||||
class NeutralColor(StrEnum):
|
||||
"""无彩色枚举(5种灰色调)"""
|
||||
|
||||
SLATE = "slate"
|
||||
GRAY = "gray"
|
||||
ZINC = "zinc"
|
||||
NEUTRAL = "neutral"
|
||||
STONE = "stone"
|
||||
|
||||
|
||||
class ThemeColorsBase(SQLModelBase):
|
||||
"""嵌套颜色 DTO,API 请求/响应层使用"""
|
||||
|
||||
primary: ChromaticColor
|
||||
"""主色调"""
|
||||
|
||||
secondary: ChromaticColor
|
||||
"""辅助色"""
|
||||
|
||||
success: ChromaticColor
|
||||
"""成功色"""
|
||||
|
||||
info: ChromaticColor
|
||||
"""信息色"""
|
||||
|
||||
warning: ChromaticColor
|
||||
"""警告色"""
|
||||
|
||||
error: ChromaticColor
|
||||
"""错误色"""
|
||||
|
||||
neutral: NeutralColor
|
||||
"""中性色"""
|
||||
|
||||
|
||||
BUILTIN_DEFAULT_COLORS = ThemeColorsBase(
|
||||
primary=ChromaticColor.GREEN,
|
||||
secondary=ChromaticColor.BLUE,
|
||||
success=ChromaticColor.GREEN,
|
||||
info=ChromaticColor.BLUE,
|
||||
warning=ChromaticColor.YELLOW,
|
||||
error=ChromaticColor.RED,
|
||||
neutral=NeutralColor.ZINC,
|
||||
)
|
||||
135
sqlmodels/custom_property.py
Normal file
135
sqlmodels/custom_property.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
用户自定义属性定义模型
|
||||
|
||||
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
|
||||
实际值通过 ObjectMetadata KV 表存储,键名格式:custom:{property_definition_id}。
|
||||
|
||||
支持的属性类型:text, number, boolean, select, multi_select, rating, link
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import JSON
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
# ==================== 枚举 ====================
|
||||
|
||||
class CustomPropertyType(StrEnum):
|
||||
"""自定义属性值类型枚举"""
|
||||
TEXT = "text"
|
||||
"""文本"""
|
||||
NUMBER = "number"
|
||||
"""数字"""
|
||||
BOOLEAN = "boolean"
|
||||
"""布尔值"""
|
||||
SELECT = "select"
|
||||
"""单选"""
|
||||
MULTI_SELECT = "multi_select"
|
||||
"""多选"""
|
||||
RATING = "rating"
|
||||
"""评分(1-5)"""
|
||||
LINK = "link"
|
||||
"""链接"""
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class CustomPropertyDefinitionBase(SQLModelBase):
|
||||
"""自定义属性定义基础模型"""
|
||||
|
||||
name: Str100
|
||||
"""属性显示名称"""
|
||||
|
||||
type: CustomPropertyType
|
||||
"""属性值类型"""
|
||||
|
||||
icon: Str100 | None = None
|
||||
"""图标标识(iconify 名称)"""
|
||||
|
||||
options: list[str] | None = Field(default=None, sa_type=JSON)
|
||||
"""可选值列表(仅 select/multi_select 类型)"""
|
||||
|
||||
default_value: str | None = Field(default=None, max_length=500)
|
||||
"""默认值"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
用户自定义属性定义
|
||||
|
||||
每个用户独立管理自己的属性模板。
|
||||
实际属性值存储在 ObjectMetadata 表中,键名格式:custom:{id}。
|
||||
"""
|
||||
|
||||
owner_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
ondelete="CASCADE",
|
||||
index=True,
|
||||
)
|
||||
"""所有者用户UUID"""
|
||||
|
||||
sort_order: int = 0
|
||||
"""排序顺序"""
|
||||
|
||||
# 关系
|
||||
owner: "User" = Relationship()
|
||||
"""所有者"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class CustomPropertyCreateRequest(SQLModelBase):
|
||||
"""创建自定义属性请求 DTO"""
|
||||
|
||||
name: Str100
|
||||
"""属性显示名称"""
|
||||
|
||||
type: CustomPropertyType
|
||||
"""属性值类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标标识"""
|
||||
|
||||
options: list[str] | None = None
|
||||
"""可选值列表(仅 select/multi_select 类型)"""
|
||||
|
||||
default_value: str | None = None
|
||||
"""默认值"""
|
||||
|
||||
|
||||
class CustomPropertyUpdateRequest(SQLModelBase):
|
||||
"""更新自定义属性请求 DTO"""
|
||||
|
||||
name: str | None = None
|
||||
"""属性显示名称"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标标识"""
|
||||
|
||||
options: list[str] | None = None
|
||||
"""可选值列表"""
|
||||
|
||||
default_value: str | None = None
|
||||
"""默认值"""
|
||||
|
||||
sort_order: int | None = None
|
||||
"""排序顺序"""
|
||||
|
||||
|
||||
class CustomPropertyResponse(CustomPropertyDefinitionBase):
|
||||
"""自定义属性响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""属性定义UUID"""
|
||||
|
||||
sort_order: int
|
||||
"""排序顺序"""
|
||||
78
sqlmodels/database_connection.py
Normal file
78
sqlmodels/database_connection.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import AsyncGenerator, ClassVar
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
engine: ClassVar[AsyncEngine | None] = None
|
||||
_async_session_factory: ClassVar[sessionmaker | None] = None
|
||||
|
||||
@classmethod
|
||||
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
|
||||
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
|
||||
async with cls._async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
@classmethod
|
||||
async def init(
|
||||
cls,
|
||||
database_url: str,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化数据库连接引擎。
|
||||
|
||||
:param database_url: 数据库连接URL
|
||||
:param debug: 是否开启调试模式
|
||||
"""
|
||||
# 构建引擎参数
|
||||
engine_kwargs: dict = {
|
||||
'echo': debug,
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if debug:
|
||||
# Debug 模式使用 NullPool(无连接池,每次创建新连接)
|
||||
engine_kwargs['poolclass'] = NullPool
|
||||
else:
|
||||
# 生产模式使用 AsyncAdaptedQueuePool 连接池
|
||||
engine_kwargs.update({
|
||||
'poolclass': AsyncAdaptedQueuePool,
|
||||
'pool_size': 40,
|
||||
'max_overflow': 80,
|
||||
'pool_timeout': 30,
|
||||
'pool_recycle': 1800,
|
||||
'pool_pre_ping': True,
|
||||
})
|
||||
|
||||
# 只在需要时添加 connect_args
|
||||
if database_url.startswith("sqlite"):
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
cls.engine = create_async_engine(database_url, **engine_kwargs)
|
||||
|
||||
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
|
||||
|
||||
# 开发阶段直接 create_all 创建表结构
|
||||
async with cls.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("数据库引擎初始化完成")
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
"""
|
||||
优雅地关闭数据库连接引擎。
|
||||
仅应在应用结束时调用。
|
||||
"""
|
||||
if cls.engine:
|
||||
logger.info("正在关闭数据库连接引擎...")
|
||||
await cls.engine.dispose()
|
||||
logger.info("数据库连接引擎已成功关闭。")
|
||||
else:
|
||||
logger.info("数据库连接引擎未初始化,无需关闭。")
|
||||
@@ -4,8 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin, TableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -142,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
|
||||
speed: int = Field(default=0)
|
||||
"""下载速度(bytes/s)"""
|
||||
|
||||
parent: str | None = Field(default=None, max_length=255)
|
||||
parent: Str255 | None = None
|
||||
"""父任务标识"""
|
||||
|
||||
error: str | None = Field(default=None)
|
||||
435
sqlmodels/file_app.py
Normal file
435
sqlmodels/file_app.py
Normal file
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
文件查看器应用模块
|
||||
|
||||
提供文件预览应用选择器系统的数据模型和 DTO。
|
||||
类似 Android 的"使用什么应用打开"机制:
|
||||
- 管理员注册应用(内置/iframe/WOPI)
|
||||
- 用户按扩展名查询可用查看器
|
||||
- 用户可设置"始终使用"偏好
|
||||
- 支持用户组级别的访问控制
|
||||
|
||||
架构:
|
||||
FileApp (应用注册表)
|
||||
├── FileAppExtension (扩展名关联)
|
||||
├── FileAppGroupLink (用户组访问控制)
|
||||
└── UserFileAppDefault (用户默认偏好)
|
||||
"""
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
|
||||
|
||||
# ==================== 枚举 ====================
|
||||
|
||||
class FileAppType(StrEnum):
|
||||
"""文件应用类型"""
|
||||
|
||||
BUILTIN = "builtin"
|
||||
"""前端内置查看器(如 pdf.js, Monaco)"""
|
||||
|
||||
IFRAME = "iframe"
|
||||
"""iframe 内嵌第三方服务"""
|
||||
|
||||
WOPI = "wopi"
|
||||
"""WOPI 协议(OnlyOffice / Collabora)"""
|
||||
|
||||
|
||||
# ==================== Link 表 ====================
|
||||
|
||||
class FileAppGroupLink(SQLModelBase, table=True):
|
||||
"""应用-用户组访问控制关联表"""
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", primary_key=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
group_id: UUID = Field(foreign_key="group.id", primary_key=True, ondelete="CASCADE")
|
||||
"""关联的用户组UUID"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class FileAppSummary(SQLModelBase):
|
||||
"""查看器列表项 DTO,用于选择器弹窗"""
|
||||
|
||||
id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
name: str
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = None
|
||||
"""应用描述"""
|
||||
|
||||
iframe_url_template: str | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_editor_url_template: str | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
|
||||
class FileViewersResponse(SQLModelBase):
|
||||
"""查看器查询响应 DTO"""
|
||||
|
||||
viewers: list[FileAppSummary] = []
|
||||
"""可用查看器列表(已按 priority 排序)"""
|
||||
|
||||
default_viewer_id: UUID | None = None
|
||||
"""用户默认查看器UUID(如果已设置"始终使用")"""
|
||||
|
||||
|
||||
class SetDefaultViewerRequest(SQLModelBase):
|
||||
"""设置默认查看器请求 DTO"""
|
||||
|
||||
extension: str = Field(max_length=20)
|
||||
"""文件扩展名(小写,无点号)"""
|
||||
|
||||
app_id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
|
||||
class UserFileAppDefaultResponse(SQLModelBase):
|
||||
"""用户默认查看器响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""记录UUID"""
|
||||
|
||||
extension: str
|
||||
"""扩展名"""
|
||||
|
||||
app: FileAppSummary
|
||||
"""关联的应用摘要"""
|
||||
|
||||
|
||||
class FileAppCreateRequest(SQLModelBase):
|
||||
"""管理员创建应用请求 DTO"""
|
||||
|
||||
name: Str100
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50)
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
extensions: list[str] = []
|
||||
"""关联的扩展名列表"""
|
||||
|
||||
allowed_group_ids: list[UUID] = []
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
|
||||
class FileAppUpdateRequest(SQLModelBase):
|
||||
"""管理员更新应用请求 DTO(所有字段可选)"""
|
||||
|
||||
name: Str100 | None = None
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str | None = Field(default=None, max_length=50)
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType | None = None
|
||||
"""应用类型"""
|
||||
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool | None = None
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool | None = None
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
|
||||
class FileAppResponse(SQLModelBase):
|
||||
"""管理员应用详情响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""应用UUID"""
|
||||
|
||||
name: str
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str
|
||||
"""应用唯一标识"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: str | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = None
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: str | None = None
|
||||
"""iframe URL 模板"""
|
||||
|
||||
wopi_discovery_url: str | None = None
|
||||
"""WOPI 发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: str | None = None
|
||||
"""WOPI 编辑器 URL 模板"""
|
||||
|
||||
extensions: list[str] = []
|
||||
"""关联的扩展名列表"""
|
||||
|
||||
allowed_group_ids: list[UUID] = []
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
@classmethod
|
||||
def from_app(
|
||||
cls,
|
||||
app: "FileApp",
|
||||
extensions: list["FileAppExtension"],
|
||||
group_links: list[FileAppGroupLink],
|
||||
) -> "FileAppResponse":
|
||||
"""从 ORM 对象构建 DTO"""
|
||||
return cls(
|
||||
id=app.id,
|
||||
name=app.name,
|
||||
app_key=app.app_key,
|
||||
type=app.type,
|
||||
icon=app.icon,
|
||||
description=app.description,
|
||||
is_enabled=app.is_enabled,
|
||||
is_restricted=app.is_restricted,
|
||||
iframe_url_template=app.iframe_url_template,
|
||||
wopi_discovery_url=app.wopi_discovery_url,
|
||||
wopi_editor_url_template=app.wopi_editor_url_template,
|
||||
extensions=[ext.extension for ext in extensions],
|
||||
allowed_group_ids=[link.group_id for link in group_links],
|
||||
)
|
||||
|
||||
|
||||
class FileAppListResponse(SQLModelBase):
|
||||
"""管理员应用列表响应 DTO"""
|
||||
|
||||
apps: list[FileAppResponse] = []
|
||||
"""应用列表"""
|
||||
|
||||
total: int = 0
|
||||
"""总数"""
|
||||
|
||||
|
||||
class ExtensionUpdateRequest(SQLModelBase):
|
||||
"""扩展名全量替换请求 DTO"""
|
||||
|
||||
extensions: list[str]
|
||||
"""扩展名列表(小写,无点号)"""
|
||||
|
||||
|
||||
class GroupAccessUpdateRequest(SQLModelBase):
|
||||
"""用户组权限全量替换请求 DTO"""
|
||||
|
||||
group_ids: list[UUID]
|
||||
"""允许访问的用户组UUID列表"""
|
||||
|
||||
|
||||
class WopiSessionResponse(SQLModelBase):
|
||||
"""WOPI 会话响应 DTO"""
|
||||
|
||||
wopi_src: str
|
||||
"""WOPI 源 URL"""
|
||||
|
||||
access_token: str
|
||||
"""WOPI 访问令牌"""
|
||||
|
||||
access_token_ttl: int
|
||||
"""令牌过期时间戳(毫秒,WOPI 规范要求)"""
|
||||
|
||||
editor_url: str
|
||||
"""完整的编辑器 URL"""
|
||||
|
||||
|
||||
class WopiDiscoveredExtension(SQLModelBase):
|
||||
"""单个 WOPI Discovery 发现的扩展名"""
|
||||
|
||||
extension: str
|
||||
"""文件扩展名"""
|
||||
|
||||
action_url: str
|
||||
"""处理后的动作 URL 模板"""
|
||||
|
||||
|
||||
class WopiDiscoveryResponse(SQLModelBase):
|
||||
"""WOPI Discovery 结果响应 DTO"""
|
||||
|
||||
discovered_extensions: list[WopiDiscoveredExtension] = []
|
||||
"""发现的扩展名及其 URL 模板"""
|
||||
|
||||
app_names: list[str] = []
|
||||
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress)"""
|
||||
|
||||
applied_count: int = 0
|
||||
"""已应用到 FileAppExtension 的数量"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class FileApp(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""文件查看器应用注册表"""
|
||||
|
||||
name: Str100
|
||||
"""应用名称"""
|
||||
|
||||
app_key: str = Field(max_length=50, unique=True, index=True)
|
||||
"""应用唯一标识,前端路由用"""
|
||||
|
||||
type: FileAppType
|
||||
"""应用类型"""
|
||||
|
||||
icon: Str255 | None = None
|
||||
"""图标名称/URL"""
|
||||
|
||||
description: str | None = Field(default=None, max_length=500)
|
||||
"""应用描述"""
|
||||
|
||||
is_enabled: bool = True
|
||||
"""是否启用"""
|
||||
|
||||
is_restricted: bool = False
|
||||
"""是否限制用户组访问"""
|
||||
|
||||
iframe_url_template: Text1024 | None = None
|
||||
"""iframe URL 模板,支持 {file_url} 占位符"""
|
||||
|
||||
wopi_discovery_url: str | None = Field(default=None, max_length=512)
|
||||
"""WOPI 客户端发现端点 URL"""
|
||||
|
||||
wopi_editor_url_template: Text1024 | None = None
|
||||
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||
|
||||
# 关系
|
||||
extensions: list["FileAppExtension"] = Relationship(
|
||||
back_populates="app",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
user_defaults: list["UserFileAppDefault"] = Relationship(
|
||||
back_populates="app",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
allowed_groups: list["Group"] = Relationship(
|
||||
link_model=FileAppGroupLink,
|
||||
)
|
||||
|
||||
def to_summary(self) -> FileAppSummary:
|
||||
"""转换为摘要 DTO"""
|
||||
return FileAppSummary(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
app_key=self.app_key,
|
||||
type=self.type,
|
||||
icon=self.icon,
|
||||
description=self.description,
|
||||
iframe_url_template=self.iframe_url_template,
|
||||
wopi_editor_url_template=self.wopi_editor_url_template,
|
||||
)
|
||||
|
||||
|
||||
class FileAppExtension(SQLModelBase, TableBaseMixin):
|
||||
"""扩展名关联表"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("app_id", "extension", name="uq_fileappextension_app_extension"),
|
||||
)
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
extension: str = Field(max_length=20, index=True)
|
||||
"""扩展名(小写,无点号)"""
|
||||
|
||||
priority: int = Field(default=0, ge=0)
|
||||
"""排序优先级(越小越优先)"""
|
||||
|
||||
wopi_action_url: str | None = Field(default=None, max_length=2048)
|
||||
"""WOPI 动作 URL 模板(Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
|
||||
|
||||
# 关系
|
||||
app: FileApp = Relationship(back_populates="extensions")
|
||||
|
||||
|
||||
class UserFileAppDefault(SQLModelBase, UUIDTableBaseMixin):
|
||||
"""用户"始终使用"偏好"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("user_id", "extension", name="uq_userfileappdefault_user_extension"),
|
||||
)
|
||||
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
|
||||
"""用户UUID"""
|
||||
|
||||
extension: str = Field(max_length=20)
|
||||
"""扩展名(小写,无点号)"""
|
||||
|
||||
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
|
||||
"""关联的应用UUID"""
|
||||
|
||||
# 关系
|
||||
app: FileApp = Relationship(back_populates="user_defaults")
|
||||
|
||||
def to_response(self) -> UserFileAppDefaultResponse:
|
||||
"""转换为响应 DTO(需预加载 app 关系)"""
|
||||
return UserFileAppDefaultResponse(
|
||||
id=self.id,
|
||||
extension=self.extension,
|
||||
app=self.app.to_summary(),
|
||||
)
|
||||
@@ -2,10 +2,10 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -67,7 +67,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
|
||||
class GroupCreateRequest(GroupAllOptionsBase):
|
||||
"""创建用户组请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
name: Str255
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int = Field(default=0, ge=0)
|
||||
@@ -92,7 +92,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
|
||||
class GroupUpdateRequest(SQLModelBase):
|
||||
"""更新用户组请求 DTO(所有字段可选)"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=255)
|
||||
name: Str255 | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int | None = Field(default=None, ge=0)
|
||||
@@ -188,6 +188,28 @@ class GroupListResponse(SQLModelBase):
|
||||
"""总数"""
|
||||
|
||||
|
||||
class GroupClaims(GroupCoreBase, GroupAllOptionsBase):
|
||||
"""
|
||||
JWT 中的用户组权限快照。
|
||||
|
||||
复用 GroupCoreBase(id, name, max_storage, share_enabled, web_dav_enabled, admin, speed_limit)
|
||||
和 GroupAllOptionsBase(share_download, share_free, ... 共 11 个功能开关)。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_group(cls, group: "Group") -> "GroupClaims":
|
||||
"""
|
||||
从 Group ORM 对象(需预加载 options 关系)构建权限快照。
|
||||
|
||||
:param group: 已加载 options 的 Group 对象
|
||||
"""
|
||||
opts = group.options
|
||||
return cls(
|
||||
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
|
||||
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
|
||||
)
|
||||
|
||||
|
||||
class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
"""用户组响应 DTO"""
|
||||
|
||||
@@ -236,10 +258,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
|
||||
class Group(GroupBase, UUIDTableBaseMixin):
|
||||
"""用户组模型"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
name: Str255 = Field(unique=True)
|
||||
"""用户组名"""
|
||||
|
||||
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""最大存储空间(字节)"""
|
||||
|
||||
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user