Compare commits

..

22 Commits

Author SHA1 Message Date
15b2efe52a fix: 修复 update_group_access 中 app 变量未赋值的问题
All checks were successful
Test / test (push) Successful in 2m34s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 11:31:02 +08:00
6c96c43bea refactor: 统一 sqlmodel_ext 用法至官方推荐模式
Some checks failed
Test / test (push) Failing after 3m47s
- 替换 Field(max_length=X) 为 StrX/TextX 类型别名(21 个 sqlmodels 文件)
- 替换 get + 404 检查为 get_exist_one()(17 个路由文件,约 50 处)
- 替换 save + session.refresh 为 save(load=...)
- 替换 session.add + commit 为 save()(dav/provider.py)
- 更新所有依赖至最新版本

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 11:13:16 +08:00
9185f26b83 feat: 添加 EPUB 阅读器、3D 模型预览和字体查看器应用,启用 Office 在线预览
All checks were successful
Test / test (push) Successful in 2m31s
2026-02-26 12:50:24 +08:00
f4052d229a fix: clean up empty parent directories after file deletion
All checks were successful
Test / test (push) Successful in 2m32s
Prevent local storage fragmentation by removing empty directories
left behind when files are permanently deleted or moved to trash.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:44 +08:00
bc2182720d feat: implement avatar upload, Gravatar support, and avatar settings
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:24 +08:00
eddf38d316 chore: remove applied migration script
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:07 +08:00
03e768d232 chore: update .gitignore for avatar and dev directories
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-25 15:56:00 +08:00
bcb0a9b322 feat: redesign metadata as KV store, add custom properties and WOPI Discovery
Some checks failed
Test / test (push) Failing after 2m32s
Replace one-to-one FileMetadata table with flexible ObjectMetadata KV pairs,
add custom property definitions, WOPI Discovery auto-configuration, and
per-extension action URL support.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 17:21:22 +08:00
743a2c9d65 fix: use TaskStatus/TaskType enums in TaskDetailResponse
Some checks failed
Test / test (push) Failing after 2m17s
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 16:44:01 +08:00
3639a31163 feat: add S3 storage support, policy migration, and quota enforcement
Some checks failed
Test / test (push) Failing after 2m21s
- Add S3StorageService with AWS Signature V4 signing (URI-encoded for non-ASCII keys)
- Add PATCH /object/{id}/policy endpoint for switching storage policies with background migration
- Implement cross-storage file migration service (local <-> S3)
- Replace deprecated StorageType enum with PolicyType (local/s3)
- Implement GET /user/settings/policies endpoint (was 501 stub)
- Add storage quota pre-allocation on upload session creation to prevent concurrent bypass
- Fix BigInteger for max_storage and user.storage to support >2GB values
- Add policy permission validation on upload and directory creation
- Use group's first policy as default on registration instead of hardcoded name
- Define TaskType.POLICY_MIGRATE and extend TaskProps with migration fields

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-23 13:38:20 +08:00
7200df6d87 fix: patch storage quota bypass and harden auth security
All checks were successful
Test / test (push) Successful in 2m11s
- Fix WebDAV chunked PUT bypassing storage quota when remaining_quota <= 0
- Add QuotaLimitedWriter to enforce quota during streaming writes
- Clean up residual files on write failure in end_write()
- Add Magic Link replay attack prevention via TokenStore
- Reject startup when JWT SECRET_KEY is not configured
- Sanitize OAuth callback and Magic Link log output

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 22:20:43 +08:00
40b6a31c98 feat: implement WebDAV protocol support with WsgiDAV + account management API
All checks were successful
Test / test (push) Successful in 2m14s
Add complete WebDAV support: management REST API (CRUD accounts at /api/v1/webdav/accounts)
and DAV protocol endpoint (/dav) using WsgiDAV + a2wsgi bridge for client access via
HTTP Basic Auth. Includes Redis+TTLCache auth caching and integration tests (24 cases).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 15:19:29 +08:00
19837b4817 refactor: extract ee/ into private submodule (disknext-ee)
All checks were successful
Test / test (push) Successful in 1m54s
Enterprise Edition code is now hosted in a separate private repository
and linked as a git submodule. Community Edition runs without it via
ImportError fallback in main.py.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-17 11:14:45 +08:00
b5d09009e3 feat: implement source link endpoints and enforce policy rules
- Add POST/GET source link endpoints for file sharing via permanent URLs
- Enforce max_size check in PATCH /file/content to prevent size limit bypass
- Support is_private (proxy) vs public (302 redirect) storage modes
- Replace all ResponseBase(data=...) with proper DTOs or 204 responses
- Add 18 integration tests for source link and policy rule enforcement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 17:07:20 +08:00
0b521ae8ab feat: add PATCH /user/settings/password endpoint for changing password
Register the fixed /password route before the wildcard /{option} to
prevent FastAPI from matching it as a path parameter.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 15:11:56 +08:00
eac0766e79 feat: migrate ORM base to sqlmodel-ext, add file viewers and WOPI integration
- Migrate SQLModel base classes, mixins, and database management to
  external sqlmodel-ext package; remove sqlmodels/base/, sqlmodels/mixin/,
  and sqlmodels/database.py
- Add file viewer/editor system with WOPI protocol support for
  collaborative editing (OnlyOffice, Collabora)
- Add enterprise edition license verification module (ee/)
- Add Dockerfile multi-stage build with Cython compilation support
- Add new dependencies: sqlmodel-ext, cryptography, whatthepatch

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-14 14:23:17 +08:00
53b757de7a fix: use container image for Gitea CI to provide Node.js
All checks were successful
Test / test (push) Successful in 2m15s
The act_runner doesn't have Node.js in PATH, which is required
by actions/checkout@v4. Use catthehacker/ubuntu:act-latest container.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 14:24:22 +08:00
69f852a4ce fix: align all 212 tests with current API and add CI workflows
Some checks failed
Test / test (push) Failing after 1m4s
Update integration tests to match actual endpoint responses: remove
data wrappers, use snake_case fields, correct HTTP methods (PUT→POST
for directory create), status codes (200→204 for mutations), and
request formats (params→json for 2FA). Fix root-level and unit tests
for DatabaseManager migration, model CRUD patterns, and JWT setup.
Add GitHub Actions and Gitea CI configs with ubuntu-latest + Python 3.13.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 14:21:40 +08:00
800c85bf8d feat: implement WebAuthn credential registration, login verification, and management
Complete the WebAuthn/Passkey flow that was previously stubbed out:
- Add ChallengeStore (Redis + TTLCache fallback) for challenge lifecycle
- Add RP config helper to extract rp_id/origin from site settings
- Fix registration start (exclude_credentials, user_id, challenge storage)
- Implement registration finish (verify + create UserAuthn & AuthIdentity)
- Add authentication options endpoint for Discoverable Credentials login
- Fix passkey login to use challenge_token and base64url encoding
- Add credential management endpoints (list/rename/delete)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-13 12:56:46 +08:00
729773cae3 feat: add multi-provider auth via AuthIdentity and extend site config
- Extract AuthIdentity model for multi-provider authentication (email_password, OAuth, Passkey, Magic Link)
- Remove password field from User model, credentials now stored in AuthIdentity
- Refactor unified login/register to use AuthIdentity-based provider checking
- Add site config fields: footer_code, tos_url, privacy_url, auth_methods
- Add auth settings defaults in migration (email_password enabled by default)
- Update admin user creation to create AuthIdentity records
- Update all tests to use AuthIdentity model

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 22:49:12 +08:00
d831c9c0d6 feat: implement PATCH /user/settings/{option} and fix timezone range to UTC-12~+14
- Add SettingOption StrEnum (nickname/language/timezone) for path param validation
- Add UserSettingUpdateRequest DTO with Pydantic constraints
- Implement endpoint: extract value by option name, validate non-null for required fields
- Fix timezone upper bound from 12 to 14 (UTC+14 exists, e.g. Line Islands)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 20:15:35 +08:00
4c1b7a8aad feat: add theme preset system with admin CRUD, public listing, and user theme settings
- Add ChromaticColor (17 Tailwind colors) and NeutralColor (5 grays) enums
- Add ThemePreset table with flat color columns and unique name constraint
- Add admin theme endpoints (CRUD + set default) at /api/v1/admin/theme
- Add public theme listing at /api/v1/site/themes
- Add user theme settings (PATCH /theme) with color snapshot on User model
- User.color_* columns store per-user overrides; fallback to default preset then builtin
- Initialize default theme preset in migration
- Remove legacy defaultTheme/themes settings

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-12 19:34:41 +08:00
146 changed files with 16654 additions and 7574 deletions

View File

@@ -5,7 +5,8 @@
"Bash(findstr:*)", "Bash(findstr:*)",
"Bash(find:*)", "Bash(find:*)",
"Bash(yarn tsc:*)", "Bash(yarn tsc:*)",
"Bash(dir:*)" "Bash(dir:*)",
"mcp__server-notify__notify"
] ]
} }
} }

37
.dockerignore Normal file
View File

@@ -0,0 +1,37 @@
.git/
.gitignore
.github/
.idea/
.vscode/
.venv/
.env
.env.*
.run/
.claude/
__pycache__/
*.py[cod]
tests/
htmlcov/
.pytest_cache/
.coverage
coverage.xml
*.db
*.sqlite
*.sqlite3
*.log
logs/
data/
Dockerfile
.dockerignore
# Cython 编译产物
*.c
build/
# 许可证私钥和工具脚本
license_private.pem
scripts/

31
.gitea/workflows/test.yml Normal file
View File

@@ -0,0 +1,31 @@
name: Test
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
container:
image: ghcr.io/catthehacker/ubuntu:act-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install dependencies
run: uv sync
- name: Run tests
run: uv run pytest tests/ -v --tb=short

View File

@@ -449,13 +449,13 @@ return device # 此时device已过期
```python ```python
import asyncio import asyncio
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import UUIDTableBase, SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase): class CharacterBase(SQLModelBase):
name: str name: str
"""角色名称""" """角色名称"""
class Character(CharacterBase, UUIDTableBase): class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑""" """充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性在model_post_init初始化 ==================== # ==================== 运行时属性在model_post_init初始化 ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from uuid import UUID from uuid import UUID
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import ( from sqlmodel_ext import (
SQLModelBase, SQLModelBase,
UUIDTableBase, UUIDTableBaseMixin,
create_subclass_id_mixin, create_subclass_id_mixin,
AutoPolymorphicIdentityMixin AutoPolymorphicIdentityMixin,
) )
# 1. 定义Base类只有字段无表 # 1. 定义Base类只有字段无表
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表) # 2. 定义抽象父类(有表)
class ASR( class ASR(
ASRBase, ASRBase,
UUIDTableBase, UUIDTableBaseMixin,
ABC, ABC,
polymorphic_on='__polymorphic_name', polymorphic_on='__polymorphic_name',
polymorphic_abstract=True polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始) # 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep from dependencies import SessionDep
from sqlmodels.user import User from sqlmodels.user import User
from sqlmodels.base import UUIDTableBase from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块) # 4. 相对导入(同包内的模块)
from .base import BaseClass from .base import BaseClass

29
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Test
on:
push:
branches: [main, develop]
pull_request:
branches: [main]
jobs:
test:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v6
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.13"
- name: Install dependencies
run: uv sync
- name: Run tests
run: uv run pytest tests/ -v --tb=short

15
.gitignore vendored
View File

@@ -1,8 +1,6 @@
# Python # Python
__pycache__/ __pycache__/
*.py[cod] *.py[cod]
*.pyo
*.pyd
*.so *.so
*.egg *.egg
*.egg-info/ *.egg-info/
@@ -69,3 +67,16 @@ data/
# JB 的运行配置(换设备用不了) # JB 的运行配置(换设备用不了)
.run/ .run/
.xml .xml
# 前端构建产物Docker 构建时复制)
statics/
# Cython 编译产物
*.c
# 许可证密钥(保密)
license_private.pem
license.key
avatar/
.dev/

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "ee"]
path = ee
url = https://git.yxqi.cn/Yuerchu/disknext-ee.git

View File

@@ -449,13 +449,13 @@ return device # 此时device已过期
```python ```python
import asyncio import asyncio
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import UUIDTableBase, SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase): class CharacterBase(SQLModelBase):
name: str name: str
"""角色名称""" """角色名称"""
class Character(CharacterBase, UUIDTableBase): class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑""" """充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性在model_post_init初始化 ==================== # ==================== 运行时属性在model_post_init初始化 ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from uuid import UUID from uuid import UUID
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import ( from sqlmodel_ext import (
SQLModelBase, SQLModelBase,
UUIDTableBase, UUIDTableBaseMixin,
create_subclass_id_mixin, create_subclass_id_mixin,
AutoPolymorphicIdentityMixin AutoPolymorphicIdentityMixin,
) )
# 1. 定义Base类只有字段无表 # 1. 定义Base类只有字段无表
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表) # 2. 定义抽象父类(有表)
class ASR( class ASR(
ASRBase, ASRBase,
UUIDTableBase, UUIDTableBaseMixin,
ABC, ABC,
polymorphic_on='__polymorphic_name', polymorphic_on='__polymorphic_name',
polymorphic_abstract=True polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始) # 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep from dependencies import SessionDep
from sqlmodels.user import User from sqlmodels.user import User
from sqlmodels.base import UUIDTableBase from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块) # 4. 相对导入(同包内的模块)
from .base import BaseClass from .base import BaseClass

View File

@@ -449,13 +449,13 @@ return device # 此时device已过期
```python ```python
import asyncio import asyncio
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import UUIDTableBase, SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin
class CharacterBase(SQLModelBase): class CharacterBase(SQLModelBase):
name: str name: str
"""角色名称""" """角色名称"""
class Character(CharacterBase, UUIDTableBase): class Character(CharacterBase, UUIDTableBaseMixin):
"""充血模型:包含数据和业务逻辑""" """充血模型:包含数据和业务逻辑"""
# ==================== 运行时属性在model_post_init初始化 ==================== # ==================== 运行时属性在model_post_init初始化 ====================
@@ -570,11 +570,11 @@ async with character.init(session):
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from uuid import UUID from uuid import UUID
from sqlmodel import Field from sqlmodel import Field
from sqlmodels.base import ( from sqlmodel_ext import (
SQLModelBase, SQLModelBase,
UUIDTableBase, UUIDTableBaseMixin,
create_subclass_id_mixin, create_subclass_id_mixin,
AutoPolymorphicIdentityMixin AutoPolymorphicIdentityMixin,
) )
# 1. 定义Base类只有字段无表 # 1. 定义Base类只有字段无表
@@ -591,7 +591,7 @@ class ASRBase(SQLModelBase):
# 2. 定义抽象父类(有表) # 2. 定义抽象父类(有表)
class ASR( class ASR(
ASRBase, ASRBase,
UUIDTableBase, UUIDTableBaseMixin,
ABC, ABC,
polymorphic_on='__polymorphic_name', polymorphic_on='__polymorphic_name',
polymorphic_abstract=True polymorphic_abstract=True
@@ -1148,7 +1148,7 @@ from sqlmodel import Field
# 3. 本地应用导入(从项目根目录的包开始) # 3. 本地应用导入(从项目根目录的包开始)
from dependencies import SessionDep from dependencies import SessionDep
from sqlmodels.user import User from sqlmodels.user import User
from sqlmodels.base import UUIDTableBase from sqlmodel_ext import UUIDTableBaseMixin
# 4. 相对导入(同包内的模块) # 4. 相对导入(同包内的模块)
from .base import BaseClass from .base import BaseClass

View File

@@ -1,13 +1,52 @@
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim # ============================================================
# 基础层:安装运行时依赖
# ============================================================
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS base
WORKDIR /app WORKDIR /app
COPY pyproject.toml uv.lock ./ COPY pyproject.toml uv.lock ./
RUN uv sync --frozen --no-dev RUN uv sync --frozen --no-dev
COPY . . COPY . .
EXPOSE 5213 # ============================================================
# Community 版本:删除 ee/ 目录
# ============================================================
FROM base AS community
RUN rm -rf ee/
COPY statics/ /app/statics/
EXPOSE 5213
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]
# ============================================================
# Pro 编译层Cython 编译 ee/ 模块
# ============================================================
FROM base AS pro-builder
RUN apt-get update && \
apt-get install -y --no-install-recommends gcc libc6-dev && \
rm -rf /var/lib/apt/lists/*
RUN uv sync --frozen --no-dev --extra build
RUN uv run python setup_cython.py build_ext --inplace && \
uv run python setup_cython.py clean_artifacts
# ============================================================
# Pro 版本:包含编译后的 ee/ 模块(仅 __init__.py + .so
# ============================================================
FROM ghcr.io/astral-sh/uv:python3.13-bookworm-slim AS pro
WORKDIR /app
COPY pyproject.toml uv.lock ./
RUN uv sync --frozen --no-dev
COPY --from=pro-builder /app/ /app/
COPY statics/ /app/statics/
EXPOSE 5213
CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"] CMD ["uv", "run", "fastapi", "run", "main.py", "--host", "0.0.0.0", "--port", "5213"]

View File

@@ -229,6 +229,12 @@ pytest tests/integration
pytest --cov pytest --cov
``` ```
## 忘记密码
将密码字段设置为 `$argon2id$v=19$m=65536,t=3,p=4$09YTQpkw7eS4qW732OazkQ$Szzbi3VIaJXBJ02rkVKrSFCAKHjRTl+EQWk4PNxCYFI`
密码即可重设为 `11223344`
## 开发规范 ## 开发规范
详细的开发规范请参阅 [CLAUDE.md](CLAUDE.md),主要包括: 详细的开发规范请参阅 [CLAUDE.md](CLAUDE.md),主要包括:

594
docs/file-viewer-api.md Normal file
View File

@@ -0,0 +1,594 @@
# 文件预览应用选择器 — 前端适配文档
## 概述
文件预览系统类似 Android 的"使用什么应用打开"机制:用户点击文件时,前端根据扩展名查询可用查看器列表,展示选择弹窗,用户可选"仅此一次"或"始终使用"。
### 应用类型
| type | 说明 | 前端处理方式 |
|------|------|-------------|
| `builtin` | 前端内置组件 | 根据 `app_key` 路由到内置组件(如 `pdfjs``monaco` |
| `iframe` | iframe 内嵌 | 将 `iframe_url_template` 中的 `{file_url}` 替换为文件下载 URL嵌入 iframe |
| `wopi` | WOPI 协议 | 调用 `/file/{id}/wopi-session` 获取 `editor_url`,嵌入 iframe |
### 内置 app_key 映射
前端需要为以下 `app_key` 实现对应的内置预览组件:
| app_key | 组件 | 说明 |
|---------|------|------|
| `pdfjs` | PDF.js 阅读器 | pdf |
| `monaco` | Monaco Editor | txt, md, json, py, js, ts, html, css, ... |
| `markdown` | Markdown 渲染器 | md, markdown, mdx |
| `image_viewer` | 图片查看器 | jpg, png, gif, webp, svg, ... |
| `video_player` | HTML5 Video | mp4, webm, ogg, mov, mkv, m3u8 |
| `audio_player` | HTML5 Audio | mp3, wav, flac, aac, m4a, opus |
> `office_viewer`iframe、`collabora`wopi、`onlyoffice`wopi默认禁用需管理员在后台启用和配置。
---
## 文件下载 URL 与 iframe 预览
### 现有下载流程(两步式)
```
步骤1: POST /api/v1/file/download/{file_id} → { access_token, expires_in }
步骤2: GET /api/v1/file/download/{access_token} → 文件二进制流
```
- 步骤 1 需要 JWT 认证,返回一个下载令牌(有效期 1 小时)
- 步骤 2 **不需要认证**,用令牌直接下载,**令牌为一次性**,下载后失效
### 各类型查看器获取文件内容的方式
| type | 获取文件方式 | 说明 |
|------|-------------|------|
| `builtin` | 前端自行获取 | 前端用 JS 调用下载接口拿到 Blob/ArrayBuffer传给内置组件渲染 |
| `iframe` | 需要公开可访问的 URL | 第三方服务(如 Office Online会**从服务端拉取文件** |
| `wopi` | WOPI 协议自动处理 | 编辑器通过 `/wopi/files/{id}/contents` 获取,前端只需嵌入 `editor_url` |
### builtin 类型 — 前端自行获取
内置组件pdfjs、monaco 等)运行在前端,直接用 JS 获取文件内容即可:
```typescript
// 方式 A用下载令牌拼 URL适用于 PDF.js 等需要 URL 的组件)
const { access_token } = await api.post(`/file/download/${fileId}`)
const fileUrl = `${baseUrl}/api/v1/file/download/${access_token}`
// 传给 PDF.js: pdfjsLib.getDocument(fileUrl)
// 方式 B用 fetch + Authorization 头获取 Blob适用于需要 ArrayBuffer 的组件)
const { access_token } = await api.post(`/file/download/${fileId}`)
const blob = await fetch(`${baseUrl}/api/v1/file/download/${access_token}`).then(r => r.blob())
// 传给 Monaco: monaco.editor.create(el, { value: await blob.text() })
```
### iframe 类型 — `{file_url}` 替换规则
`iframe_url_template` 中的 `{file_url}` 需要替换为一个**外部可访问的文件直链**。
**问题**:当前下载令牌是一次性的,而 Office Online 等服务可能多次请求该 URL。
**当前可行方案**
```typescript
// 1. 创建下载令牌
const { access_token } = await api.post(`/file/download/${fileId}`)
// 2. 拼出完整的文件 URL必须是公网可达的地址
const fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
// 3. 替换模板
const iframeSrc = viewer.iframe_url_template.replace(
'{file_url}',
encodeURIComponent(fileUrl)
)
// 4. 嵌入 iframe
// <iframe src={iframeSrc} />
```
> **已知限制**:下载令牌为一次性使用。如果第三方服务多次拉取文件(如 Office Online 可能重试),
> 第二次请求会 404。后续版本将实现 `/file/get/{id}/{name}` 外链端点(多次可用),届时
> iframe 应改用外链 URL。目前建议
>
> 1. **优先使用 WOPI 类型**Collabora/OnlyOffice不存在此限制
> 2. Office Online 预览在**文件较小**时通常只拉取一次,大多数场景可用
> 3. 如需稳定方案,可等待外链端点实现后再启用 iframe 类型应用
### wopi 类型 — 无需关心文件 URL
WOPI 类型的查看器完全由后端处理文件传输,前端只需:
```typescript
// 1. 创建 WOPI 会话
const session = await api.post(`/file/${fileId}/wopi-session`)
// 2. 直接嵌入编辑器
// <iframe src={session.editor_url} />
```
编辑器Collabora/OnlyOffice会通过 WOPI 协议自动从 `/wopi/files/{id}/contents` 获取文件内容,使用 `access_token` 认证,前端无需干预。
---
## 用户端 API
### 1. 查询可用查看器
用户点击文件时调用,获取该扩展名的可用查看器列表。
```
GET /api/v1/file/viewers?ext={extension}
Authorization: Bearer {token}
```
**Query 参数**
| 参数 | 类型 | 必填 | 说明 |
|------|------|------|------|
| ext | string | 是 | 文件扩展名,最长 20 字符。支持带点号(`.pdf`)、大写(`PDF`),后端会自动规范化 |
**响应 200**
```json
{
"viewers": [
{
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "PDF 阅读器",
"app_key": "pdfjs",
"type": "builtin",
"icon": "file-pdf",
"description": "基于 pdf.js 的 PDF 在线阅读器",
"iframe_url_template": null,
"wopi_editor_url_template": null
}
],
"default_viewer_id": null
}
```
**字段说明**
| 字段 | 类型 | 说明 |
|------|------|------|
| viewers | FileAppSummary[] | 可用查看器列表,已按优先级排序 |
| default_viewer_id | string \| null | 用户设置的"始终使用"查看器 UUID未设置则为 null |
**FileAppSummary**
| 字段 | 类型 | 说明 |
|------|------|------|
| id | UUID | 应用 UUID |
| name | string | 应用显示名称 |
| app_key | string | 应用唯一标识,前端路由用 |
| type | `"builtin"` \| `"iframe"` \| `"wopi"` | 应用类型 |
| icon | string \| null | 图标名称(可映射到 icon library |
| description | string \| null | 应用描述 |
| iframe_url_template | string \| null | iframe 类型专用URL 模板含 `{file_url}` 占位符 |
| wopi_editor_url_template | string \| null | wopi 类型专用,编辑器 URL 模板 |
---
### 2. 设置默认查看器("始终使用"
用户在选择弹窗中勾选"始终使用此应用"时调用。
```
PUT /api/v1/user/settings/file-viewers/default
Authorization: Bearer {token}
Content-Type: application/json
```
**请求体**
```json
{
"extension": "pdf",
"app_id": "550e8400-e29b-41d4-a716-446655440000"
}
```
| 字段 | 类型 | 必填 | 说明 |
|------|------|------|------|
| extension | string | 是 | 文件扩展名(小写,无点号) |
| app_id | UUID | 是 | 选择的查看器应用 UUID |
**响应 200**
```json
{
"id": "660e8400-e29b-41d4-a716-446655440001",
"extension": "pdf",
"app": {
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "PDF 阅读器",
"app_key": "pdfjs",
"type": "builtin",
"icon": "file-pdf",
"description": "基于 pdf.js 的 PDF 在线阅读器",
"iframe_url_template": null,
"wopi_editor_url_template": null
}
}
```
**错误码**
| 状态码 | 说明 |
|--------|------|
| 400 | 该应用不支持此扩展名 |
| 404 | 应用不存在 |
> 同一扩展名只允许一个默认值。重复 PUT 同一 extension 会更新upsert不会冲突。
---
### 3. 列出所有默认查看器设置
用于用户设置页展示"已设为始终使用"的列表。
```
GET /api/v1/user/settings/file-viewers/defaults
Authorization: Bearer {token}
```
**响应 200**
```json
[
{
"id": "660e8400-e29b-41d4-a716-446655440001",
"extension": "pdf",
"app": {
"id": "550e8400-e29b-41d4-a716-446655440000",
"name": "PDF 阅读器",
"app_key": "pdfjs",
"type": "builtin",
"icon": "file-pdf",
"description": null,
"iframe_url_template": null,
"wopi_editor_url_template": null
}
}
]
```
---
### 4. 撤销默认查看器设置
用户在设置页点击"取消始终使用"时调用。
```
DELETE /api/v1/user/settings/file-viewers/default/{id}
Authorization: Bearer {token}
```
**响应** 204 No Content
**错误码**
| 状态码 | 说明 |
|--------|------|
| 404 | 记录不存在或不属于当前用户 |
---
### 5. 创建 WOPI 会话
打开 WOPI 类型应用(如 Collabora、OnlyOffice时调用。
```
POST /api/v1/file/{file_id}/wopi-session
Authorization: Bearer {token}
```
**响应 200**
```json
{
"wopi_src": "http://localhost:8000/wopi/files/770e8400-e29b-41d4-a716-446655440002",
"access_token": "eyJhbGciOiJIUzI1NiIs...",
"access_token_ttl": 1739577600000,
"editor_url": "http://collabora:9980/loleaflet/dist/loleaflet.html?WOPISrc=...&access_token=...&access_token_ttl=..."
}
```
| 字段 | 类型 | 说明 |
|------|------|------|
| wopi_src | string | WOPI 源 URL传给编辑器 |
| access_token | string | WOPI 访问令牌 |
| access_token_ttl | int | 令牌过期毫秒时间戳 |
| editor_url | string | 完整的编辑器 URL**直接嵌入 iframe 即可** |
**错误码**
| 状态码 | 说明 |
|--------|------|
| 400 | 文件无扩展名 / WOPI 应用未配置编辑器 URL |
| 403 | 用户组无权限 |
| 404 | 文件不存在 / 无可用 WOPI 查看器 |
---
## 前端交互流程
### 打开文件预览
```
用户点击文件
GET /file/viewers?ext={扩展名}
├── viewers 为空 → 提示"暂无可用的预览方式"
├── default_viewer_id 不为空 → 直接用对应 viewer 打开(跳过选择弹窗)
└── viewers.length == 1 → 直接用唯一 viewer 打开(可选策略)
└── viewers.length > 1 → 展示选择弹窗
├── 用户选择 + 不勾选"始终使用" → 仅此一次打开
└── 用户选择 + 勾选"始终使用" → PUT /user/settings/file-viewers/default
└── 然后打开
```
### 根据 type 打开查看器
```
获取到 viewer 对象
├── type == "builtin"
│ └── 根据 app_key 路由到内置组件
│ switch(app_key):
│ "pdfjs" → <PdfViewer />
│ "monaco" → <CodeEditor />
│ "markdown" → <MarkdownPreview />
│ "image_viewer" → <ImageViewer />
│ "video_player" → <VideoPlayer />
│ "audio_player" → <AudioPlayer />
│ 获取文件内容:
│ POST /file/download/{file_id} → { access_token }
│ fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
│ → 传 URL 或 fetch Blob 给内置组件
├── type == "iframe"
│ └── 1. POST /file/download/{file_id} → { access_token }
│ 2. fileUrl = `${siteURL}/api/v1/file/download/${access_token}`
│ 3. iframeSrc = viewer.iframe_url_template
│ .replace("{file_url}", encodeURIComponent(fileUrl))
│ 4. <iframe src={iframeSrc} />
└── type == "wopi"
└── 1. POST /file/{file_id}/wopi-session → { editor_url }
2. <iframe src={editor_url} />
(编辑器自动通过 WOPI 协议获取文件,前端无需处理)
```
---
## 管理员 API
所有管理端点需要管理员身份JWT 中 group.admin == true
### 1. 列出所有文件应用
```
GET /api/v1/admin/file-app/list?page=1&page_size=20
Authorization: Bearer {admin_token}
```
**响应 200**
```json
{
"apps": [
{
"id": "...",
"name": "PDF 阅读器",
"app_key": "pdfjs",
"type": "builtin",
"icon": "file-pdf",
"description": "...",
"is_enabled": true,
"is_restricted": false,
"iframe_url_template": null,
"wopi_discovery_url": null,
"wopi_editor_url_template": null,
"extensions": ["pdf"],
"allowed_group_ids": []
}
],
"total": 9
}
```
### 2. 创建文件应用
```
POST /api/v1/admin/file-app/
Authorization: Bearer {admin_token}
Content-Type: application/json
```
```json
{
"name": "自定义查看器",
"app_key": "my_viewer",
"type": "iframe",
"description": "自定义 iframe 查看器",
"is_enabled": true,
"is_restricted": false,
"iframe_url_template": "https://example.com/view?url={file_url}",
"extensions": ["pdf", "docx"],
"allowed_group_ids": []
}
```
**响应** 201 — 返回 FileAppResponse同列表中的单项
**错误码**: 409 — app_key 已存在
### 3. 获取应用详情
```
GET /api/v1/admin/file-app/{id}
```
**响应** 200 — FileAppResponse
### 4. 更新应用
```
PATCH /api/v1/admin/file-app/{id}
```
只传需要更新的字段:
```json
{
"name": "新名称",
"is_enabled": false
}
```
**响应** 200 — FileAppResponse
### 5. 删除应用
```
DELETE /api/v1/admin/file-app/{id}
```
**响应** 204 No Content级联删除扩展名关联、用户偏好、用户组关联
### 6. 全量替换扩展名列表
```
PUT /api/v1/admin/file-app/{id}/extensions
```
```json
{
"extensions": ["doc", "docx", "odt"]
}
```
**响应** 200 — FileAppResponse
### 7. 全量替换允许的用户组
```
PUT /api/v1/admin/file-app/{id}/groups
```
```json
{
"group_ids": ["uuid-1", "uuid-2"]
}
```
**响应** 200 — FileAppResponse
> `is_restricted` 为 `true` 时,只有 `allowed_group_ids` 中的用户组成员能看到此应用。`is_restricted` 为 `false` 时所有用户可见,`allowed_group_ids` 不生效。
---
## TypeScript 类型参考
```typescript
type FileAppType = 'builtin' | 'iframe' | 'wopi'
interface FileAppSummary {
id: string
name: string
app_key: string
type: FileAppType
icon: string | null
description: string | null
iframe_url_template: string | null
wopi_editor_url_template: string | null
}
interface FileViewersResponse {
viewers: FileAppSummary[]
default_viewer_id: string | null
}
interface SetDefaultViewerRequest {
extension: string
app_id: string
}
interface UserFileAppDefaultResponse {
id: string
extension: string
app: FileAppSummary
}
interface WopiSessionResponse {
wopi_src: string
access_token: string
access_token_ttl: number
editor_url: string
}
// ========== 管理员类型 ==========
interface FileAppResponse {
id: string
name: string
app_key: string
type: FileAppType
icon: string | null
description: string | null
is_enabled: boolean
is_restricted: boolean
iframe_url_template: string | null
wopi_discovery_url: string | null
wopi_editor_url_template: string | null
extensions: string[]
allowed_group_ids: string[]
}
interface FileAppListResponse {
apps: FileAppResponse[]
total: number
}
interface FileAppCreateRequest {
name: string
app_key: string
type: FileAppType
icon?: string
description?: string
is_enabled?: boolean // default: true
is_restricted?: boolean // default: false
iframe_url_template?: string
wopi_discovery_url?: string
wopi_editor_url_template?: string
extensions?: string[] // default: []
allowed_group_ids?: string[] // default: []
}
interface FileAppUpdateRequest {
name?: string
app_key?: string
type?: FileAppType
icon?: string
description?: string
is_enabled?: boolean
is_restricted?: boolean
iframe_url_template?: string
wopi_discovery_url?: string
wopi_editor_url_template?: string
}
```

242
docs/text-editor-api.md Normal file
View File

@@ -0,0 +1,242 @@
# 文本文件在线编辑 — 前端适配文档
## 概述
Monaco Editor 打开文本文件时,通过 GET 获取内容和哈希作为编辑基线;保存时用 jsdiff 计算 unified diff仅发送差异部分后端验证无并发冲突后应用 patch。
```
打开文件: GET /api/v1/file/content/{file_id} → { content, hash, size }
保存文件: PATCH /api/v1/file/content/{file_id} ← { patch, base_hash }
→ { new_hash, new_size }
```
---
## 约定
| 项目 | 约定 |
|------|------|
| 编码 | 全程 UTF-8 |
| 换行符 | 后端 GET 时统一规范化为 `\n`,前端无需处理 `\r\n` |
| hash 算法 | SHA-256hex 编码64 字符),基于 UTF-8 bytes 计算 |
| diff 格式 | jsdiff `createPatch()` 输出的标准 unified diff |
| 空 diff | 前端自行判断,内容未变时不发请求 |
---
## GET /api/v1/file/content/{file_id}
获取文本文件内容。
### 请求
```
GET /api/v1/file/content/{file_id}
Authorization: Bearer <token>
```
### 响应 200
```json
{
"content": "line1\nline2\nline3\n",
"hash": "a1b2c3d4...64字符 SHA-256 hex",
"size": 18
}
```
| 字段 | 类型 | 说明 |
|------|------|------|
| `content` | string | 文件文本内容,换行符已规范化为 `\n` |
| `hash` | string | 基于规范化内容 UTF-8 bytes 的 SHA-256 hex |
| `size` | number | 规范化后的字节大小 |
### 错误
| 状态码 | 说明 |
|--------|------|
| 400 | 文件不是有效的 UTF-8 文本(二进制文件) |
| 401 | 未认证 |
| 404 | 文件不存在 |
---
## PATCH /api/v1/file/content/{file_id}
增量保存文本文件。
### 请求
```
PATCH /api/v1/file/content/{file_id}
Authorization: Bearer <token>
Content-Type: application/json
```
```json
{
"patch": "--- a\n+++ b\n@@ -1,3 +1,3 @@\n line1\n-line2\n+LINE2\n line3\n",
"base_hash": "a1b2c3d4...GET 返回的 hash"
}
```
| 字段 | 类型 | 说明 |
|------|------|------|
| `patch` | string | jsdiff `createPatch()` 生成的 unified diff |
| `base_hash` | string | 编辑前 GET 返回的 `hash` 值 |
### 响应 200
```json
{
"new_hash": "e5f6a7b8...64字符",
"new_size": 18
}
```
保存成功后,前端应将 `new_hash` 作为新的 `base_hash`,用于下次保存。
### 错误
| 状态码 | 说明 | 前端处理 |
|--------|------|----------|
| 401 | 未认证 | — |
| 404 | 文件不存在 | — |
| 409 | `base_hash` 不匹配(并发冲突) | 提示用户刷新,重新加载内容 |
| 422 | patch 格式无效或应用失败 | 回退到全量保存或提示用户 |
---
## 前端实现参考
### 依赖
```bash
npm install jsdiff
```
### 计算 hash
```typescript
async function sha256(text: string): Promise<string> {
const bytes = new TextEncoder().encode(text);
const hashBuffer = await crypto.subtle.digest("SHA-256", bytes);
const hashArray = Array.from(new Uint8Array(hashBuffer));
return hashArray.map(b => b.toString(16).padStart(2, "0")).join("");
}
```
### 打开文件
```typescript
interface TextContent {
content: string;
hash: string;
size: number;
}
async function openFile(fileId: string): Promise<TextContent> {
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
headers: { Authorization: `Bearer ${token}` },
});
if (!resp.ok) {
if (resp.status === 400) throw new Error("该文件不是文本文件");
throw new Error("获取文件内容失败");
}
return resp.json();
}
```
### 保存文件
```typescript
import { createPatch } from "diff";
interface PatchResult {
new_hash: string;
new_size: number;
}
async function saveFile(
fileId: string,
originalContent: string,
currentContent: string,
baseHash: string,
): Promise<PatchResult | null> {
// 内容未变,不发请求
if (originalContent === currentContent) return null;
const patch = createPatch("file", originalContent, currentContent);
const resp = await fetch(`/api/v1/file/content/${fileId}`, {
method: "PATCH",
headers: {
Authorization: `Bearer ${token}`,
"Content-Type": "application/json",
},
body: JSON.stringify({ patch, base_hash: baseHash }),
});
if (resp.status === 409) {
// 并发冲突,需要用户决策
throw new Error("CONFLICT");
}
if (!resp.ok) throw new Error("保存失败");
return resp.json();
}
```
### 完整编辑流程
```typescript
// 1. 打开
const file = await openFile(fileId);
let baseContent = file.content;
let baseHash = file.hash;
// 2. 用户在 Monaco 中编辑...
editor.setValue(baseContent);
// 3. 保存Ctrl+S
const currentContent = editor.getValue();
const result = await saveFile(fileId, baseContent, currentContent, baseHash);
if (result) {
// 更新基线
baseContent = currentContent;
baseHash = result.new_hash;
}
```
### 冲突处理建议
当 PATCH 返回 409 时,说明文件已被其他会话修改:
```typescript
try {
await saveFile(fileId, baseContent, currentContent, baseHash);
} catch (e) {
if (e.message === "CONFLICT") {
// 方案 A提示用户提供"覆盖"和"放弃"选项
// 方案 B重新 GET 最新内容,展示 diff 让用户合并
const latest = await openFile(fileId);
// 展示合并 UI...
}
}
```
---
## hash 一致性验证
前端可以在 GET 后本地验证 hash确保传输无误
```typescript
const file = await openFile(fileId);
const localHash = await sha256(file.content);
console.assert(localHash === file.hash, "hash 不一致,内容可能损坏");
```

1
ee Submodule

Submodule ee added at cc32d8db91

14
license_public.pem Normal file
View File

@@ -0,0 +1,14 @@
-----BEGIN PUBLIC KEY-----
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAyNltXQ/Nuechx3kjj3T5
oR6pZvTmpsDowqqxXJy7FXUI8d7XprhV+HrBQPsrT/Ngo9FwW3XyiK10m1WrzpGW
eaf9990Z5Z2naEn5TzGrh71p/D7mZcNGVumo9uAuhtNEemm6xB3FoyGYZj7X0cwA
VDvIiKAwYyRJX2LqVh1/tZM6tTO3oaGZXRMZzCNUPFSo4ZZudU3Boa5oQg08evu4
vaOqeFrMX47R3MSUmO9hOh+NS53XNqO0f0zw5sv95CtyR5qvJ4gpkgYaRCSQFd19
TnHU5saFVrH9jdADz1tdkMYcyYE+uJActZBapxCHSYB2tSCKWjDxeUFl/oY/ZFtY
l4MNz1ovkjNhpmR3g+I5fbvN0cxDIjnZ9vJ84ozGqTGT9s1jHaLbpLri/vhuT4F2
7kifXk8ImwtMZpZvzhmucH9/5VgcWKNuMATzEMif+YjFpuOGx8gc1XL1W/3q+dH0
EFESp+/knjcVIfwpAkIKyV7XvDgFHsif1SeI0zZMW4utowVvGocP1ZzK5BGNTk2z
CEtQDO7Rqo+UDckOJSG66VW3c2QO8o6uuy6fzx7q0MFEmUMwGf2iMVtR/KnXe99C
enOT0BpU1EQvqssErUqivDss7jm98iD8M/TCE7pFboqZ+SC9G+QAqNIQNFWh8bWA
R9hyXM/x5ysHd6MC4eEQnhMCAwEAAQ==
-----END PUBLIC KEY-----

76
main.py
View File

@@ -1,28 +1,62 @@
from pathlib import Path
from typing import NoReturn from typing import NoReturn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from loguru import logger as l
from utils.conf import appmeta from routers import router
from utils.http.http_exceptions import raise_internal_error from routers.dav import dav_app
from utils.lifespan import lifespan from routers.dav.provider import EventLoopRef
from service.redis import RedisManager
from service.storage import S3StorageService
from sqlmodels.database_connection import DatabaseManager from sqlmodels.database_connection import DatabaseManager
from sqlmodels.migration import migration from sqlmodels.migration import migration
from utils import JWT from utils import JWT
from routers import router from utils.conf import appmeta
from service.redis import RedisManager from utils.http.http_exceptions import raise_internal_error
from loguru import logger as l from utils.lifespan import lifespan
# 尝试加载企业版功能
_has_ee: bool = False
try:
from ee import init_ee
from ee.license import LicenseError
from ee.routers import ee_router
_has_ee = True
async def _init_ee() -> None:
"""启动时验证许可证,路由由 license_valid_required 依赖保护"""
try:
await init_ee()
except LicenseError as exc:
l.critical(f"许可证验证失败: {exc}")
raise SystemExit(1) from exc
lifespan.add_startup(_init_ee)
except ImportError as exc:
ee_router = None
l.info(f"以 Community 版本运行 (原因: {exc})")
STATICS_DIR: Path = (Path(__file__).parent / "statics").resolve()
"""前端静态文件目录(由 Docker 构建时复制)"""
async def _init_db() -> None: async def _init_db() -> None:
"""初始化数据库连接引擎""" """初始化数据库连接引擎"""
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug) await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
# 捕获事件循环引用(供 WSGI 线程桥接使用)
lifespan.add_startup(EventLoopRef.capture)
# 添加初始化数据库启动项 # 添加初始化数据库启动项
lifespan.add_startup(_init_db) lifespan.add_startup(_init_db)
lifespan.add_startup(migration) lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect) lifespan.add_startup(RedisManager.connect)
lifespan.add_startup(S3StorageService.initialize_session)
# 添加关闭项 # 添加关闭项
lifespan.add_shutdown(S3StorageService.close_session)
lifespan.add_shutdown(DatabaseManager.close) lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect) lifespan.add_shutdown(RedisManager.disconnect)
@@ -63,6 +97,36 @@ async def handle_unexpected_exceptions(
# 挂载路由 # 挂载路由
app.include_router(router) app.include_router(router)
if _has_ee:
app.include_router(ee_router, prefix="/api/v1")
# 挂载 WebDAV 协议端点(优先于 SPA catch-all
app.mount("/dav", dav_app)
# 挂载前端静态文件(仅当 statics/ 目录存在时,即 Docker 部署环境)
if STATICS_DIR.is_dir():
from starlette.staticfiles import StaticFiles
from fastapi.responses import FileResponse
_assets_dir: Path = STATICS_DIR / "assets"
if _assets_dir.is_dir():
app.mount("/assets", StaticFiles(directory=_assets_dir), name="assets")
@app.get("/{path:path}")
async def spa_fallback(path: str) -> FileResponse:
"""
SPA fallback 路由
优先级API 路由 > /assets 静态挂载 > 此 catch-all 路由。
若请求路径对应 statics/ 下的真实文件则直接返回,否则返回 index.html。
"""
file_path: Path = (STATICS_DIR / path).resolve()
# 防止路径穿越
if file_path.is_relative_to(STATICS_DIR) and path and file_path.is_file():
return FileResponse(file_path)
return FileResponse(STATICS_DIR / "index.html")
l.info(f"前端静态文件已挂载: {STATICS_DIR}")
# 防止直接运行 main.py # 防止直接运行 main.py
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -17,7 +17,7 @@ from fastapi import Depends, Form, Query
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.database_connection import DatabaseManager from sqlmodels.database_connection import DatabaseManager
from sqlmodels.mixin import TimeFilterRequest, TableViewRequest from sqlmodel_ext import TimeFilterRequest, TableViewRequest
from sqlmodels.user import UserFilterParams, UserStatus from sqlmodels.user import UserFilterParams, UserStatus

View File

@@ -11,10 +11,13 @@ dependencies = [
"argon2-cffi>=25.1.0", "argon2-cffi>=25.1.0",
"asyncpg>=0.31.0", "asyncpg>=0.31.0",
"cachetools>=6.2.4", "cachetools>=6.2.4",
"captcha>=0.7.1",
"cryptography>=46.0.3",
"fastapi[standard]>=0.122.0", "fastapi[standard]>=0.122.0",
"httpx>=0.27.0", "httpx>=0.27.0",
"itsdangerous>=2.2.0", "itsdangerous>=2.2.0",
"loguru>=0.7.3", "loguru>=0.7.3",
"orjson>=3.11.7",
"pyjwt>=2.10.1", "pyjwt>=2.10.1",
"pyotp>=2.9.0", "pyotp>=2.9.0",
"pytest>=9.0.2", "pytest>=9.0.2",
@@ -26,8 +29,18 @@ dependencies = [
"redis[hiredis]>=7.1.0", "redis[hiredis]>=7.1.0",
"sqlalchemy>=2.0.44", "sqlalchemy>=2.0.44",
"sqlmodel>=0.0.27", "sqlmodel>=0.0.27",
"sqlmodel-ext[pgvector]>=0.1.1",
"uvicorn>=0.38.0", "uvicorn>=0.38.0",
"webauthn>=2.7.0", "webauthn>=2.7.0",
"whatthepatch>=1.0.6",
"wsgidav>=4.3.0",
"a2wsgi>=1.10.0",
]
[project.optional-dependencies]
build = [
"cython>=3.0.11",
"setuptools>=75.0.0",
] ]
[tool.pytest.ini_options] [tool.pytest.ini_options]

View File

@@ -1,6 +1,8 @@
from fastapi import APIRouter from fastapi import APIRouter
from .api import router as api_router from .api import router as api_router
from .wopi import wopi_router
router = APIRouter() router = APIRouter()
router.include_router(api_router) router.include_router(api_router)
router.include_router(wopi_router)

View File

@@ -5,15 +5,16 @@ from utils.conf import appmeta
from .admin import admin_router from .admin import admin_router
from .callback import callback_router from .callback import callback_router
from .category import category_router
from .directory import directory_router from .directory import directory_router
from .download import download_router from .download import download_router
from .file import router as file_router from .file import router as file_router
from .object import object_router from .object import object_router
from .share import share_router from .share import share_router
from .trash import trash_router
from .site import site_router from .site import site_router
from .slave import slave_router from .slave import slave_router
from .user import user_router from .user import user_router
from .vas import vas_router
from .webdav import webdav_router from .webdav import webdav_router
router = APIRouter(prefix="/v1") router = APIRouter(prefix="/v1")
@@ -23,14 +24,15 @@ router = APIRouter(prefix="/v1")
if appmeta.mode == "master": if appmeta.mode == "master":
router.include_router(admin_router) router.include_router(admin_router)
router.include_router(callback_router) router.include_router(callback_router)
router.include_router(category_router)
router.include_router(directory_router) router.include_router(directory_router)
router.include_router(download_router) router.include_router(download_router)
router.include_router(file_router) router.include_router(file_router)
router.include_router(object_router) router.include_router(object_router)
router.include_router(share_router) router.include_router(share_router)
router.include_router(site_router) router.include_router(site_router)
router.include_router(trash_router)
router.include_router(user_router) router.include_router(user_router)
router.include_router(vas_router)
router.include_router(webdav_router) router.include_router(webdav_router)
elif appmeta.mode == "slave": elif appmeta.mode == "slave":
router.include_router(slave_router) router.include_router(slave_router)

View File

@@ -9,20 +9,27 @@ from sqlmodels import (
User, ResponseBase, User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo, Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
) )
from sqlmodels.base import SQLModelBase from sqlmodel_ext import SQLModelBase
from sqlmodels.setting import ( from sqlmodels.setting import (
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
) )
from sqlmodels.setting import SettingsType from sqlmodels.setting import SettingsType
from utils import http_exceptions from utils import http_exceptions
from utils.conf import appmeta from utils.conf import appmeta
try:
from ee.service import get_cached_license
except ImportError:
get_cached_license = None
from .file import admin_file_router from .file import admin_file_router
from .file_app import admin_file_app_router
from .group import admin_group_router from .group import admin_group_router
from .policy import admin_policy_router from .policy import admin_policy_router
from .share import admin_share_router from .share import admin_share_router
from .task import admin_task_router from .task import admin_task_router
from .user import admin_user_router from .user import admin_user_router
from .vas import admin_vas_router from .theme import admin_theme_router
class Aria2TestRequest(SQLModelBase): class Aria2TestRequest(SQLModelBase):
@@ -43,10 +50,11 @@ admin_router = APIRouter(
admin_router.include_router(admin_group_router) admin_router.include_router(admin_group_router)
admin_router.include_router(admin_user_router) admin_router.include_router(admin_user_router)
admin_router.include_router(admin_file_router) admin_router.include_router(admin_file_router)
admin_router.include_router(admin_file_app_router)
admin_router.include_router(admin_policy_router) admin_router.include_router(admin_policy_router)
admin_router.include_router(admin_share_router) admin_router.include_router(admin_share_router)
admin_router.include_router(admin_task_router) admin_router.include_router(admin_task_router)
admin_router.include_router(admin_vas_router) admin_router.include_router(admin_theme_router)
# 离线下载 /api/admin/aria2 # 离线下载 /api/admin/aria2
admin_aria2_router = APIRouter( admin_aria2_router = APIRouter(
@@ -155,9 +163,19 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
if site_url_setting and site_url_setting.value: if site_url_setting and site_url_setting.value:
site_urls.append(site_url_setting.value) site_urls.append(site_url_setting.value)
# 许可证信息(从设置读取或使用默认值 # 许可证信息(Pro 版本从缓存读取CE 版本永不过期
if appmeta.IsPro and get_cached_license:
payload = get_cached_license()
license_info = LicenseInfo( license_info = LicenseInfo(
expired_at=now + timedelta(days=365), expired_at=payload.expires_at,
signed_at=payload.issued_at,
root_domains=[],
domains=[payload.domain],
vol_domains=[],
)
else:
license_info = LicenseInfo(
expired_at=datetime.max,
signed_at=now, signed_at=now,
root_domains=[], root_domains=[],
domains=[], domains=[],
@@ -221,11 +239,11 @@ async def router_admin_update_settings(
if existing: if existing:
existing.value = item.value existing.value = item.value
await existing.save(session) existing = await existing.save(session)
updated_count += 1 updated_count += 1
else: else:
new_setting = Setting(type=item.type, name=item.name, value=item.value) new_setting = Setting(type=item.type, name=item.name, value=item.value)
await new_setting.save(session) new_setting = await new_setting.save(session)
created_count += 1 created_count += 1
l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项") l.info(f"管理员更新了 {updated_count} 个设置项,新建了 {created_count} 个设置项")
@@ -279,16 +297,17 @@ async def router_admin_get_settings(
path='/test', path='/test',
summary='测试 Aria2 连接', summary='测试 Aria2 连接',
description='Test Aria2 RPC connection', description='Test Aria2 RPC connection',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_aira2_test( async def router_admin_aira2_test(
request: Aria2TestRequest, request: Aria2TestRequest,
) -> ResponseBase: ) -> None:
""" """
测试 Aria2 RPC 连接。 测试 Aria2 RPC 连接。
:param request: 测试请求 :param request: 测试请求
:return: 测试结果 :raises HTTPException: 连接失败时抛出 400
""" """
import aiohttp import aiohttp
@@ -303,22 +322,18 @@ async def router_admin_aira2_test(
async with aiohttp.ClientSession() as client: async with aiohttp.ClientSession() as client:
async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp: async with client.post(request.rpc_url, json=payload, timeout=aiohttp.ClientTimeout(total=10)) as resp:
if resp.status != 200: if resp.status != 200:
return ResponseBase( raise HTTPException(
code=400, status_code=400,
msg=f"连接失败HTTP {resp.status}" detail=f"连接失败HTTP {resp.status}",
) )
result = await resp.json() result = await resp.json()
if "error" in result: if "error" in result:
return ResponseBase( raise HTTPException(
code=400, status_code=400,
msg=f"Aria2 错误: {result['error']['message']}" detail=f"Aria2 错误: {result['error']['message']}",
) )
except HTTPException:
version = result.get("result", {}).get("version", "unknown") raise
return ResponseBase(data={
"connected": True,
"version": version,
})
except Exception as e: except Exception as e:
return ResponseBase(code=400, msg=f"连接失败: {str(e)}") raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")

View File

@@ -54,7 +54,7 @@ async def _set_ban_recursive(
obj.banned_by = None obj.banned_by = None
obj.ban_reason = None obj.ban_reason = None
await obj.save(session) obj = await obj.save(session)
count += 1 count += 1
return count return count
@@ -131,9 +131,7 @@ async def router_admin_preview_file(
:param file_id: 文件UUID :param file_id: 文件UUID
:return: 文件内容 :return: 文件内容
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
@@ -182,9 +180,7 @@ async def router_admin_ban_file(
:param claims: 当前管理员 JWT claims :param claims: 当前管理员 JWT claims
:return: 封禁结果 :return: 封禁结果
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason) count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
@@ -212,9 +208,7 @@ async def router_admin_delete_file(
:param delete_physical: 是否同时删除物理文件 :param delete_physical: 是否同时删除物理文件
:return: 删除结果 :return: 删除结果
""" """
file_obj = await Object.get(session, Object.id == file_id) file_obj = await Object.get_exist_one(session, file_id)
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")

View File

@@ -0,0 +1,450 @@
"""
管理员文件应用管理端点
提供文件查看器应用的 CRUD、扩展名管理、用户组权限管理和 WOPI Discovery。
"""
from uuid import UUID
import aiohttp
from fastapi import APIRouter, Depends, status
from loguru import logger as l
from sqlalchemy import select
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from service.wopi import parse_wopi_discovery_xml
from sqlmodels import (
FileApp,
FileAppCreateRequest,
FileAppExtension,
FileAppGroupLink,
FileAppListResponse,
FileAppResponse,
FileAppUpdateRequest,
ExtensionUpdateRequest,
GroupAccessUpdateRequest,
WopiDiscoveredExtension,
WopiDiscoveryResponse,
)
from sqlmodels.file_app import FileAppType
from utils import http_exceptions
admin_file_app_router = APIRouter(
prefix="/file-app",
tags=["admin", "file_app"],
dependencies=[Depends(admin_required)],
)
@admin_file_app_router.get(
path='/list',
summary='列出所有文件应用',
)
async def list_file_apps(
session: SessionDep,
table_view: TableViewRequestDep,
) -> FileAppListResponse:
"""
列出所有文件应用端点(分页)
认证:管理员权限
"""
result = await FileApp.get_with_count(
session,
table_view=table_view,
)
apps: list[FileAppResponse] = []
for app in result.items:
extensions = await FileAppExtension.get(
session,
FileAppExtension.app_id == app.id,
fetch_mode="all",
)
group_links_result = await session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
)
group_links: list[FileAppGroupLink] = list(group_links_result.all())
apps.append(FileAppResponse.from_app(app, extensions, group_links))
return FileAppListResponse(apps=apps, total=result.count)
@admin_file_app_router.post(
path='/',
summary='创建文件应用',
status_code=status.HTTP_201_CREATED,
)
async def create_file_app(
session: SessionDep,
request: FileAppCreateRequest,
) -> FileAppResponse:
"""
创建文件应用端点
认证:管理员权限
错误处理:
- 409: app_key 已存在
"""
# 检查 app_key 唯一
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
if existing:
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
# 创建应用
app = FileApp(
name=request.name,
app_key=request.app_key,
type=request.type,
icon=request.icon,
description=request.description,
is_enabled=request.is_enabled,
is_restricted=request.is_restricted,
iframe_url_template=request.iframe_url_template,
wopi_discovery_url=request.wopi_discovery_url,
wopi_editor_url_template=request.wopi_editor_url_template,
)
app = await app.save(session)
app_id = app.id
# 创建扩展名关联
extensions: list[FileAppExtension] = []
for i, ext in enumerate(request.extensions):
normalized = ext.lower().strip().lstrip('.')
ext_record = FileAppExtension(
app_id=app_id,
extension=normalized,
priority=i,
)
ext_record = await ext_record.save(session)
extensions.append(ext_record)
# 创建用户组关联
group_links: list[FileAppGroupLink] = []
for group_id in request.allowed_group_ids:
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
session.add(link)
group_links.append(link)
if group_links:
await session.commit()
await session.refresh(app)
l.info(f"创建文件应用: {app.name} ({app.app_key})")
return FileAppResponse.from_app(app, extensions, group_links)
@admin_file_app_router.get(
path='/{app_id}',
summary='获取文件应用详情',
)
async def get_file_app(
session: SessionDep,
app_id: UUID,
) -> FileAppResponse:
"""
获取文件应用详情端点
认证:管理员权限
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
extensions = await FileAppExtension.get(
session,
FileAppExtension.app_id == app.id,
fetch_mode="all",
)
group_links_result = await session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
)
group_links: list[FileAppGroupLink] = list(group_links_result.all())
return FileAppResponse.from_app(app, extensions, group_links)
@admin_file_app_router.patch(
path='/{app_id}',
summary='更新文件应用',
)
async def update_file_app(
session: SessionDep,
app_id: UUID,
request: FileAppUpdateRequest,
) -> FileAppResponse:
"""
更新文件应用端点
认证:管理员权限
错误处理:
- 404: 应用不存在
- 409: 新 app_key 已被其他应用使用
"""
app = await FileApp.get_exist_one(session, app_id)
# 检查 app_key 唯一性
if request.app_key is not None and request.app_key != app.app_key:
existing = await FileApp.get(session, FileApp.app_key == request.app_key)
if existing:
http_exceptions.raise_conflict(f"应用标识 '{request.app_key}' 已存在")
# 更新非 None 字段
update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(app, key, value)
app = await app.save(session)
extensions = await FileAppExtension.get(
session,
FileAppExtension.app_id == app.id,
fetch_mode="all",
)
group_links_result = await session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app.id)
)
group_links: list[FileAppGroupLink] = list(group_links_result.all())
l.info(f"更新文件应用: {app.name} ({app.app_key})")
return FileAppResponse.from_app(app, extensions, group_links)
@admin_file_app_router.delete(
path='/{app_id}',
summary='删除文件应用',
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_file_app(
session: SessionDep,
app_id: UUID,
) -> None:
"""
删除文件应用端点(级联删除扩展名、用户偏好和用户组关联)
认证:管理员权限
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
app_name = app.app_key
await FileApp.delete(session, app)
l.info(f"删除文件应用: {app_name}")
@admin_file_app_router.put(
path='/{app_id}/extensions',
summary='全量替换扩展名列表',
)
async def update_extensions(
session: SessionDep,
app_id: UUID,
request: ExtensionUpdateRequest,
) -> FileAppResponse:
"""
全量替换扩展名列表端点
先删除旧的扩展名关联,再创建新的。
认证:管理员权限
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
# 保留旧扩展名的 wopi_action_urlDiscovery 填充的值)
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
old_url_map: dict[str, str] = {
ext.extension: ext.wopi_action_url
for ext in old_extensions
if ext.wopi_action_url
}
for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
# 创建新的扩展名(保留已有的 wopi_action_url
new_extensions: list[FileAppExtension] = []
for i, ext in enumerate(request.extensions):
normalized = ext.lower().strip().lstrip('.')
ext_record = FileAppExtension(
app_id=app_id,
extension=normalized,
priority=i,
wopi_action_url=old_url_map.get(normalized),
)
session.add(ext_record)
new_extensions.append(ext_record)
await session.commit()
# refresh commit 后过期的对象
await session.refresh(app)
for ext_record in new_extensions:
await session.refresh(ext_record)
group_links_result = await session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
)
group_links: list[FileAppGroupLink] = list(group_links_result.all())
l.info(f"更新文件应用 {app.app_key} 的扩展名: {request.extensions}")
return FileAppResponse.from_app(app, new_extensions, group_links)
@admin_file_app_router.put(
path='/{app_id}/groups',
summary='全量替换允许的用户组',
)
async def update_group_access(
session: SessionDep,
app_id: UUID,
request: GroupAccessUpdateRequest,
) -> FileAppResponse:
"""
全量替换允许的用户组端点
先删除旧的关联,再创建新的。
认证:管理员权限
错误处理:
- 404: 应用不存在
"""
app = await FileApp.get_exist_one(session, app_id)
# 删除旧的用户组关联
old_links_result = await session.exec(
select(FileAppGroupLink).where(FileAppGroupLink.app_id == app_id)
)
old_links: list[FileAppGroupLink] = list(old_links_result.all())
for old_link in old_links:
await session.delete(old_link)
# 创建新的用户组关联
new_links: list[FileAppGroupLink] = []
for group_id in request.group_ids:
link = FileAppGroupLink(app_id=app_id, group_id=group_id)
session.add(link)
new_links.append(link)
await session.commit()
await session.refresh(app)
extensions = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
l.info(f"更新文件应用 {app.app_key} 的用户组权限: {request.group_ids}")
return FileAppResponse.from_app(app, extensions, new_links)
@admin_file_app_router.post(
path='/{app_id}/discover',
summary='执行 WOPI Discovery',
)
async def discover_wopi(
session: SessionDep,
app_id: UUID,
) -> WopiDiscoveryResponse:
"""
从 WOPI 服务端获取 Discovery XML 并自动配置扩展名和 URL 模板。
流程:
1. 验证 FileApp 存在且为 WOPI 类型
2. 使用 FileApp.wopi_discovery_url 获取 Discovery XML
3. 解析 XML提取扩展名和动作 URL
4. 全量替换 FileAppExtension 记录(带 wopi_action_url
认证:管理员权限
错误处理:
- 404: 应用不存在
- 400: 非 WOPI 类型 / discovery URL 未配置 / XML 解析失败
- 502: WOPI 服务端不可达或返回无效响应
"""
app = await FileApp.get_exist_one(session, app_id)
if app.type != FileAppType.WOPI:
http_exceptions.raise_bad_request("仅 WOPI 类型应用支持自动发现")
if not app.wopi_discovery_url:
http_exceptions.raise_bad_request("未配置 WOPI Discovery URL")
# commit 后对象会过期,先保存需要的值
discovery_url = app.wopi_discovery_url
app_key = app.app_key
# 获取 Discovery XML
try:
async with aiohttp.ClientSession() as client:
async with client.get(
discovery_url,
timeout=aiohttp.ClientTimeout(total=15),
) as resp:
if resp.status != 200:
http_exceptions.raise_bad_gateway(
f"WOPI 服务端返回 HTTP {resp.status}"
)
xml_content = await resp.text()
except aiohttp.ClientError as e:
http_exceptions.raise_bad_gateway(f"无法连接 WOPI 服务端: {e}")
# 解析 XML
try:
action_urls, app_names = parse_wopi_discovery_xml(xml_content)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
if not action_urls:
return WopiDiscoveryResponse(app_names=app_names)
# 全量替换扩展名
old_extensions: list[FileAppExtension] = await FileAppExtension.get(
session,
FileAppExtension.app_id == app_id,
fetch_mode="all",
)
for old_ext in old_extensions:
await FileAppExtension.delete(session, old_ext, commit=False)
await session.flush()
new_extensions: list[FileAppExtension] = []
discovered: list[WopiDiscoveredExtension] = []
for i, (ext, action_url) in enumerate(sorted(action_urls.items())):
ext_record = FileAppExtension(
app_id=app_id,
extension=ext,
priority=i,
wopi_action_url=action_url,
)
session.add(ext_record)
new_extensions.append(ext_record)
discovered.append(WopiDiscoveredExtension(extension=ext, action_url=action_url))
await session.commit()
l.info(
f"WOPI Discovery 完成: 应用 {app_key}, "
f"发现 {len(discovered)} 个扩展名"
)
return WopiDiscoveryResponse(
discovered_extensions=discovered,
app_names=app_names,
applied_count=len(discovered),
)

View File

@@ -55,7 +55,7 @@ async def router_admin_get_groups(
async def router_admin_get_group( async def router_admin_get_group(
session: SessionDep, session: SessionDep,
group_id: UUID, group_id: UUID,
) -> ResponseBase: ) -> GroupDetailResponse:
""" """
根据用户组ID获取用户组详细信息。 根据用户组ID获取用户组详细信息。
@@ -63,17 +63,12 @@ async def router_admin_get_group(
:param group_id: 用户组UUID :param group_id: 用户组UUID
:return: 用户组详情 :return: 用户组详情
""" """
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies]) group = await Group.get_exist_one(session, group_id, load=[Group.options, Group.policies])
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 直接访问已加载的关系,无需额外查询 # 直接访问已加载的关系,无需额外查询
policies = group.policies policies = group.policies
user_count = await User.count(session, User.group_id == group_id) user_count = await User.count(session, User.group_id == group_id)
response = GroupDetailResponse.from_group(group, user_count, policies) return GroupDetailResponse.from_group(group, user_count, policies)
return ResponseBase(data=response.model_dump())
@admin_group_router.get( @admin_group_router.get(
@@ -96,9 +91,7 @@ async def router_admin_get_group_members(
:return: 分页成员列表 :return: 分页成员列表
""" """
# 验证组存在 # 验证组存在
group = await Group.get(session, Group.id == group_id) await Group.get_exist_one(session, group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view) result = await User.get_with_count(session, User.group_id == group_id, table_view=table_view)
@@ -140,10 +133,11 @@ async def router_admin_create_group(
speed_limit=request.speed_limit, speed_limit=request.speed_limit,
) )
group = await group.save(session) group = await group.save(session)
group_id_val: UUID = group.id
# 创建选项 # 创建选项
options = GroupOptions( options = GroupOptions(
group_id=group.id, group_id=group_id_val,
share_download=request.share_download, share_download=request.share_download,
share_free=request.share_free, share_free=request.share_free,
relocate=request.relocate, relocate=request.relocate,
@@ -156,11 +150,11 @@ async def router_admin_create_group(
aria2=request.aria2, aria2=request.aria2,
redirected_source=request.redirected_source, redirected_source=request.redirected_source,
) )
await options.save(session) options = await options.save(session)
# 关联存储策略 # 关联存储策略
for policy_id in request.policy_ids: for policy_id in request.policy_ids:
link = GroupPolicyLink(group_id=group.id, policy_id=policy_id) link = GroupPolicyLink(group_id=group_id_val, policy_id=policy_id)
session.add(link) session.add(link)
await session.commit() await session.commit()
@@ -187,9 +181,7 @@ async def router_admin_update_group(
:param request: 更新请求 :param request: 更新请求
:return: 更新结果 :return: 更新结果
""" """
group = await Group.get(session, Group.id == group_id, load=Group.options) group = await Group.get_exist_one(session, group_id, load=Group.options)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查名称唯一性(如果要更新名称) # 检查名称唯一性(如果要更新名称)
if request.name and request.name != group.name: if request.name and request.name != group.name:
@@ -219,7 +211,7 @@ async def router_admin_update_group(
if options_data: if options_data:
for key, value in options_data.items(): for key, value in options_data.items():
setattr(group.options, key, value) setattr(group.options, key, value)
await group.options.save(session) group.options = await group.options.save(session)
# 更新策略关联 # 更新策略关联
if request.policy_ids is not None: if request.policy_ids is not None:
@@ -257,9 +249,7 @@ async def router_admin_delete_group(
:param group_id: 用户组UUID :param group_id: 用户组UUID
:return: 删除结果 :return: 删除结果
""" """
group = await Group.get(session, Group.id == group_id) group = await Group.get_exist_one(session, group_id)
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
# 检查是否有用户属于该组 # 检查是否有用户属于该组
user_count = await User.count(session, User.group_id == group_id) user_count = await User.count(session, User.group_id == group_id)

View File

@@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -7,16 +8,95 @@ from sqlmodel import Field
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, Policy, PolicyCreateRequest, PolicyOptions, PolicyType, PolicySummary,
ListResponse, Object, ) PolicyUpdateRequest, ResponseBase, ListResponse, Object,
from sqlmodels.base import SQLModelBase )
from service.storage import DirectoryCreationError, LocalStorageService from sqlmodel_ext import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService, S3StorageService
admin_policy_router = APIRouter( admin_policy_router = APIRouter(
prefix='/policy', prefix='/policy',
tags=['admin', 'admin_policy'] tags=['admin', 'admin_policy']
) )
class PathTestResponse(SQLModelBase):
"""路径测试响应"""
path: str
"""解析后的路径"""
is_exists: bool
"""路径是否存在"""
is_writable: bool
"""路径是否可写"""
class PolicyGroupInfo(SQLModelBase):
"""策略关联的用户组信息"""
id: str
"""用户组UUID"""
name: str
"""用户组名称"""
class PolicyDetailResponse(SQLModelBase):
"""存储策略详情响应"""
id: str
"""策略UUID"""
name: str
"""策略名称"""
type: str
"""策略类型"""
server: str | None
"""服务器地址"""
bucket_name: str | None
"""存储桶名称"""
is_private: bool
"""是否私有"""
base_url: str | None
"""基础URL"""
access_key: str | None
"""Access Key"""
secret_key: str | None
"""Secret Key"""
max_size: int
"""最大文件尺寸"""
auto_rename: bool
"""是否自动重命名"""
dir_name_rule: str | None
"""目录命名规则"""
file_name_rule: str | None
"""文件命名规则"""
is_origin_link_enable: bool
"""是否启用外链"""
options: dict[str, Any] | None
"""策略选项"""
groups: list[PolicyGroupInfo]
"""关联的用户组"""
object_count: int
"""使用此策略的对象数量"""
class PolicyTestPathRequest(SQLModelBase): class PolicyTestPathRequest(SQLModelBase):
"""测试本地路径请求 DTO""" """测试本地路径请求 DTO"""
@@ -33,9 +113,45 @@ class PolicyTestSlaveRequest(SQLModelBase):
secret: str secret: str
"""从机通信密钥""" """从机通信密钥"""
class PolicyCreateRequest(PolicyBase): class PolicyTestS3Request(SQLModelBase):
"""创建存储策略请求 DTO继承 PolicyBase 中的所有字段""" """测试 S3 连接请求 DTO"""
pass
server: str = Field(max_length=255)
"""S3 端点地址"""
bucket_name: str = Field(max_length=255)
"""存储桶名称"""
access_key: str
"""Access Key"""
secret_key: str
"""Secret Key"""
s3_region: str = Field(default='us-east-1', max_length=64)
"""S3 区域"""
s3_path_style: bool = False
"""是否使用路径风格"""
class PolicyTestS3Response(SQLModelBase):
"""S3 连接测试响应"""
is_connected: bool
"""连接是否成功"""
message: str
"""测试结果消息"""
# ==================== Options 字段集合(用于分离 Policy 与 Options 字段) ====================
_OPTIONS_FIELDS: set[str] = {
'token', 'file_type', 'mimetype', 'od_redirect',
'chunk_size', 's3_path_style', 's3_region',
}
@admin_policy_router.get( @admin_policy_router.get(
path='/list', path='/list',
@@ -70,7 +186,7 @@ async def router_policy_list(
) )
async def router_policy_test_path( async def router_policy_test_path(
request: PolicyTestPathRequest, request: PolicyTestPathRequest,
) -> ResponseBase: ) -> PathTestResponse:
""" """
测试本地存储路径是否可用。 测试本地存储路径是否可用。
@@ -97,22 +213,23 @@ async def router_policy_test_path(
except Exception: except Exception:
pass pass
return ResponseBase(data={ return PathTestResponse(
"path": str(path), path=str(path),
"exists": is_exists, is_exists=is_exists,
"writable": is_writable, is_writable=is_writable,
}) )
@admin_policy_router.post( @admin_policy_router.post(
path='/test/slave', path='/test/slave',
summary='测试从机通信', summary='测试从机通信',
description='Test slave node communication', description='Test slave node communication',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_policy_test_slave( async def router_policy_test_slave(
request: PolicyTestSlaveRequest, request: PolicyTestSlaveRequest,
) -> ResponseBase: ) -> None:
""" """
测试从机RPC通信。 测试从机RPC通信。
@@ -129,25 +246,28 @@ async def router_policy_test_slave(
timeout=aiohttp.ClientTimeout(total=10) timeout=aiohttp.ClientTimeout(total=10)
) as resp: ) as resp:
if resp.status == 200: if resp.status == 200:
return ResponseBase(data={"connected": True}) return
else: else:
return ResponseBase( raise HTTPException(
code=400, status_code=400,
msg=f"从机响应错误HTTP {resp.status}" detail=f"从机响应错误HTTP {resp.status}",
) )
except HTTPException:
raise
except Exception as e: except Exception as e:
return ResponseBase(code=400, msg=f"连接失败: {str(e)}") raise HTTPException(status_code=400, detail=f"连接失败: {str(e)}")
@admin_policy_router.post( @admin_policy_router.post(
path='/', path='/',
summary='创建存储策略', summary='创建存储策略',
description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。', description='创建新的存储策略。对于本地存储策略,会自动创建物理目录。',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_policy_add_policy( async def router_policy_add_policy(
session: SessionDep, session: SessionDep,
request: PolicyCreateRequest, request: PolicyCreateRequest,
) -> ResponseBase: ) -> None:
""" """
创建存储策略端点 创建存储策略端点
@@ -201,12 +321,18 @@ async def router_policy_add_policy(
# 保存到数据库 # 保存到数据库
policy = await policy.save(session) policy = await policy.save(session)
return ResponseBase(data={ # 创建策略选项
"id": str(policy.id), options = PolicyOptions(
"name": policy.name, policy_id=policy.id,
"type": policy.type.value, token=request.token,
"server": policy.server, file_type=request.file_type,
}) mimetype=request.mimetype,
od_redirect=request.od_redirect,
chunk_size=request.chunk_size,
s3_path_style=request.s3_path_style,
s3_region=request.s3_region,
)
options = await options.save(session)
@admin_policy_router.post( @admin_policy_router.post(
path='/cors', path='/cors',
@@ -257,9 +383,7 @@ async def router_policy_onddrive_oauth(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: OAuth URL :return: OAuth URL
""" """
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# TODO: 实现OneDrive OAuth # TODO: 实现OneDrive OAuth
raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现") raise HTTPException(status_code=501, detail="OneDrive OAuth暂未实现")
@@ -274,7 +398,7 @@ async def router_policy_onddrive_oauth(
async def router_policy_get_policy( async def router_policy_get_policy(
session: SessionDep, session: SessionDep,
policy_id: UUID, policy_id: UUID,
) -> ResponseBase: ) -> PolicyDetailResponse:
""" """
获取存储策略详情。 获取存储策略详情。
@@ -282,9 +406,7 @@ async def router_policy_get_policy(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: 策略详情 :return: 策略详情
""" """
policy = await Policy.get(session, Policy.id == policy_id, load=Policy.options) policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 获取使用此策略的用户组 # 获取使用此策略的用户组
groups = await policy.awaitable_attrs.groups groups = await policy.awaitable_attrs.groups
@@ -292,35 +414,38 @@ async def router_policy_get_policy(
# 统计使用此策略的对象数量 # 统计使用此策略的对象数量
object_count = await Object.count(session, Object.policy_id == policy_id) object_count = await Object.count(session, Object.policy_id == policy_id)
return ResponseBase(data={ return PolicyDetailResponse(
"id": str(policy.id), id=str(policy.id),
"name": policy.name, name=policy.name,
"type": policy.type.value, type=policy.type.value,
"server": policy.server, server=policy.server,
"bucket_name": policy.bucket_name, bucket_name=policy.bucket_name,
"is_private": policy.is_private, is_private=policy.is_private,
"base_url": policy.base_url, base_url=policy.base_url,
"max_size": policy.max_size, access_key=policy.access_key,
"auto_rename": policy.auto_rename, secret_key=policy.secret_key,
"dir_name_rule": policy.dir_name_rule, max_size=policy.max_size,
"file_name_rule": policy.file_name_rule, auto_rename=policy.auto_rename,
"is_origin_link_enable": policy.is_origin_link_enable, dir_name_rule=policy.dir_name_rule,
"options": policy.options.model_dump() if policy.options else None, file_name_rule=policy.file_name_rule,
"groups": [{"id": str(g.id), "name": g.name} for g in groups], is_origin_link_enable=policy.is_origin_link_enable,
"object_count": object_count, options=policy.options.model_dump() if policy.options else None,
}) groups=[PolicyGroupInfo(id=str(g.id), name=g.name) for g in groups],
object_count=object_count,
)
@admin_policy_router.delete( @admin_policy_router.delete(
path='/{policy_id}', path='/{policy_id}',
summary='删除存储策略', summary='删除存储策略',
description='Delete storage policy by ID', description='Delete storage policy by ID',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_policy_delete_policy( async def router_policy_delete_policy(
session: SessionDep, session: SessionDep,
policy_id: UUID, policy_id: UUID,
) -> ResponseBase: ) -> None:
""" """
删除存储策略。 删除存储策略。
@@ -330,9 +455,7 @@ async def router_policy_delete_policy(
:param policy_id: 存储策略UUID :param policy_id: 存储策略UUID
:return: 删除结果 :return: 删除结果
""" """
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
# 检查是否有文件使用此策略 # 检查是否有文件使用此策略
file_count = await Object.count(session, Object.policy_id == policy_id) file_count = await Object.count(session, Object.policy_id == policy_id)
@@ -346,4 +469,105 @@ async def router_policy_delete_policy(
await Policy.delete(session, policy) await Policy.delete(session, policy)
l.info(f"管理员删除了存储策略: {policy_name}") l.info(f"管理员删除了存储策略: {policy_name}")
return ResponseBase(data={"deleted": True})
@admin_policy_router.patch(
path='/{policy_id}',
summary='更新存储策略',
description='更新存储策略配置。策略类型创建后不可更改。',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_policy_update_policy(
session: SessionDep,
policy_id: UUID,
request: PolicyUpdateRequest,
) -> None:
"""
更新存储策略端点
功能:
- 更新策略基础字段和扩展选项
- 策略类型type不可更改
认证:
- 需要管理员权限
:param session: 数据库会话
:param policy_id: 存储策略UUID
:param request: 更新请求
"""
policy = await Policy.get_exist_one(session, policy_id, load=Policy.options)
# 检查名称唯一性(如果要更新名称)
if request.name and request.name != policy.name:
existing = await Policy.get(session, Policy.name == request.name)
if existing:
raise HTTPException(status_code=409, detail="策略名称已存在")
# 分离 Policy 字段和 Options 字段
all_data = request.model_dump(exclude_unset=True)
policy_data = {k: v for k, v in all_data.items() if k not in _OPTIONS_FIELDS}
options_data = {k: v for k, v in all_data.items() if k in _OPTIONS_FIELDS}
# 更新 Policy 基础字段
if policy_data:
for key, value in policy_data.items():
setattr(policy, key, value)
policy = await policy.save(session)
# 更新或创建 PolicyOptions
if options_data:
if policy.options:
for key, value in options_data.items():
setattr(policy.options, key, value)
policy.options = await policy.options.save(session)
else:
options = PolicyOptions(policy_id=policy.id, **options_data)
options = await options.save(session)
l.info(f"管理员更新了存储策略: {policy_id}")
@admin_policy_router.post(
path='/test/s3',
summary='测试 S3 连接',
description='测试 S3 存储端点的连通性和凭据有效性。',
dependencies=[Depends(admin_required)],
)
async def router_policy_test_s3(
request: PolicyTestS3Request,
) -> PolicyTestS3Response:
"""
测试 S3 连接端点
通过向 S3 端点发送 HEAD Bucket 请求,验证凭据和网络连通性。
:param request: 测试请求
:return: 测试结果
"""
from service.storage import S3APIError
# 构造临时 Policy 对象用于创建 S3StorageService
temp_policy = Policy(
name="__test__",
type=PolicyType.S3,
server=request.server,
bucket_name=request.bucket_name,
access_key=request.access_key,
secret_key=request.secret_key,
)
s3_service = S3StorageService(
temp_policy,
region=request.s3_region,
is_path_style=request.s3_path_style,
)
try:
# 使用 file_exists 发送 HEAD 请求来验证连通性
await s3_service.file_exists("__connection_test__")
return PolicyTestS3Response(is_connected=True, message="连接成功")
except S3APIError as e:
return PolicyTestS3Response(is_connected=False, message=f"S3 API 错误: {e}")
except Exception as e:
return PolicyTestS3Response(is_connected=False, message=f"连接失败: {e}")

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -6,8 +7,53 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
ResponseBase, ListResponse, ListResponse,
Share, AdminShareListItem, ) Share, AdminShareListItem,
)
from sqlmodel_ext import SQLModelBase
class ShareDetailResponse(SQLModelBase):
"""分享详情响应"""
id: UUID
"""分享UUID"""
code: str
"""分享码"""
views: int
"""浏览次数"""
downloads: int
"""下载次数"""
remain_downloads: int | None
"""剩余下载次数"""
expires: datetime | None
"""过期时间"""
preview_enabled: bool
"""是否启用预览"""
score: int
"""评分"""
has_password: bool
"""是否有密码"""
user_id: str
"""用户UUID"""
username: str | None
"""用户名"""
object: dict | None
"""关联对象信息"""
created_at: str
"""创建时间"""
admin_share_router = APIRouter( admin_share_router = APIRouter(
prefix='/share', prefix='/share',
@@ -53,8 +99,8 @@ async def router_admin_get_share_list(
) )
async def router_admin_get_share( async def router_admin_get_share(
session: SessionDep, session: SessionDep,
share_id: int, share_id: UUID,
) -> ResponseBase: ) -> ShareDetailResponse:
""" """
获取分享详情。 获取分享详情。
@@ -69,38 +115,39 @@ async def router_admin_get_share(
obj = await share.awaitable_attrs.object obj = await share.awaitable_attrs.object
user = await share.awaitable_attrs.user user = await share.awaitable_attrs.user
return ResponseBase(data={ return ShareDetailResponse(
"id": share.id, id=share.id,
"code": share.code, code=share.code,
"views": share.views, views=share.views,
"downloads": share.downloads, downloads=share.downloads,
"remain_downloads": share.remain_downloads, remain_downloads=share.remain_downloads,
"expires": share.expires.isoformat() if share.expires else None, expires=share.expires,
"preview_enabled": share.preview_enabled, preview_enabled=share.preview_enabled,
"score": share.score, score=share.score,
"has_password": bool(share.password), has_password=bool(share.password),
"user_id": str(share.user_id), user_id=str(share.user_id),
"username": user.email if user else None, username=user.email if user else None,
"object": { object={
"id": str(obj.id), "id": str(obj.id),
"name": obj.name, "name": obj.name,
"type": obj.type.value, "type": obj.type.value,
"size": obj.size, "size": obj.size,
} if obj else None, } if obj else None,
"created_at": share.created_at.isoformat(), created_at=share.created_at.isoformat(),
}) )
@admin_share_router.delete( @admin_share_router.delete(
path='/{share_id}', path='/{share_id}',
summary='删除分享', summary='删除分享',
description='Delete share by ID', description='Delete share by ID',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_delete_share( async def router_admin_delete_share(
session: SessionDep, session: SessionDep,
share_id: int, share_id: UUID,
) -> ResponseBase: ) -> None:
""" """
删除分享。 删除分享。
@@ -108,11 +155,8 @@ async def router_admin_delete_share(
:param share_id: 分享ID :param share_id: 分享ID
:return: 删除结果 :return: 删除结果
""" """
share = await Share.get(session, Share.id == share_id) share = await Share.get_exist_one(session, share_id)
if not share:
raise HTTPException(status_code=404, detail="分享不存在")
await Share.delete(session, share) await Share.delete(session, share)
l.info(f"管理员删除了分享: {share.code}") l.info(f"管理员删除了分享: {share.code}")
return ResponseBase(data={"deleted": True})

View File

@@ -1,3 +1,4 @@
from typing import Any
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
@@ -6,9 +7,44 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import ( from sqlmodels import (
ResponseBase, ListResponse, ListResponse,
Task, TaskSummary, Task, TaskSummary, TaskStatus, TaskType,
) )
from sqlmodel_ext import SQLModelBase
class TaskDetailResponse(SQLModelBase):
"""任务详情响应"""
id: int
"""任务ID"""
status: TaskStatus
"""任务状态"""
type: TaskType
"""任务类型"""
progress: int
"""任务进度"""
error: str | None
"""错误信息"""
user_id: str
"""用户UUID"""
username: str | None
"""用户名"""
props: dict[str, Any] | None
"""任务属性"""
created_at: str
"""创建时间"""
updated_at: str
"""更新时间"""
admin_task_router = APIRouter( admin_task_router = APIRouter(
prefix='/task', prefix='/task',
@@ -67,7 +103,7 @@ async def router_admin_get_task_list(
async def router_admin_get_task( async def router_admin_get_task(
session: SessionDep, session: SessionDep,
task_id: int, task_id: int,
) -> ResponseBase: ) -> TaskDetailResponse:
""" """
获取任务详情。 获取任务详情。
@@ -82,30 +118,31 @@ async def router_admin_get_task(
user = await task.awaitable_attrs.user user = await task.awaitable_attrs.user
props = await task.awaitable_attrs.props props = await task.awaitable_attrs.props
return ResponseBase(data={ return TaskDetailResponse(
"id": task.id, id=task.id,
"status": task.status, status=task.status,
"type": task.type, type=task.type,
"progress": task.progress, progress=task.progress,
"error": task.error, error=task.error,
"user_id": str(task.user_id), user_id=str(task.user_id),
"username": user.email if user else None, username=user.email if user else None,
"props": props.model_dump() if props else None, props=props.model_dump() if props else None,
"created_at": task.created_at.isoformat(), created_at=task.created_at.isoformat(),
"updated_at": task.updated_at.isoformat(), updated_at=task.updated_at.isoformat(),
}) )
@admin_task_router.delete( @admin_task_router.delete(
path='/{task_id}', path='/{task_id}',
summary='删除任务', summary='删除任务',
description='Delete task by ID', description='Delete task by ID',
dependencies=[Depends(admin_required)] dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_delete_task( async def router_admin_delete_task(
session: SessionDep, session: SessionDep,
task_id: int, task_id: int,
) -> ResponseBase: ) -> None:
""" """
删除任务。 删除任务。
@@ -113,11 +150,8 @@ async def router_admin_delete_task(
:param task_id: 任务ID :param task_id: 任务ID
:return: 删除结果 :return: 删除结果
""" """
task = await Task.get(session, Task.id == task_id) task = await Task.get_exist_one(session, task_id)
if not task:
raise HTTPException(status_code=404, detail="任务不存在")
await Task.delete(session, task) await Task.delete(session, task)
l.info(f"管理员删除了任务: {task_id}") l.info(f"管理员删除了任务: {task_id}")
return ResponseBase(data={"deleted": True})

View File

@@ -0,0 +1,187 @@
from uuid import UUID
from fastapi import APIRouter, Depends, status
from loguru import logger as l
from sqlalchemy import update as sql_update
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from sqlmodels import (
ThemePreset,
ThemePresetCreateRequest,
ThemePresetUpdateRequest,
ThemePresetResponse,
ThemePresetListResponse,
)
from utils import http_exceptions
admin_theme_router = APIRouter(
prefix="/theme",
tags=["admin", "admin_theme"],
dependencies=[Depends(admin_required)],
)
@admin_theme_router.get(
path='/',
summary='获取主题预设列表',
)
async def router_admin_theme_list(session: SessionDep) -> ThemePresetListResponse:
"""
获取所有主题预设列表
认证:需要管理员权限
响应:
- ThemePresetListResponse: 包含所有主题预设的列表
"""
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
return ThemePresetListResponse(
themes=[ThemePresetResponse.from_preset(p) for p in presets]
)
@admin_theme_router.post(
path='/',
summary='创建主题预设',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_admin_theme_create(
session: SessionDep,
request: ThemePresetCreateRequest,
) -> None:
"""
创建新的主题预设
认证:需要管理员权限
请求体:
- name: 预设名称(唯一)
- colors: 颜色配置对象
错误处理:
- 409: 名称已存在
"""
# 检查名称唯一性
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
if existing:
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
preset = ThemePreset(
name=request.name,
**request.colors.model_dump(),
)
preset = await preset.save(session)
l.info(f"管理员创建了主题预设: {request.name}")
@admin_theme_router.patch(
path='/{preset_id}',
summary='更新主题预设',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_admin_theme_update(
session: SessionDep,
preset_id: UUID,
request: ThemePresetUpdateRequest,
) -> None:
"""
部分更新主题预设
认证:需要管理员权限
路径参数:
- preset_id: 预设UUID
请求体(均可选):
- name: 预设名称
- colors: 颜色配置对象
错误处理:
- 404: 预设不存在
- 409: 名称已被其他预设使用
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
# 检查名称唯一性(排除自身)
if request.name is not None and request.name != preset.name:
existing = await ThemePreset.get(session, ThemePreset.name == request.name)
if existing:
http_exceptions.raise_conflict(f"主题预设名称 '{request.name}' 已存在")
preset.name = request.name
# 更新颜色字段
if request.colors is not None:
color_data = request.colors.model_dump()
for key, value in color_data.items():
setattr(preset, key, value)
preset = await preset.save(session)
l.info(f"管理员更新了主题预设: {preset.name}")
@admin_theme_router.delete(
path='/{preset_id}',
summary='删除主题预设',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_admin_theme_delete(
session: SessionDep,
preset_id: UUID,
) -> None:
"""
删除主题预设
认证:需要管理员权限
路径参数:
- preset_id: 预设UUID
错误处理:
- 404: 预设不存在
副作用:
- 关联用户的 theme_preset_id 会被数据库 SET NULL
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
await preset.delete(session)
l.info(f"管理员删除了主题预设: {preset.name}")
@admin_theme_router.patch(
path='/{preset_id}/default',
summary='设为默认主题预设',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_admin_theme_set_default(
session: SessionDep,
preset_id: UUID,
) -> None:
"""
将指定预设设为默认主题
认证:需要管理员权限
路径参数:
- preset_id: 预设UUID
错误处理:
- 404: 预设不存在
逻辑:
- 事务中先清除所有旧默认,再设新默认
"""
preset = await ThemePreset.get_exist_one(session, preset_id)
# 清除所有旧默认
await session.execute(
sql_update(ThemePreset)
.where(ThemePreset.is_default == True) # noqa: E712
.values(is_default=False)
)
# 设新默认
preset.is_default = True
preset = await preset.save(session)
l.info(f"管理员将主题预设 '{preset.name}' 设为默认")

View File

@@ -12,6 +12,7 @@ from sqlmodels import (
Group, Object, ObjectType, Setting, SettingsType, Group, Object, ObjectType, Setting, SettingsType,
BatchDeleteRequest, BatchDeleteRequest,
) )
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import ( from sqlmodels.user import (
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus, UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
) )
@@ -83,14 +84,27 @@ async def router_admin_create_user(
""" """
创建一个新的用户,设置邮箱、密码、用户组等信息。 创建一个新的用户,设置邮箱、密码、用户组等信息。
管理员创建用户时,若提供了 email + password
会同时创建 AuthIdentity(provider=email_password)。
:param session: 数据库会话 :param session: 数据库会话
:param request: 创建用户请求 DTO :param request: 创建用户请求 DTO
:return: 创建结果 :return: 创建结果
""" """
# 如果提供了邮箱检查唯一性User 表和 AuthIdentity 表)
if request.email:
existing_user = await User.get(session, User.email == request.email) existing_user = await User.get(session, User.email == request.email)
if existing_user: if existing_user:
raise HTTPException(status_code=409, detail="该邮箱已被注册") raise HTTPException(status_code=409, detail="该邮箱已被注册")
existing_identity = await AuthIdentity.get(
session,
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
& (AuthIdentity.identifier == request.email),
)
if existing_identity:
raise HTTPException(status_code=409, detail="该邮箱已被绑定")
# 验证用户组存在 # 验证用户组存在
group = await Group.get(session, Group.id == request.group_id) group = await Group.get(session, Group.id == request.group_id)
if not group: if not group:
@@ -98,12 +112,25 @@ async def router_admin_create_user(
user = User( user = User(
email=request.email, email=request.email,
password=Password.hash(request.password),
nickname=request.nickname, nickname=request.nickname,
group_id=request.group_id, group_id=request.group_id,
status=request.status, status=request.status,
) )
user = await user.save(session) user = await user.save(session)
# 如果提供了邮箱和密码,创建邮箱密码认证身份
if request.email and request.password:
identity = AuthIdentity(
provider=AuthProviderType.EMAIL_PASSWORD,
identifier=request.email,
credential=Password.hash(request.password),
is_primary=True,
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
user = await User.get(session, User.id == user.id, load=User.group)
return user.to_public() return user.to_public()
@@ -127,9 +154,7 @@ async def router_admin_update_user(
:param request: 更新请求 :param request: 更新请求
:return: 更新结果 :return: 更新结果
""" """
user = await User.get(session, User.id == user_id) user = await User.get_exist_one(session, user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别) # 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
default_admin_setting = await Setting.get( default_admin_setting = await Setting.get(
@@ -148,17 +173,7 @@ async def router_admin_update_user(
if not group: if not group:
raise HTTPException(status_code=400, detail="目标用户组不存在") raise HTTPException(status_code=400, detail="目标用户组不存在")
# 如果更新密码,需要加密
update_data = request.model_dump(exclude_unset=True) update_data = request.model_dump(exclude_unset=True)
if 'password' in update_data and update_data['password']:
update_data['password'] = Password.hash(update_data['password'])
elif 'password' in update_data:
del update_data['password'] # 空密码不更新
# 验证两步验证密钥格式(如果提供了值且不为 None长度必须为 32
if 'two_factor' in update_data and update_data['two_factor'] is not None:
if len(update_data['two_factor']) != 32:
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
# 记录旧 status 以便检测变更 # 记录旧 status 以便检测变更
old_status = user.status old_status = user.status
@@ -175,7 +190,7 @@ async def router_admin_update_user(
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE: elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
await UserBanStore.unban(str(user_id)) await UserBanStore.unban(str(user_id))
l.info(f"管理员更新了用户: {request.email}") l.info(f"管理员更新了用户: {user.email}")
@admin_user_router.delete( @admin_user_router.delete(
@@ -198,9 +213,17 @@ async def router_admin_delete_users(
:param request: 批量删除请求,包含待删除用户的 UUID 列表 :param request: 批量删除请求,包含待删除用户的 UUID 列表
:return: 删除结果(已删除数 / 总请求数) :return: 删除结果(已删除数 / 总请求数)
""" """
deleted = 0
for uid in request.ids: for uid in request.ids:
user = await User.get(session, User.id == uid) user = await User.get(session, User.id == uid, load=User.group)
# 安全检查:默认管理员不允许被删除(通过 Setting 中的 default_admin_id 识别)
default_admin_setting = await Setting.get(
session,
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
)
if user and default_admin_setting and default_admin_setting.value == str(uid):
raise HTTPException(status_code=403, detail=f"默认管理员不允许被删除")
if user: if user:
await User.delete(session, user) await User.delete(session, user)
l.info(f"管理员删除了用户: {user.email}") l.info(f"管理员删除了用户: {user.email}")
@@ -228,13 +251,12 @@ async def router_admin_calibrate_storage(
:param user_id: 用户UUID :param user_id: 用户UUID
:return: 校准结果 :return: 校准结果
""" """
user = await User.get(session, User.id == user_id) user = await User.get_exist_one(session, user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
previous_storage = user.storage previous_storage = user.storage
# 计算实际存储量 - 使用 SQL 聚合 # 计算实际存储量 - 使用 SQL 聚合
# [TODO] 不应这么计算,看看 SQLModel_Ext 库怎么解决
from sqlmodel import select from sqlmodel import select
result = await session.execute( result = await session.execute(
select(func.sum(Object.size), func.count(Object.id)).where( select(func.sum(Object.size), func.count(Object.id)).where(

View File

@@ -1,81 +0,0 @@
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from sqlmodels import (
ResponseBase,
)
admin_vas_router = APIRouter(
prefix='/vas',
tags=['admin', 'admin_vas']
)
@admin_vas_router.get(
path='/list',
summary='获取增值服务列表',
description='Get VAS list (orders and storage packs)',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas_list(
session: SessionDep,
user_id: UUID | None = None,
page: int = 1,
page_size: int = 20,
) -> ResponseBase:
"""
获取增值服务列表(订单和存储包)。
:param session: 数据库会话
:param user_id: 按用户筛选
:param page: 页码
:param page_size: 每页数量
:return: 增值服务列表
"""
# TODO: 实现增值服务列表
# 需要查询 Order 和 StoragePack 模型
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.get(
path='/{vas_id}',
summary='获取增值服务详情',
description='Get VAS detail by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_get_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
获取增值服务详情。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 增值服务详情
"""
# TODO: 实现增值服务详情
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")
@admin_vas_router.delete(
path='/{vas_id}',
summary='删除增值服务',
description='Delete VAS by ID',
dependencies=[Depends(admin_required)]
)
async def router_admin_delete_vas(
session: SessionDep,
vas_id: UUID,
) -> ResponseBase:
"""
删除增值服务。
:param session: 数据库会话
:param vas_id: 增值服务UUID
:return: 删除结果
"""
# TODO: 实现增值服务删除
raise HTTPException(status_code=501, detail="增值服务管理暂未实现")

View File

@@ -1,5 +1,6 @@
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from loguru import logger as l
from sqlmodels import ResponseBase from sqlmodels import ResponseBase
import service.oauth import service.oauth
@@ -15,18 +16,12 @@ oauth_router = APIRouter(
tags=["callback", "oauth"], tags=["callback", "oauth"],
) )
pay_router = APIRouter(
prefix='/callback/pay',
tags=["callback", "pay"],
)
upload_router = APIRouter( upload_router = APIRouter(
prefix='/callback/upload', prefix='/callback/upload',
tags=["callback", "upload"], tags=["callback", "upload"],
) )
callback_router.include_router(oauth_router) callback_router.include_router(oauth_router)
callback_router.include_router(pay_router)
callback_router.include_router(upload_router) callback_router.include_router(upload_router)
@oauth_router.post( @oauth_router.post(
@@ -64,91 +59,17 @@ async def router_callback_github(
""" """
try: try:
access_token = await service.oauth.github.get_access_token(code) access_token = await service.oauth.github.get_access_token(code)
# [TODO] 把access_token写数据库里
if not access_token: if not access_token:
return PlainTextResponse("Failed to retrieve access token from GitHub.", status_code=400) return PlainTextResponse("GitHub 认证失败", status_code=400)
user_data = await service.oauth.github.get_user_info(access_token.access_token) user_data = await service.oauth.github.get_user_info(access_token.access_token)
# [TODO] 把user_data写数据库 # [TODO] 把 access_token 和 user_data 写数据库,生成 JWT重定向到前端
l.info(f"GitHub OAuth 回调成功: user={user_data.user_data.login}")
return PlainTextResponse(f"User information processed successfully, code: {code}, user_data: {user_data.json_dump()}", status_code=200) return PlainTextResponse("认证成功,功能开发中", status_code=200)
except Exception as e: except Exception as e:
return PlainTextResponse(f"An error occurred: {str(e)}", status_code=500) l.error(f"GitHub OAuth 回调异常: {e}")
return PlainTextResponse("认证过程中发生错误,请重试", status_code=500)
@pay_router.post(
path='/alipay',
summary='支付宝支付回调',
description='Handle Alipay payment callback and return payment status.',
)
def router_callback_alipay() -> ResponseBase:
"""
Handle Alipay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Alipay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/wechat',
summary='微信支付回调',
description='Handle WeChat Pay payment callback and return payment status.',
)
def router_callback_wechat() -> ResponseBase:
"""
Handle WeChat Pay payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the WeChat Pay payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.post(
path='/stripe',
summary='Stripe支付回调',
description='Handle Stripe payment callback and return payment status.',
)
def router_callback_stripe() -> ResponseBase:
"""
Handle Stripe payment callback and return payment status.
Returns:
ResponseBase: A model containing the response data for the Stripe payment callback.
"""
http_exceptions.raise_not_implemented()
@pay_router.get(
path='/easypay',
summary='易支付回调',
description='Handle EasyPay payment callback and return payment status.',
)
def router_callback_easypay() -> PlainTextResponse:
"""
Handle EasyPay payment callback and return payment status.
Returns:
PlainTextResponse: A response containing the payment status for the EasyPay payment callback.
"""
http_exceptions.raise_not_implemented()
# return PlainTextResponse("success", status_code=200)
@pay_router.get(
path='/custom/{order_no}/{id}',
summary='自定义支付回调',
description='Handle custom payment callback and return payment status.',
)
def router_callback_custom(order_no: str, id: str) -> ResponseBase:
"""
Handle custom payment callback and return payment status.
Args:
order_no (str): The order number for the payment.
id (str): The ID associated with the payment.
Returns:
ResponseBase: A model containing the response data for the custom payment callback.
"""
http_exceptions.raise_not_implemented()
@upload_router.post( @upload_router.post(
path='/remote/{session_id}/{key}', path='/remote/{session_id}/{key}',

View File

@@ -0,0 +1,100 @@
"""
文件分类筛选端点
按文件类型分类(图片/视频/音频/文档)查询用户的所有文件,
跨目录搜索,支持分页。扩展名映射从数据库 Setting 表读取。
"""
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from sqlmodels import (
FileCategory,
ListResponse,
Object,
ObjectResponse,
ObjectType,
Setting,
SettingsType,
User,
)
category_router = APIRouter(
prefix="/category",
tags=["category"],
)
@category_router.get(
path="/{category}",
summary="按分类获取文件列表",
)
async def router_category_list(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
category: FileCategory,
table_view: TableViewRequestDep,
) -> ListResponse[ObjectResponse]:
"""
按文件类型分类查询用户的所有文件
跨所有目录搜索,返回分页结果。
扩展名配置从数据库 Setting 表读取type=file_category
认证:
- JWT token in Authorization header
路径参数:
- category: 文件分类image / video / audio / document
查询参数:
- offset: 分页偏移量默认0
- limit: 每页数量默认20最大100
- desc: 是否降序默认true
- order: 排序字段created_at / updated_at
响应:
- ListResponse[ObjectResponse]: 分页文件列表
错误处理:
- HTTPException 422: category 参数无效
- HTTPException 404: 该分类未配置扩展名
"""
# 从数据库读取该分类的扩展名配置
setting = await Setting.get(
session,
(Setting.type == SettingsType.FILE_CATEGORY) & (Setting.name == category.value),
)
if not setting or not setting.value:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 未配置扩展名")
extensions = [ext.strip() for ext in setting.value.split(",") if ext.strip()]
if not extensions:
raise HTTPException(status_code=404, detail=f"分类 {category.value} 扩展名列表为空")
result = await Object.get_by_category(
session,
user.id,
extensions,
table_view=table_view,
)
items = [
ObjectResponse(
id=obj.id,
name=obj.name,
type=ObjectType.FILE,
size=obj.size,
mime_type=obj.mime_type,
thumb=False,
created_at=obj.created_at,
updated_at=obj.updated_at,
source_enabled=False,
)
for obj in result.items
]
return ListResponse(count=result.count, items=items)

View File

@@ -57,7 +57,7 @@ async def _get_directory_response(
policy_response = PolicyResponse( policy_response = PolicyResponse(
id=policy.id, id=policy.id,
name=policy.name, name=policy.name,
type=policy.type.value, type=policy.type,
max_size=policy.max_size, max_size=policy.max_size,
) )
@@ -139,12 +139,13 @@ async def router_directory_get(
@directory_router.post( @directory_router.post(
path="/", path="/",
summary="创建目录", summary="创建目录",
status_code=204,
) )
async def router_directory_create( async def router_directory_create(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: DirectoryCreateRequest request: DirectoryCreateRequest
) -> ResponseBase: ) -> None:
""" """
创建目录 创建目录
@@ -162,8 +163,11 @@ async def router_directory_create(
if "/" in name or "\\" in name: if "/" in name or "\\" in name:
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠") raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
# 通过 UUID 获取父目录 # 通过 UUID 获取父目录(排除已删除的)
parent = await Object.get(session, Object.id == request.parent_id) parent = await Object.get(
session,
(Object.id == request.parent_id) & (Object.deleted_at == None)
)
if not parent or parent.owner_id != user.id: if not parent or parent.owner_id != user.id:
raise HTTPException(status_code=404, detail="父目录不存在") raise HTTPException(status_code=404, detail="父目录不存在")
@@ -173,17 +177,26 @@ async def router_directory_create(
if parent.is_banned: if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名对象 # 检查是否已存在同名对象(仅检查未删除的)
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user.id) & (Object.owner_id == user.id) &
(Object.parent_id == parent.id) & (Object.parent_id == parent.id) &
(Object.name == name) (Object.name == name) &
(Object.deleted_at == None)
) )
if existing: if existing:
raise HTTPException(status_code=409, detail="同名文件或目录已存在") raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id policy_id = request.policy_id if request.policy_id else parent.policy_id
# 校验用户组是否有权使用该策略(仅当用户显式指定 policy_id 时)
if request.policy_id:
group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
if request.policy_id not in {p.id for p in group.policies}:
raise HTTPException(status_code=403, detail="当前用户组无权使用该存储策略")
parent_id = parent.id # 在 save 前保存 parent_id = parent.id # 在 save 前保存
new_folder = Object( new_folder = Object(
@@ -193,14 +206,4 @@ async def router_directory_create(
parent_id=parent_id, parent_id=parent_id,
policy_id=policy_id, policy_id=policy_id,
) )
new_folder_id = new_folder.id # 在 save 前保存 UUID new_folder = await new_folder.save(session)
new_folder_name = new_folder.name
await new_folder.save(session)
return ResponseBase(
data={
"id": new_folder_id,
"name": new_folder_name,
"parent_id": parent_id,
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,106 @@
"""
文件查看器查询端点
提供按文件扩展名查询可用查看器的功能,包含用户组访问控制过滤。
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from sqlalchemy import and_
from sqlalchemy import select
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
FileApp,
FileAppExtension,
FileAppGroupLink,
FileAppSummary,
FileViewersResponse,
User,
UserFileAppDefault,
)
viewers_router = APIRouter(prefix="/viewers", tags=["file", "viewers"])
@viewers_router.get(
path='',
summary='查询可用文件查看器',
description='根据文件扩展名查询可用的查看器应用列表。',
)
async def get_viewers(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
ext: Annotated[str, Query(max_length=20, description="文件扩展名")],
) -> FileViewersResponse:
"""
查询可用文件查看器端点
流程:
1. 规范化扩展名(小写,去点号)
2. 查询匹配的已启用应用
3. 按用户组权限过滤
4. 按 priority 排序
5. 查询用户默认偏好
认证JWT token 必填
错误处理:
- 401: 未授权
"""
# 规范化扩展名
normalized_ext = ext.lower().strip().lstrip('.')
# 查询匹配扩展名的应用(已启用的)
ext_records: list[FileAppExtension] = await FileAppExtension.get(
session,
and_(
FileAppExtension.extension == normalized_ext,
),
fetch_mode="all",
load=FileAppExtension.app,
)
# 过滤和收集可用应用
user_group_id = user.group_id
viewers: list[tuple[FileAppSummary, int]] = []
for ext_record in ext_records:
app: FileApp = ext_record.app
if not app.is_enabled:
continue
if app.is_restricted:
# 检查用户组权限FileAppGroupLink 是纯关联表,使用 session 查询)
stmt = select(FileAppGroupLink).where(
and_(
FileAppGroupLink.app_id == app.id,
FileAppGroupLink.group_id == user_group_id,
)
)
result = await session.exec(stmt)
group_link = result.first()
if not group_link:
continue
viewers.append((app.to_summary(), ext_record.priority))
# 按 priority 排序
viewers.sort(key=lambda x: x[1])
# 查询用户默认偏好
user_default: UserFileAppDefault | None = await UserFileAppDefault.get(
session,
and_(
UserFileAppDefault.user_id == user.id,
UserFileAppDefault.extension == normalized_ext,
),
)
return FileViewersResponse(
viewers=[v[0] for v in viewers],
default_viewer_id=user_default.app_id if user_default else None,
)

View File

@@ -8,14 +8,14 @@
from typing import Annotated from typing import Annotated
from uuid import UUID from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger as l from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import ( from sqlmodels import (
CreateFileRequest, CreateFileRequest,
Group,
Object, Object,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
@@ -23,170 +23,54 @@ from sqlmodels import (
ObjectPropertyDetailResponse, ObjectPropertyDetailResponse,
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
PhysicalFile, PhysicalFile,
Policy, Policy,
PolicyType, PolicyType,
ResponseBase, Task,
TaskProps,
TaskStatus,
TaskSummaryBase,
TaskType,
User, User,
# 元数据相关
ObjectMetadata,
MetadataResponse,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
) )
from service.storage import LocalStorageService from service.storage import (
LocalStorageService,
adjust_user_storage,
copy_object_recursive,
migrate_file_with_task,
migrate_directory_files,
)
from service.storage.object import soft_delete_objects
from sqlmodels.database_connection import DatabaseManager
from utils import http_exceptions from utils import http_exceptions
from .custom_property import router as custom_property_router
object_router = APIRouter( object_router = APIRouter(
prefix="/object", prefix="/object",
tags=["object"] tags=["object"]
) )
object_router.include_router(custom_property_router)
async def _delete_object_recursive(
session: AsyncSession,
obj: Object,
user_id: UUID,
) -> int:
"""
递归删除对象(软删除)
对于文件:
- 减少 PhysicalFile 引用计数
- 只有引用计数为0时才移动物理文件到回收站
对于目录:
- 递归处理所有子对象
:param session: 数据库会话
:param obj: 要删除的对象
:param user_id: 用户UUID
:return: 删除的对象数量
"""
deleted_count = 0
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
obj_id = obj.id
obj_name = obj.name
obj_is_folder = obj.is_folder
obj_is_file = obj.is_file
obj_physical_file_id = obj.physical_file_id
if obj_is_folder:
# 递归删除子对象
children = await Object.get_children(session, user_id, obj_id)
for child in children:
deleted_count += await _delete_object_recursive(session, child, user_id)
# 如果是文件,处理物理文件引用
if obj_is_file and obj_physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
if physical_file:
# 减少引用计数
new_count = physical_file.decrement_reference()
if physical_file.can_be_deleted:
# 引用计数为0移动物理文件到回收站
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy and policy.type == PolicyType.LOCAL:
try:
storage_service = LocalStorageService(policy)
await storage_service.move_to_trash(
source_path=physical_file.storage_path,
user_id=user_id,
object_id=obj_id,
)
l.debug(f"物理文件已移动到回收站: {obj_name}")
except Exception as e:
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
# 删除 PhysicalFile 记录
await PhysicalFile.delete(session, physical_file)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
# 还有其他引用,只更新引用计数
await physical_file.save(session)
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
# 使用条件删除,避免访问过期的 obj 实例
await Object.delete(session, condition=Object.id == obj_id)
deleted_count += 1
return deleted_count
async def _copy_object_recursive(
session: AsyncSession,
src: Object,
dst_parent_id: UUID,
user_id: UUID,
) -> tuple[int, list[UUID]]:
"""
递归复制对象
对于文件:
- 增加 PhysicalFile 引用计数
- 创建新的 Object 记录指向同一 PhysicalFile
对于目录:
- 创建新目录
- 递归复制所有子对象
:param session: 数据库会话
:param src: 源对象
:param dst_parent_id: 目标父目录UUID
:param user_id: 用户UUID
:return: (复制数量, 新对象UUID列表)
"""
copied_count = 0
new_ids: list[UUID] = []
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
src_is_folder = src.is_folder
src_id = src.id
# 创建新的 Object 记录
new_obj = Object(
name=src.name,
type=src.type,
size=src.size,
password=src.password,
parent_id=dst_parent_id,
owner_id=user_id,
policy_id=src.policy_id,
physical_file_id=src.physical_file_id,
)
# 如果是文件,增加物理文件引用计数
if src.is_file and src.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src.physical_file_id)
if physical_file:
physical_file.increment_reference()
await physical_file.save(session)
new_obj = await new_obj.save(session)
copied_count += 1
new_ids.append(new_obj.id)
# 如果是目录,递归复制子对象
if src_is_folder:
children = await Object.get_children(session, user_id, src_id)
for child in children:
child_count, child_ids = await _copy_object_recursive(
session, child, new_obj.id, user_id
)
copied_count += child_count
new_ids.extend(child_ids)
return copied_count, new_ids
@object_router.post( @object_router.post(
path='/', path='/',
summary='创建空白文件', summary='创建空白文件',
description='在指定目录下创建空白文件。', description='在指定目录下创建空白文件。',
status_code=204,
) )
async def router_object_create( async def router_object_create(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: CreateFileRequest, request: CreateFileRequest,
) -> ResponseBase: ) -> None:
""" """
创建空白文件端点 创建空白文件端点
@@ -201,8 +85,11 @@ async def router_object_create(
if not request.name or '/' in request.name or '\\' in request.name: if not request.name or '/' in request.name or '\\' in request.name:
raise HTTPException(status_code=400, detail="无效的文件名") raise HTTPException(status_code=400, detail="无效的文件名")
# 验证父目录 # 验证父目录(排除已删除的)
parent = await Object.get(session, Object.id == request.parent_id) parent = await Object.get(
session,
(Object.id == request.parent_id) & (Object.deleted_at == None)
)
if not parent or parent.owner_id != user_id: if not parent or parent.owner_id != user_id:
raise HTTPException(status_code=404, detail="父目录不存在") raise HTTPException(status_code=404, detail="父目录不存在")
@@ -212,21 +99,20 @@ async def router_object_create(
if parent.is_banned: if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名文件 # 检查是否已存在同名文件(仅检查未删除的)
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user_id) & (Object.owner_id == user_id) &
(Object.parent_id == parent.id) & (Object.parent_id == parent.id) &
(Object.name == request.name) (Object.name == request.name) &
(Object.deleted_at == None)
) )
if existing: if existing:
raise HTTPException(status_code=409, detail="同名文件已存在") raise HTTPException(status_code=409, detail="同名文件已存在")
# 确定存储策略 # 确定存储策略
policy_id = request.policy_id or parent.policy_id policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get_exist_one(session, policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
parent_id = parent.id parent_id = parent.id
@@ -261,44 +147,45 @@ async def router_object_create(
owner_id=user_id, owner_id=user_id,
policy_id=policy_id, policy_id=policy_id,
) )
await file_object.save(session) file_object = await file_object.save(session)
l.info(f"创建空白文件: {request.name}") l.info(f"创建空白文件: {request.name}")
return ResponseBase()
@object_router.delete( @object_router.delete(
path='/', path='/',
summary='删除对象', summary='删除对象',
description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。', description='删除一个或多个对象(文件或目录),文件会移动到用户回收站。',
status_code=204,
) )
async def router_object_delete( async def router_object_delete(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: ObjectDeleteRequest, request: ObjectDeleteRequest,
) -> ResponseBase: ) -> None:
""" """
删除对象端点(软删除) 删除对象端点(软删除到回收站
流程: 流程:
1. 验证对象存在且属于当前用户 1. 验证对象存在且属于当前用户
2. 对于文件,减少物理文件引用计数 2. 设置 deleted_at 时间戳
3. 如果引用计数为0移动物理文件到 .trash 目录 3. 保存原 parent_id 到 deleted_original_parent_id
4. 对于目录,递归处理子对象 4. 将 parent_id 置 NULL 脱离文件树
5. 从数据库中删除记录 5. 子对象和物理文件不做任何变更
:param session: 数据库会话 :param session: 数据库会话
:param user: 当前登录用户 :param user: 当前登录用户
:param request: 删除请求包含待删除对象的UUID列表 :param request: 删除请求包含待删除对象的UUID列表
:return: 删除结果 :return: 删除结果
""" """
# 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id user_id = user.id
deleted_count = 0 objects_to_delete: list[Object] = []
for obj_id in request.ids: for obj_id in request.ids:
obj = await Object.get(session, Object.id == obj_id) obj = await Object.get(
session,
(Object.id == obj_id) & (Object.deleted_at == None)
)
if not obj or obj.owner_id != user_id: if not obj or obj.owner_id != user_id:
continue continue
@@ -307,30 +194,24 @@ async def router_object_delete(
l.warning(f"尝试删除根目录被阻止: {obj.name}") l.warning(f"尝试删除根目录被阻止: {obj.name}")
continue continue
# 递归删除(包含引用计数逻辑) objects_to_delete.append(obj)
count = await _delete_object_recursive(session, obj, user_id)
deleted_count += count
l.info(f"用户 {user_id} 删除了 {deleted_count} 个对象") if objects_to_delete:
deleted_count = await soft_delete_objects(session, objects_to_delete)
return ResponseBase( l.info(f"用户 {user_id} 软删除了 {deleted_count} 个对象到回收站")
data={
"deleted": deleted_count,
"total": len(request.ids),
}
)
@object_router.patch( @object_router.patch(
path='/', path='/',
summary='移动对象', summary='移动对象',
description='移动一个或多个对象到目标目录', description='移动一个或多个对象到目标目录',
status_code=204,
) )
async def router_object_move( async def router_object_move(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: ObjectMoveRequest, request: ObjectMoveRequest,
) -> ResponseBase: ) -> None:
""" """
移动对象端点 移动对象端点
@@ -342,8 +223,11 @@ async def router_object_move(
# 存储 user.id避免后续 save() 导致 user 过期后无法访问 # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id user_id = user.id
# 验证目标目录 # 验证目标目录(排除已删除的)
dst = await Object.get(session, Object.id == request.dst_id) dst = await Object.get(
session,
(Object.id == request.dst_id) & (Object.deleted_at == None)
)
if not dst or dst.owner_id != user_id: if not dst or dst.owner_id != user_id:
raise HTTPException(status_code=404, detail="目标目录不存在") raise HTTPException(status_code=404, detail="目标目录不存在")
@@ -360,7 +244,10 @@ async def router_object_move(
moved_count = 0 moved_count = 0
for src_id in request.src_ids: for src_id in request.src_ids:
src = await Object.get(session, Object.id == src_id) src = await Object.get(
session,
(Object.id == src_id) & (Object.deleted_at == None)
)
if not src or src.owner_id != user_id: if not src or src.owner_id != user_id:
continue continue
@@ -388,12 +275,13 @@ async def router_object_move(
if is_cycle: if is_cycle:
continue continue
# 检查目标目录下是否存在同名对象 # 检查目标目录下是否存在同名对象(仅检查未删除的)
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user_id) & (Object.owner_id == user_id) &
(Object.parent_id == dst_id) & (Object.parent_id == dst_id) &
(Object.name == src.name) (Object.name == src.name) &
(Object.deleted_at == None)
) )
if existing: if existing:
continue # 跳过重名对象 continue # 跳过重名对象
@@ -405,24 +293,18 @@ async def router_object_move(
# 统一提交所有更改 # 统一提交所有更改
await session.commit() await session.commit()
return ResponseBase(
data={
"moved": moved_count,
"total": len(request.src_ids),
}
)
@object_router.post( @object_router.post(
path='/copy', path='/copy',
summary='复制对象', summary='复制对象',
description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。', description='复制一个或多个对象到目标目录。文件复制仅增加物理文件引用计数,不复制物理文件。',
status_code=204,
) )
async def router_object_copy( async def router_object_copy(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: ObjectCopyRequest, request: ObjectCopyRequest,
) -> ResponseBase: ) -> None:
""" """
复制对象端点 复制对象端点
@@ -443,8 +325,11 @@ async def router_object_copy(
# 存储 user.id避免后续 save() 导致 user 过期后无法访问 # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id user_id = user.id
# 验证目标目录 # 验证目标目录(排除已删除的)
dst = await Object.get(session, Object.id == request.dst_id) dst = await Object.get(
session,
(Object.id == request.dst_id) & (Object.deleted_at == None)
)
if not dst or dst.owner_id != user_id: if not dst or dst.owner_id != user_id:
raise HTTPException(status_code=404, detail="目标目录不存在") raise HTTPException(status_code=404, detail="目标目录不存在")
@@ -456,20 +341,25 @@ async def router_object_copy(
copied_count = 0 copied_count = 0
new_ids: list[UUID] = [] new_ids: list[UUID] = []
total_copied_size = 0
for src_id in request.src_ids: for src_id in request.src_ids:
src = await Object.get(session, Object.id == src_id) src = await Object.get(
session,
(Object.id == src_id) & (Object.deleted_at == None)
)
if not src or src.owner_id != user_id: if not src or src.owner_id != user_id:
continue continue
if src.is_banned: if src.is_banned:
continue http_exceptions.raise_banned("源对象已被封禁,无法执行此操作")
# 不能复制根目录 # 不能复制根目录
if src.parent_id is None: if src.parent_id is None:
continue http_exceptions.raise_banned("无法复制根目录")
# 不能复制到自身 # 不能复制到自身
# [TODO] 视为创建副本
if src.id == dst.id: if src.id == dst.id:
continue continue
@@ -485,42 +375,42 @@ async def router_object_copy(
if is_cycle: if is_cycle:
continue continue
# 检查目标目录下是否存在同名对象 # 检查目标目录下是否存在同名对象(仅检查未删除的)
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user_id) & (Object.owner_id == user_id) &
(Object.parent_id == dst.id) & (Object.parent_id == dst.id) &
(Object.name == src.name) (Object.name == src.name) &
(Object.deleted_at == None)
) )
if existing: if existing:
continue # 跳过重名对象 # [TODO] 应当询问用户是否覆盖、跳过或创建副本
continue
# 递归复制 # 递归复制
count, ids = await _copy_object_recursive(session, src, dst.id, user_id) count, ids, copied_size = await copy_object_recursive(session, src, dst.id, user_id)
copied_count += count copied_count += count
new_ids.extend(ids) new_ids.extend(ids)
total_copied_size += copied_size
# 更新用户存储配额
if total_copied_size > 0:
await adjust_user_storage(session, user_id, total_copied_size)
l.info(f"用户 {user_id} 复制了 {copied_count} 个对象") l.info(f"用户 {user_id} 复制了 {copied_count} 个对象")
return ResponseBase(
data={
"copied": copied_count,
"total": len(request.src_ids),
"new_ids": new_ids,
}
)
@object_router.post( @object_router.post(
path='/rename', path='/rename',
summary='重命名对象', summary='重命名对象',
description='重命名对象(文件或目录)。', description='重命名对象(文件或目录)。',
status_code=204,
) )
async def router_object_rename( async def router_object_rename(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: ObjectRenameRequest, request: ObjectRenameRequest,
) -> ResponseBase: ) -> None:
""" """
重命名对象端点 重命名对象端点
@@ -539,8 +429,11 @@ async def router_object_rename(
# 存储 user.id避免后续 save() 导致 user 过期后无法访问 # 存储 user.id避免后续 save() 导致 user 过期后无法访问
user_id = user.id user_id = user.id
# 验证对象存在 # 验证对象存在(排除已删除的)
obj = await Object.get(session, Object.id == request.id) obj = await Object.get(
session,
(Object.id == request.id) & (Object.deleted_at == None)
)
if not obj: if not obj:
raise HTTPException(status_code=404, detail="对象不存在") raise HTTPException(status_code=404, detail="对象不存在")
@@ -562,28 +455,27 @@ async def router_object_rename(
if '/' in new_name or '\\' in new_name: if '/' in new_name or '\\' in new_name:
raise HTTPException(status_code=400, detail="名称不能包含斜杠") raise HTTPException(status_code=400, detail="名称不能包含斜杠")
# 如果名称没有变化,直接返回成功 # 如果名称没有变化,直接返回
if obj.name == new_name: if obj.name == new_name:
return ResponseBase(data={"success": True}) return # noqa: already 204
# 检查同目录下是否存在同名对象 # 检查同目录下是否存在同名对象(仅检查未删除的)
existing = await Object.get( existing = await Object.get(
session, session,
(Object.owner_id == user_id) & (Object.owner_id == user_id) &
(Object.parent_id == obj.parent_id) & (Object.parent_id == obj.parent_id) &
(Object.name == new_name) (Object.name == new_name) &
(Object.deleted_at == None)
) )
if existing: if existing:
raise HTTPException(status_code=409, detail="同名对象已存在") raise HTTPException(status_code=409, detail="同名对象已存在")
# 更新名称 # 更新名称
obj.name = new_name obj.name = new_name
await obj.save(session) obj = await obj.save(session)
l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}") l.info(f"用户 {user_id} 将对象 {obj.id} 重命名为 {new_name}")
return ResponseBase(data={"success": True})
@object_router.get( @object_router.get(
path='/property/{id}', path='/property/{id}',
@@ -603,7 +495,10 @@ async def router_object_property(
:param id: 对象UUID :param id: 对象UUID
:return: 对象基本属性 :return: 对象基本属性
""" """
obj = await Object.get(session, Object.id == id) obj = await Object.get(
session,
(Object.id == id) & (Object.deleted_at == None)
)
if not obj: if not obj:
raise HTTPException(status_code=404, detail="对象不存在") raise HTTPException(status_code=404, detail="对象不存在")
@@ -615,6 +510,7 @@ async def router_object_property(
name=obj.name, name=obj.name,
type=obj.type, type=obj.type,
size=obj.size, size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at, created_at=obj.created_at,
updated_at=obj.updated_at, updated_at=obj.updated_at,
parent_id=obj.parent_id, parent_id=obj.parent_id,
@@ -641,8 +537,8 @@ async def router_object_property_detail(
""" """
obj = await Object.get( obj = await Object.get(
session, session,
Object.id == id, (Object.id == id) & (Object.deleted_at == None),
load=Object.file_metadata, load=Object.metadata_entries,
) )
if not obj: if not obj:
raise HTTPException(status_code=404, detail="对象不存在") raise HTTPException(status_code=404, detail="对象不存在")
@@ -665,35 +561,301 @@ async def router_object_property_detail(
total_views = sum(s.views for s in shares) total_views = sum(s.views for s in shares)
total_downloads = sum(s.downloads for s in shares) total_downloads = sum(s.downloads for s in shares)
# 获取物理文件引用计数 # 获取物理文件信息(引用计数、校验和)
reference_count = 1 reference_count = 1
checksum_md5: str | None = None
checksum_sha256: str | None = None
if obj.physical_file_id: if obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if physical_file: if physical_file:
reference_count = physical_file.reference_count reference_count = physical_file.reference_count
checksum_md5 = physical_file.checksum_md5
checksum_sha256 = physical_file.checksum_sha256
# 构建响应 # 构建元数据字典(排除内部命名空间)
response = ObjectPropertyDetailResponse( metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
ns = entry.name.split(":")[0] if ":" in entry.name else ""
if ns not in INTERNAL_NAMESPACES:
metadata[entry.name] = entry.value
return ObjectPropertyDetailResponse(
id=obj.id, id=obj.id,
name=obj.name, name=obj.name,
type=obj.type, type=obj.type,
size=obj.size, size=obj.size,
mime_type=obj.mime_type,
created_at=obj.created_at, created_at=obj.created_at,
updated_at=obj.updated_at, updated_at=obj.updated_at,
parent_id=obj.parent_id, parent_id=obj.parent_id,
checksum_md5=checksum_md5,
checksum_sha256=checksum_sha256,
policy_name=policy_name, policy_name=policy_name,
share_count=share_count, share_count=share_count,
total_views=total_views, total_views=total_views,
total_downloads=total_downloads, total_downloads=total_downloads,
reference_count=reference_count, reference_count=reference_count,
metadatas=metadata,
) )
# 添加文件元数据
if obj.file_metadata:
response.mime_type = obj.file_metadata.mime_type
response.width = obj.file_metadata.width
response.height = obj.file_metadata.height
response.duration = obj.file_metadata.duration
response.checksum_md5 = obj.file_metadata.checksum_md5
return response @object_router.patch(
path='/{object_id}/policy',
summary='切换对象存储策略',
)
async def router_object_switch_policy(
session: SessionDep,
background_tasks: BackgroundTasks,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: ObjectSwitchPolicyRequest,
) -> TaskSummaryBase:
"""
切换对象的存储策略
文件:立即创建后台迁移任务,将文件从源策略搬到目标策略。
目录:更新目录 policy_id新文件使用新策略
若 is_migrate_existing=True额外创建后台任务迁移所有已有文件。
认证JWT Bearer Token
错误处理:
- 404: 对象不存在
- 403: 无权操作此对象 / 用户组无权使用目标策略
- 400: 目标策略与当前相同 / 不能对根目录操作
"""
user_id = user.id
# 查找对象
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None)
)
if not obj:
http_exceptions.raise_not_found("对象不存在")
if obj.owner_id != user_id:
http_exceptions.raise_forbidden("无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 根目录不能直接切换策略(应通过子对象或子目录操作)
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="不能对根目录切换存储策略,请对子目录操作")
# 校验目标策略存在
dest_policy = await Policy.get(session, Policy.id == request.policy_id)
if not dest_policy:
http_exceptions.raise_not_found("目标存储策略不存在")
# 校验用户组权限
group: Group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
allowed_ids = {p.id for p in group.policies}
if request.policy_id not in allowed_ids:
http_exceptions.raise_forbidden("当前用户组无权使用该存储策略")
# 不能切换到相同策略
if obj.policy_id == request.policy_id:
raise HTTPException(status_code=400, detail="目标策略与当前策略相同")
# 保存必要的属性,避免 save 后对象过期
src_policy_id = obj.policy_id
obj_id = obj.id
obj_is_file = obj.type == ObjectType.FILE
dest_policy_id = request.policy_id
dest_policy_name = dest_policy.name
# 创建任务记录
task = Task(
type=TaskType.POLICY_MIGRATE,
status=TaskStatus.QUEUED,
user_id=user_id,
)
task = await task.save(session)
task_id = task.id
task_props = TaskProps(
task_id=task_id,
source_policy_id=src_policy_id,
dest_policy_id=dest_policy_id,
object_id=obj_id,
)
task_props = await task_props.save(session)
if obj_is_file:
# 文件:后台迁移
async def _run_file_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_obj = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_file_with_task(bg_session, bg_obj, bg_policy, bg_task)
background_tasks.add_task(_run_file_migration)
else:
# 目录:先更新目录自身的 policy_id
obj = await Object.get(session, Object.id == obj_id)
obj.policy_id = dest_policy_id
obj = await obj.save(session)
if request.is_migrate_existing:
# 后台迁移所有已有文件
async def _run_dir_migration() -> None:
async with DatabaseManager.session() as bg_session:
bg_folder = await Object.get(bg_session, Object.id == obj_id)
bg_policy = await Policy.get(bg_session, Policy.id == dest_policy_id)
bg_task = await Task.get(bg_session, Task.id == task_id)
await migrate_directory_files(bg_session, bg_folder, bg_policy, bg_task)
background_tasks.add_task(_run_dir_migration)
else:
# 不迁移已有文件,直接完成任务
task = await Task.get(session, Task.id == task_id)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
# 重新获取 task 以读取最新状态
task = await Task.get(session, Task.id == task_id)
l.info(f"用户 {user_id} 请求切换对象 {obj_id} 存储策略 → {dest_policy_name}")
return TaskSummaryBase(
id=task.id,
type=task.type,
status=task.status,
progress=task.progress,
error=task.error,
user_id=task.user_id,
created_at=task.created_at,
updated_at=task.updated_at,
)
# ==================== 元数据端点 ====================
@object_router.get(
path='/{object_id}/metadata',
summary='获取对象元数据',
description='获取对象的元数据键值对,可按命名空间过滤。',
)
async def router_get_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
ns: str | None = None,
) -> MetadataResponse:
"""
获取对象元数据端点
认证JWT token 必填
查询参数:
- ns: 逗号分隔的命名空间列表(如 exif,stream不传返回所有非内部命名空间
错误处理:
- 404: 对象不存在
- 403: 无权查看此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
load=Object.metadata_entries,
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权查看此对象")
# 解析命名空间过滤
ns_filter: set[str] | None = None
if ns:
ns_filter = {n.strip() for n in ns.split(",") if n.strip()}
# 不允许查看内部命名空间
ns_filter -= INTERNAL_NAMESPACES
# 构建元数据字典
metadata: dict[str, str] = {}
for entry in obj.metadata_entries:
entry_ns = entry.name.split(":")[0] if ":" in entry.name else ""
if entry_ns in INTERNAL_NAMESPACES:
continue
if ns_filter is not None and entry_ns not in ns_filter:
continue
metadata[entry.name] = entry.value
return MetadataResponse(metadatas=metadata)
@object_router.patch(
path='/{object_id}/metadata',
summary='批量更新对象元数据',
description='批量设置或删除对象的元数据条目。仅允许修改 custom: 命名空间。',
status_code=204,
)
async def router_patch_object_metadata(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
object_id: UUID,
request: MetadataPatchRequest,
) -> None:
"""
批量更新对象元数据端点
请求体中值为 None 的键将被删除,其余键将被设置/更新。
用户只能修改 custom: 命名空间的条目。
认证JWT token 必填
错误处理:
- 400: 尝试修改非 custom: 命名空间的条目
- 404: 对象不存在
- 403: 无权操作此对象
"""
obj = await Object.get(
session,
(Object.id == object_id) & (Object.deleted_at == None),
)
if not obj:
raise HTTPException(status_code=404, detail="对象不存在")
if obj.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此对象")
for patch in request.patches:
# 验证命名空间
patch_ns = patch.key.split(":")[0] if ":" in patch.key else ""
if patch_ns not in USER_WRITABLE_NAMESPACES:
raise HTTPException(
status_code=400,
detail=f"不允许修改命名空间 '{patch_ns}' 的元数据,仅允许 custom: 命名空间",
)
if patch.value is None:
# 删除元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
await ObjectMetadata.delete(session, instances=existing)
else:
# 设置/更新元数据条目
existing = await ObjectMetadata.get(
session,
(ObjectMetadata.object_id == object_id) & (ObjectMetadata.name == patch.key),
)
if existing:
existing.value = patch.value
existing = await existing.save(session)
else:
entry = ObjectMetadata(
object_id=object_id,
name=patch.key,
value=patch.value,
is_public=True,
)
entry = await entry.save(session)
l.info(f"用户 {user.id} 更新了对象 {object_id}{len(request.patches)} 条元数据")

View File

@@ -0,0 +1,168 @@
"""
用户自定义属性定义路由
提供自定义属性模板的增删改查功能。
用户可以定义类型化的属性模板(如标签、评分、分类等),
然后通过元数据 PATCH 端点为对象设置属性值。
路由前缀:/custom_property
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
CustomPropertyDefinition,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
User,
)
router = APIRouter(
prefix="/custom_property",
tags=["custom_property"],
)
@router.get(
path='',
summary='获取自定义属性定义列表',
description='获取当前用户的所有自定义属性定义,按 sort_order 排序。',
)
async def router_list_custom_properties(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[CustomPropertyResponse]:
"""
获取自定义属性定义列表端点
认证JWT token 必填
返回当前用户定义的所有自定义属性模板。
"""
definitions = await CustomPropertyDefinition.get(
session,
CustomPropertyDefinition.owner_id == user.id,
fetch_mode="all",
)
return [
CustomPropertyResponse(
id=d.id,
name=d.name,
type=d.type,
icon=d.icon,
options=d.options,
default_value=d.default_value,
sort_order=d.sort_order,
)
for d in sorted(definitions, key=lambda x: x.sort_order)
]
@router.post(
path='',
summary='创建自定义属性定义',
description='创建一个新的自定义属性模板。',
status_code=204,
)
async def router_create_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: CustomPropertyCreateRequest,
) -> None:
"""
创建自定义属性定义端点
认证JWT token 必填
错误处理:
- 400: 请求数据无效
- 409: 同名属性已存在
"""
# 检查同名属性
existing = await CustomPropertyDefinition.get(
session,
(CustomPropertyDefinition.owner_id == user.id) &
(CustomPropertyDefinition.name == request.name),
)
if existing:
raise HTTPException(status_code=409, detail="同名自定义属性已存在")
definition = CustomPropertyDefinition(
owner_id=user.id,
name=request.name,
type=request.type,
icon=request.icon,
options=request.options,
default_value=request.default_value,
)
definition = await definition.save(session)
l.info(f"用户 {user.id} 创建了自定义属性: {request.name}")
@router.patch(
path='/{id}',
summary='更新自定义属性定义',
description='更新自定义属性模板的名称、图标、选项等。',
status_code=204,
)
async def router_update_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
request: CustomPropertyUpdateRequest,
) -> None:
"""
更新自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
definition = await definition.update(session, request)
l.info(f"用户 {user.id} 更新了自定义属性: {id}")
@router.delete(
path='/{id}',
summary='删除自定义属性定义',
description='删除自定义属性模板。注意:不会自动清理已使用该属性的元数据条目。',
status_code=204,
)
async def router_delete_custom_property(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
"""
删除自定义属性定义端点
认证JWT token 必填
错误处理:
- 404: 属性定义不存在
- 403: 无权操作此属性
"""
definition = await CustomPropertyDefinition.get_exist_one(session, id)
if definition.owner_id != user.id:
raise HTTPException(status_code=403, detail="无权操作此属性")
await CustomPropertyDefinition.delete(session, instances=definition)
l.info(f"用户 {user.id} 删除了自定义属性: {id}")

View File

@@ -1,5 +1,5 @@
from typing import Annotated, Literal from typing import Annotated, Literal
from uuid import uuid4 from uuid import UUID, uuid4
from datetime import datetime from datetime import datetime
from fastapi import APIRouter, Depends, Query, HTTPException from fastapi import APIRouter, Depends, Query, HTTPException
@@ -9,11 +9,14 @@ from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import ResponseBase from sqlmodels import ResponseBase
from sqlmodels.user import User from sqlmodels.user import User
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse from sqlmodels.share import (
from sqlmodels.object import Object Share, ShareCreateRequest, CreateShareResponse, ShareResponse,
from sqlmodels.mixin import ListResponse, TableViewRequest ShareDetailResponse, ShareOwnerInfo, ShareObjectItem,
)
from sqlmodels.object import Object, ObjectType
from sqlmodel_ext import ListResponse, TableViewRequest
from utils import http_exceptions from utils import http_exceptions
from utils.password.pwd import Password from utils.password.pwd import Password, PasswordStatus
share_router = APIRouter( share_router = APIRouter(
prefix='/share', prefix='/share',
@@ -22,21 +25,92 @@ share_router = APIRouter(
@share_router.get( @share_router.get(
path='/{id}', path='/{id}',
summary='获取分享', summary='获取分享详情',
description='Get shared content by info type and ID.', description='Get share detail by share ID. No authentication required.',
) )
def router_share_get(info: str, id: str) -> ResponseBase: async def router_share_get(
session: SessionDep,
id: UUID,
password: str | None = Query(default=None),
) -> ShareDetailResponse:
""" """
Get shared content by info type and ID. 获取分享详情
Args: 认证:无需登录
info (str): The type of information being shared.
id (str): The ID of the shared content.
Returns: 流程:
dict: A dictionary containing shared content information. 1. 通过分享ID查找分享
2. 检查过期、封禁状态
3. 验证提取码(如果有)
4. 返回分享详情(含文件树和分享者信息)
""" """
http_exceptions.raise_not_implemented() # 1. 查询分享(预加载 user 和 object
share = await Share.get_exist_one(session, id, load=[Share.user, Share.object])
# 2. 检查过期
now = datetime.now()
if share.expires and share.expires < now:
http_exceptions.raise_not_found(detail="分享已过期")
# 3. 获取关联对象
obj = await share.awaitable_attrs.object
user = await share.awaitable_attrs.user
# 4. 检查封禁和软删除
if obj and obj.is_banned:
http_exceptions.raise_banned()
if obj and obj.deleted_at:
http_exceptions.raise_not_found(detail="分享关联的文件已被删除")
# 5. 检查密码
if share.password:
if not password:
http_exceptions.raise_precondition_required(detail="请输入提取码")
if Password.verify(share.password, password) != PasswordStatus.VALID:
http_exceptions.raise_forbidden(detail="提取码错误")
# 6. 加载子对象(目录分享)
children_items: list[ShareObjectItem] = []
if obj and obj.type == ObjectType.FOLDER:
children = await Object.get_children(session, obj.owner_id, obj.id)
children_items = [
ShareObjectItem(
id=child.id,
name=child.name,
type=child.type,
size=child.size,
created_at=child.created_at,
updated_at=child.updated_at,
)
for child in children
]
# 7. 构建响应(在 save 之前,避免 MissingGreenlet
response = ShareDetailResponse(
expires=share.expires,
preview_enabled=share.preview_enabled,
score=share.score,
created_at=share.created_at,
owner=ShareOwnerInfo(
nickname=user.nickname if user else None,
avatar=user.avatar if user else "default",
),
object=ShareObjectItem(
id=obj.id,
name=obj.name,
type=obj.type,
size=obj.size,
created_at=obj.created_at,
updated_at=obj.updated_at,
),
children=children_items,
)
# 8. 递增浏览次数(最后执行,避免 MissingGreenlet
share.views += 1
await share.save(session, refresh=False)
return response
@share_router.put( @share_router.put(
path='/download/{id}', path='/download/{id}',
@@ -226,7 +300,7 @@ async def router_share_create(
session: SessionDep, session: SessionDep,
user: Annotated[User, Depends(auth_required)], user: Annotated[User, Depends(auth_required)],
request: ShareCreateRequest, request: ShareCreateRequest,
) -> ShareResponse: ) -> CreateShareResponse:
""" """
创建新分享 创建新分享
@@ -237,10 +311,13 @@ async def router_share_create(
2. 生成随机分享码uuid4 2. 生成随机分享码uuid4
3. 如果有密码则加密存储 3. 如果有密码则加密存储
4. 创建 Share 记录并保存 4. 创建 Share 记录并保存
5. 返回分享信息 5. 返回分享 ID
""" """
# 验证对象存在且属于当前用户 # 验证对象存在且属于当前用户(排除已删除的)
obj = await Object.get(session, Object.id == request.object_id) obj = await Object.get(
session,
(Object.id == request.object_id) & (Object.deleted_at == None)
)
if not obj or obj.owner_id != user.id: if not obj or obj.owner_id != user.id:
raise HTTPException(status_code=404, detail="对象不存在或无权限") raise HTTPException(status_code=404, detail="对象不存在或无权限")
@@ -256,11 +333,12 @@ async def router_share_create(
hashed_password = Password.hash(request.password) hashed_password = Password.hash(request.password)
# 创建分享记录 # 创建分享记录
user_id = user.id
share = Share( share = Share(
code=code, code=code,
password=hashed_password, password=hashed_password,
object_id=request.object_id, object_id=request.object_id,
user_id=user.id, user_id=user_id,
expires=request.expires, expires=request.expires,
remain_downloads=request.remain_downloads, remain_downloads=request.remain_downloads,
preview_enabled=request.preview_enabled, preview_enabled=request.preview_enabled,
@@ -269,24 +347,9 @@ async def router_share_create(
) )
share = await share.save(session) share = await share.save(session)
l.info(f"用户 {user.id} 创建分享: {share.code}") l.info(f"用户 {user_id} 创建分享: {share.code}")
# 返回响应 return CreateShareResponse(share_id=share.id)
return ShareResponse(
id=share.id,
code=share.code,
object_id=share.object_id,
source_name=share.source_name,
views=share.views,
downloads=share.downloads,
remain_downloads=share.remain_downloads,
expires=share.expires,
preview_enabled=share.preview_enabled,
score=share.score,
created_at=share.created_at,
is_expired=share.expires is not None and share.expires < datetime.now(),
has_password=share.password is not None,
)
@share_router.get( @share_router.get(
path='/', path='/',
@@ -406,16 +469,29 @@ def router_share_update(id: str) -> ResponseBase:
path='/{id}', path='/{id}',
summary='删除分享', summary='删除分享',
description='Delete a share by ID.', description='Delete a share by ID.',
dependencies=[Depends(auth_required)] status_code=204,
) )
def router_share_delete(id: str) -> ResponseBase: async def router_share_delete(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
id: UUID,
) -> None:
""" """
Delete a share by ID. 删除分享
Args: 认证:需要 JWT token
id (str): The ID of the share to be deleted.
Returns: 流程:
ResponseBase: A model containing the response data for the deleted share. 1. 通过分享ID查找分享
2. 验证分享属于当前用户
3. 删除分享记录
""" """
http_exceptions.raise_not_implemented() share = await Share.get_exist_one(session, id)
if share.user_id != user.id:
http_exceptions.raise_forbidden(detail="无权删除此分享")
user_id = user.id
share_code = share.code
await Share.delete(session, share)
l.info(f"用户 {user_id} 删除了分享: {share_code}")

View File

@@ -1,7 +1,12 @@
from fastapi import APIRouter from fastapi import APIRouter
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse from sqlmodels import (
ResponseBase, Setting, SettingsType, SiteConfigResponse,
ThemePreset, ThemePresetResponse, ThemePresetListResponse,
AuthMethodConfig,
)
from sqlmodels.auth_identity import AuthProviderType
from sqlmodels.setting import CaptchaType from sqlmodels.setting import CaptchaType
from utils import http_exceptions from utils import http_exceptions
@@ -41,6 +46,22 @@ def router_site_captcha():
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@site_router.get(
path='/themes',
summary='获取主题预设列表',
)
async def router_site_themes(session: SessionDep) -> ThemePresetListResponse:
"""
获取所有主题预设列表
无需认证,前端初始化时调用。
"""
presets: list[ThemePreset] = await ThemePreset.get(session, fetch_mode="all")
return ThemePresetListResponse(
themes=[ThemePresetResponse.from_preset(p) for p in presets]
)
@site_router.get( @site_router.get(
path='/config', path='/config',
summary='站点全局配置', summary='站点全局配置',
@@ -51,7 +72,7 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
获取站点全局配置 获取站点全局配置
无需认证。前端在初始化时调用此端点获取验证码类型、 无需认证。前端在初始化时调用此端点获取验证码类型、
登录/注册/找回密码是否需要验证码等配置。 登录/注册/找回密码是否需要验证码、可用的认证方式等配置。
""" """
# 批量查询所需设置 # 批量查询所需设置
settings: list[Setting] = await Setting.get( settings: list[Setting] = await Setting.get(
@@ -59,7 +80,10 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
(Setting.type == SettingsType.BASIC) | (Setting.type == SettingsType.BASIC) |
(Setting.type == SettingsType.LOGIN) | (Setting.type == SettingsType.LOGIN) |
(Setting.type == SettingsType.REGISTER) | (Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA), (Setting.type == SettingsType.CAPTCHA) |
(Setting.type == SettingsType.AUTH) |
(Setting.type == SettingsType.OAUTH) |
(Setting.type == SettingsType.AVATAR),
fetch_mode="all", fetch_mode="all",
) )
@@ -75,12 +99,32 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE: elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
captcha_key = s.get("captcha_CloudflareKey") or None captcha_key = s.get("captcha_CloudflareKey") or None
# 构建认证方式列表
auth_methods: list[AuthMethodConfig] = [
AuthMethodConfig(provider=AuthProviderType.EMAIL_PASSWORD, is_enabled=s.get("auth_email_password_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.PHONE_SMS, is_enabled=s.get("auth_phone_sms_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.GITHUB, is_enabled=s.get("github_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.QQ, is_enabled=s.get("qq_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.PASSKEY, is_enabled=s.get("auth_passkey_enabled") == "1"),
AuthMethodConfig(provider=AuthProviderType.MAGIC_LINK, is_enabled=s.get("auth_magic_link_enabled") == "1"),
]
return SiteConfigResponse( return SiteConfigResponse(
title=s.get("siteName") or "DiskNext", title=s.get("siteName") or "DiskNext",
logo_light=s.get("logo_light") or None,
logo_dark=s.get("logo_dark") or None,
register_enabled=s.get("register_enabled") == "1", register_enabled=s.get("register_enabled") == "1",
login_captcha=s.get("login_captcha") == "1", login_captcha=s.get("login_captcha") == "1",
reg_captcha=s.get("reg_captcha") == "1", reg_captcha=s.get("reg_captcha") == "1",
forget_captcha=s.get("forget_captcha") == "1", forget_captcha=s.get("forget_captcha") == "1",
captcha_type=captcha_type, captcha_type=captcha_type,
captcha_key=captcha_key, captcha_key=captcha_key,
auth_methods=auth_methods,
password_required=s.get("auth_password_required") == "1",
phone_binding_required=s.get("auth_phone_binding_required") == "1",
email_binding_required=s.get("auth_email_binding_required") == "1",
avatar_max_size=int(s["avatar_size"]),
footer_code=s.get("footer_code"),
tos_url=s.get("tos_url"),
privacy_url=s.get("privacy_url"),
) )

View File

@@ -20,15 +20,15 @@ slave_aria2_router = APIRouter(
summary='测试用路由', summary='测试用路由',
description='Test route for checking connectivity.', description='Test route for checking connectivity.',
) )
def router_slave_ping() -> ResponseBase: def router_slave_ping() -> str:
""" """
Test route for checking connectivity. Test route for checking connectivity.
Returns: Returns:
ResponseBase: A response model indicating success. str: 后端版本号
""" """
from utils.conf.appmeta import BackendVersion from utils.conf.appmeta import BackendVersion
return ResponseBase(data=BackendVersion) return BackendVersion
@slave_router.post( @slave_router.post(
path='/post', path='/post',

View File

@@ -0,0 +1,161 @@
"""
回收站路由
提供回收站管理功能:列出、恢复、永久删除、清空。
路由前缀:/trash
"""
from typing import Annotated
from fastapi import APIRouter, Depends
from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import Object, User
from sqlmodels.object import TrashDeleteRequest, TrashItemResponse, TrashRestoreRequest
from service.storage.object import (
permanently_delete_objects,
restore_objects,
soft_delete_objects,
)
trash_router = APIRouter(
prefix="/trash",
tags=["trash"],
)
@trash_router.get(
path='/',
summary='列出回收站内容',
description='获取当前用户回收站中的所有顶层对象。',
)
async def router_trash_list(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[TrashItemResponse]:
"""
列出回收站内容
认证:需要 JWT token
返回回收站中被直接删除的顶层对象列表,
不包含其子对象(子对象在恢复/永久删除时会随顶层对象一起处理)。
"""
items = await Object.get_trash_items(session, user.id)
return [
TrashItemResponse(
id=item.id,
name=item.name,
type=item.type,
size=item.size,
deleted_at=item.deleted_at,
original_parent_id=item.deleted_original_parent_id,
)
for item in items
]
@trash_router.patch(
path='/restore',
summary='恢复对象',
description='从回收站恢复一个或多个对象到原位置。如果原位置不存在则恢复到根目录。',
status_code=204,
)
async def router_trash_restore(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: TrashRestoreRequest,
) -> None:
"""
从回收站恢复对象
认证:需要 JWT token
流程:
1. 验证对象存在且在回收站中deleted_at IS NOT NULL
2. 检查原父目录是否存在且未删除
3. 存在 → 恢复到原位置;不存在 → 恢复到根目录
4. 处理同名冲突(自动重命名)
5. 清除 deleted_at 和 deleted_original_parent_id
"""
user_id = user.id
objects_to_restore: list[Object] = []
for obj_id in request.ids:
obj = await Object.get(
session,
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
)
if obj:
objects_to_restore.append(obj)
if objects_to_restore:
restored_count = await restore_objects(session, objects_to_restore, user_id)
l.info(f"用户 {user_id} 从回收站恢复了 {restored_count} 个对象")
@trash_router.delete(
path='/',
summary='永久删除对象',
description='永久删除回收站中的指定对象,包括物理文件和数据库记录。',
status_code=204,
)
async def router_trash_delete(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: TrashDeleteRequest,
) -> None:
"""
永久删除回收站中的对象
认证:需要 JWT token
流程:
1. 验证对象存在且在回收站中
2. BFS 收集所有子文件的 PhysicalFile
3. 处理引用计数,引用为 0 时物理删除文件
4. 硬删除根 ObjectCASCADE 清理子对象)
5. 更新用户存储配额
"""
user_id = user.id
objects_to_delete: list[Object] = []
for obj_id in request.ids:
obj = await Object.get(
session,
(Object.id == obj_id) & (Object.owner_id == user_id) & (Object.deleted_at != None)
)
if obj:
objects_to_delete.append(obj)
if objects_to_delete:
deleted_count = await permanently_delete_objects(session, objects_to_delete, user_id)
l.info(f"用户 {user_id} 永久删除了 {deleted_count} 个对象")
@trash_router.delete(
path='/empty',
summary='清空回收站',
description='永久删除回收站中的所有对象。',
status_code=204,
)
async def router_trash_empty(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> None:
"""
清空回收站
认证:需要 JWT token
获取回收站中所有顶层对象,逐个执行永久删除。
"""
user_id = user.id
trash_items = await Object.get_trash_items(session, user_id)
if trash_items:
deleted_count = await permanently_delete_objects(session, trash_items, user_id)
l.info(f"用户 {user_id} 清空回收站,共删除 {deleted_count} 个对象")

View File

@@ -1,18 +1,31 @@
from typing import Annotated, Literal from typing import Annotated, Literal
from uuid import UUID, uuid4 from uuid import UUID, uuid4
import json
import jwt import jwt
from fastapi import APIRouter, Depends, Form, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse, RedirectResponse
from itsdangerous import URLSafeTimedSerializer
from loguru import logger from loguru import logger
from webauthn import generate_registration_options from webauthn import (
from webauthn.helpers import options_to_json_dict generate_authentication_options,
generate_registration_options,
verify_registration_response,
)
from webauthn.helpers import bytes_to_base64url, options_to_json
from webauthn.helpers.structs import PublicKeyCredentialDescriptor
import service import service
import sqlmodels import sqlmodels
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep, require_captcha from middleware.dependencies import SessionDep, require_captcha
from service.captcha import CaptchaScene from service.captcha import CaptchaScene
from service.redis.challenge_store import ChallengeStore
from service.webauthn import get_rp_config
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.user import UserStatus from sqlmodels.user import UserStatus
from sqlmodels.user_authn import UserAuthn
from utils import JWT, Password, http_exceptions from utils import JWT, Password, http_exceptions
from .settings import user_settings_router from .settings import user_settings_router
@@ -23,59 +36,36 @@ user_router = APIRouter(
user_router.include_router(user_settings_router) user_router.include_router(user_settings_router)
class OAuth2PasswordWithExtrasForm:
"""
扩展 OAuth2 密码表单。
在标准 username/password 基础上添加 otp_code 字段。
captcha_code 由 require_captcha 依赖注入单独处理。
"""
def __init__(
self,
*,
username: Annotated[str, Form()],
password: Annotated[str, Form()],
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
):
self.username = username
self.password = password
self.otp_code = otp_code
@user_router.post( @user_router.post(
path='/session', path='/session',
summary='用户登录', summary='用户登录(统一入口)',
description='用户登录端点,支持验证码校验和两步验证', description='统一登录端点,支持多种认证方式',
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
) )
async def router_user_session( async def router_user_session(
session: SessionDep, session: SessionDep,
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()], request: sqlmodels.UnifiedLoginRequest,
) -> sqlmodels.TokenResponse: ) -> sqlmodels.TokenResponse:
""" """
用户登录端点 统一登录端点
表单字段 请求体
- username: 用户邮箱 - provider: 登录方式email_password / github / qq / passkey / magic_link
- password: 用户密码 - identifier: 标识符(邮箱 / OAuth code / credential_id / magic link token
- captcha_code: 验证码 token可选由 require_captcha 依赖校验 - credential: 凭证(密码 / WebAuthn assertion 等
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要 - two_fa_code: 两步验证码(可选)
- redirect_uri: OAuth 回调地址(可选)
- captcha: 验证码(可选)
错误处理: 错误处理:
- 400: 需要验证码但未提供 - 400: 登录方式未启用 / 参数错误
- 401: 邮箱/密码错误,或 2FA 验证码错误 - 401: 凭证错误
- 403: 账户已禁用 / 验证码验证失败 - 403: 账户已禁用
- 428: 需要两步验证但未提供 otp_code - 428: 需要两步验证
- 501: 暂未实现的登录方式
""" """
return await service.user.login( return await service.user.unified_login(session, request)
session,
sqlmodels.LoginRequest(
email=form_data.username,
password=form_data.password,
two_fa_code=form_data.otp_code,
),
)
@user_router.post( @user_router.post(
path='/session/refresh', path='/session/refresh',
@@ -150,41 +140,82 @@ async def router_user_session_refresh(
@user_router.post( @user_router.post(
path='/', path='/',
summary='用户注册', summary='用户注册(统一入口)',
description='User registration endpoint.', description='User registration endpoint.',
status_code=204, status_code=204,
) )
async def router_user_register( async def router_user_register(
session: SessionDep, session: SessionDep,
request: sqlmodels.RegisterRequest, request: sqlmodels.UnifiedRegisterRequest,
) -> None: ) -> None:
""" """
用户注册端点 统一注册端点
流程: 流程:
1. 验证用户名唯一性 1. 检查注册开关
2. 获取默认用户组 2. 检查 provider 启用
3. 创建用户记录 3. 验证 identifier 唯一性AuthIdentity 表)
4. 创建用户根目录name="/" 4. 创建 User + AuthIdentity + 根目录
:param session: 数据库会话 请求体:
:param request: 注册请求 - provider: 注册方式email_password / phone_sms
:return: 注册结果 - identifier: 标识符(邮箱 / 手机号)
:raises HTTPException 400: 用户名已存在 - credential: 凭证(密码 / 短信验证码)
:raises HTTPException 500: 默认用户组或存储策略不存在 - nickname: 昵称(可选)
- captcha: 验证码(可选)
错误处理:
- 400: 注册未开放 / 参数错误
- 409: 邮箱或手机号已存在
- 501: 暂未实现的注册方式
""" """
# 1. 验证邮箱唯一性 # 1. 检查注册开关
register_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
& (sqlmodels.Setting.name == "register_enabled"),
)
if not register_setting or register_setting.value != "1":
http_exceptions.raise_bad_request("注册功能未开放")
# 2. 目前只支持 email_password 注册
if request.provider == AuthProviderType.PHONE_SMS:
http_exceptions.raise_not_implemented("短信注册暂未开放")
elif request.provider != AuthProviderType.EMAIL_PASSWORD:
http_exceptions.raise_bad_request("不支持的注册方式")
# 3. 检查密码是否必填
password_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_password_required"),
)
is_password_required = not password_required_setting or password_required_setting.value != "0"
if is_password_required and not request.credential:
http_exceptions.raise_bad_request("密码不能为空")
# 4. 验证 identifier 唯一性AuthIdentity 表)
existing_identity = await AuthIdentity.get(
session,
(AuthIdentity.provider == request.provider)
& (AuthIdentity.identifier == request.identifier),
)
if existing_identity:
raise HTTPException(status_code=409, detail="该邮箱已被注册")
# 同时检查 User.email 唯一性(防止旧数据冲突)
existing_user = await sqlmodels.User.get( existing_user = await sqlmodels.User.get(
session, session,
sqlmodels.User.email == request.email sqlmodels.User.email == request.identifier,
) )
if existing_user: if existing_user:
raise HTTPException(status_code=400, detail="邮箱已存在") raise HTTPException(status_code=409, detail="邮箱已被注册")
# 2. 获取默认用户组(从设置中读取 UUID # 5. 获取默认用户组
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get( default_group_setting = await sqlmodels.Setting.get(
session, session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group") (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER)
& (sqlmodels.Setting.name == "default_group"),
) )
if default_group_setting is None or not default_group_setting.value: if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
@@ -196,21 +227,33 @@ async def router_user_register(
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
# 3. 创建用户 # 6. 创建用户
hashed_password = Password.hash(request.password)
new_user = sqlmodels.User( new_user = sqlmodels.User(
email=request.email, email=request.identifier,
password=hashed_password, nickname=request.nickname,
group_id=default_group.id, group_id=default_group.id,
) )
new_user_id = new_user.id new_user_id = new_user.id
await new_user.save(session) new_user = await new_user.save(session)
# 4. 创建用户根目录 # 7. 创建 AuthIdentity
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") hashed_password = Password.hash(request.credential) if request.credential else None
if not default_policy: identity = AuthIdentity(
logger.error("默认存储策略不存在") provider=AuthProviderType.EMAIL_PASSWORD,
identifier=request.identifier,
credential=hashed_password,
is_primary=True,
is_verified=False,
user_id=new_user_id,
)
identity = await identity.save(session)
# 8. 创建用户根目录(使用用户组关联的第一个存储策略)
await session.refresh(default_group, ['policies'])
if not default_group.policies:
logger.error("默认用户组未关联任何存储策略")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
default_policy = default_group.policies[0]
await sqlmodels.Object( await sqlmodels.Object(
name="/", name="/",
@@ -220,6 +263,66 @@ async def router_user_register(
policy_id=default_policy.id, policy_id=default_policy.id,
).save(session) ).save(session)
@user_router.post(
path='/magic-link',
summary='发送 Magic Link 邮件',
description='生成 Magic Link token 并发送到指定邮箱。',
status_code=204,
)
async def router_user_magic_link(
session: SessionDep,
request: sqlmodels.MagicLinkRequest,
) -> None:
"""
发送 Magic Link 邮件
流程:
1. 验证邮箱对应的 AuthIdentity 存在
2. 生成签名 token
3. 发送邮件(包含带 token 的链接)
错误处理:
- 400: Magic Link 未启用
- 404: 邮箱未注册
"""
# 检查 magic_link 是否启用
magic_link_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_magic_link_enabled"),
)
if not magic_link_setting or magic_link_setting.value != "1":
http_exceptions.raise_bad_request("Magic Link 登录未启用")
# 验证邮箱存在
identity = await AuthIdentity.get(
session,
(AuthIdentity.identifier == request.email)
& (
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
),
)
if not identity:
http_exceptions.raise_not_found("该邮箱未注册")
# 生成签名 token
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
token = serializer.dumps(request.email, salt="magic-link-salt")
# 获取站点 URL
site_url_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.BASIC)
& (sqlmodels.Setting.name == "siteURL"),
)
site_url = site_url_setting.value if site_url_setting else "http://localhost"
# TODO: 发送邮件(包含 {site_url}/auth/magic-link?token={token}
logger.info(f"Magic Link token 已为 {request.email} 生成 (邮件发送待实现)")
@user_router.post( @user_router.post(
path='/code', path='/code',
summary='发送验证码邮件', summary='发送验证码邮件',
@@ -236,46 +339,6 @@ def router_user_email_code(
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@user_router.get(
path='/qq',
summary='初始化QQ登录',
description='Initialize QQ login for a user.',
)
def router_user_qq() -> sqlmodels.ResponseBase:
"""
Initialize QQ login for a user.
Returns:
dict: A dictionary containing QQ login initialization information.
"""
http_exceptions.raise_not_implemented()
@user_router.get(
path='authn/{username}',
summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.',
)
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
http_exceptions.raise_not_implemented()
@user_router.post(
path='authn/finish/{username}',
summary='WebAuthn登录',
description='Finish WebAuthn login for a user.',
)
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
"""
Finish WebAuthn login for a user.
Args:
username (str): The username of the user.
Returns:
dict: A dictionary containing WebAuthn login information.
"""
http_exceptions.raise_not_implemented()
@user_router.get( @user_router.get(
path='/profile/{id}', path='/profile/{id}',
summary='获取用户主页展示用分享', summary='获取用户主页展示用分享',
@@ -296,20 +359,78 @@ def router_user_profile(id: str) -> sqlmodels.ResponseBase:
@user_router.get( @user_router.get(
path='/avatar/{id}/{size}', path='/avatar/{id}/{size}',
summary='获取用户头像', summary='获取用户头像',
description='Get user avatar by ID and size.', response_model=None,
) )
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase: async def router_user_avatar(
session: SessionDep,
id: UUID,
size: int = 128,
) -> FileResponse | RedirectResponse:
""" """
Get user avatar by ID and size. 获取指定用户指定尺寸的头像(公开端点,无需认证)
Args: 路径参数:
id (str): The user ID. - id: 用户 UUID
size (int): The size of the avatar image. - size: 请求的头像尺寸px默认 128
Returns: 行为:
str: A Base64 encoded string of the user avatar image. - default: 302 重定向到 Gravatar identicon
- gravatar: 302 重定向到 Gravatar使用用户邮箱 MD5
- file: 返回本地 WebP 文件
响应:
- 200: image/webpfile 模式)
- 302: 重定向到外部 URLdefault/gravatar 模式)
- 404: 用户不存在
缓存Cache-Control: public, max-age=3600
""" """
http_exceptions.raise_not_implemented() import aiofiles.os
from service.avatar import (
get_avatar_file_path,
get_avatar_settings,
gravatar_url,
resolve_avatar_size,
)
user = await sqlmodels.User.get(session, sqlmodels.User.id == id)
if not user:
http_exceptions.raise_not_found("用户不存在")
avatar_path, _, size_l, size_m, size_s = await get_avatar_settings(session)
if user.avatar == "file":
size_label = resolve_avatar_size(size, size_l, size_m, size_s)
file_path = get_avatar_file_path(avatar_path, user.id, size_label)
if not await aiofiles.os.path.exists(file_path):
# 文件丢失,降级为 identicon
fallback_url = gravatar_url(str(user.id), size, "https://www.gravatar.com/")
return RedirectResponse(url=fallback_url, status_code=302)
return FileResponse(
path=file_path,
media_type="image/webp",
headers={"Cache-Control": "public, max-age=3600"},
)
elif user.avatar == "gravatar":
gravatar_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AVATAR)
& (sqlmodels.Setting.name == "gravatar_server"),
)
server = gravatar_setting.value if gravatar_setting else "https://www.gravatar.com/"
email = user.email or str(user.id)
url = gravatar_url(email, size, server)
return RedirectResponse(url=url, status_code=302)
else:
# default: identicon
email_or_id = user.email or str(user.id)
url = gravatar_url(email_or_id, size, "https://www.gravatar.com/")
return RedirectResponse(url=url, status_code=302)
##################### #####################
# 需要登录的接口 # 需要登录的接口
@@ -348,8 +469,6 @@ async def router_user_me(
return sqlmodels.UserResponse( return sqlmodels.UserResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,
status=user.status,
score=user.score,
nickname=user.nickname, nickname=user.nickname,
avatar=user.avatar, avatar=user.avatar,
created_at=user.created_at, created_at=user.created_at,
@@ -375,9 +494,24 @@ async def router_user_storage(
if not group: if not group:
raise HTTPException(status_code=404, detail="用户组不存在") raise HTTPException(status_code=404, detail="用户组不存在")
# [TODO] 总空间加上用户购买的额外空间 # 查询用户所有未过期容量包的 size 总和
from datetime import datetime
from sqlalchemy import func, select, and_, or_
total: int = group.max_storage now = datetime.now()
stmt = select(func.coalesce(func.sum(sqlmodels.StoragePack.size), 0)).where(
and_(
sqlmodels.StoragePack.user_id == user.id,
or_(
sqlmodels.StoragePack.expired_time.is_(None),
sqlmodels.StoragePack.expired_time > now,
),
)
)
result = await session.exec(stmt)
active_packs_total: int = result.scalar_one()
total: int = group.max_storage + active_packs_total
used: int = user.storage used: int = user.storage
free: int = max(0, total - used) free: int = max(0, total - used)
@@ -389,57 +523,177 @@ async def router_user_storage(
@user_router.put( @user_router.put(
path='/authn/start', path='/authn/start',
summary='WebAuthn登录初始化', summary='注册 Passkey 凭证(初始化',
description='Initialize WebAuthn login for a user.', description='Initialize Passkey registration for a user.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
) )
async def router_user_authn_start( async def router_user_authn_start(
session: SessionDep, session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase: ) -> dict:
""" """
Initialize WebAuthn login for a user. Passkey 注册初始化(需要登录)
Returns: 返回 WebAuthn registration options前端使用 navigator.credentials.create() 处理。
dict: A dictionary containing WebAuthn initialization information.
错误处理:
- 400: Passkey 未启用
""" """
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await sqlmodels.Setting.get( authn_setting = await sqlmodels.Setting.get(
session, session,
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled") (sqlmodels.Setting.type == sqlmodels.SettingsType.AUTHN)
& (sqlmodels.Setting.name == "authn_enabled"),
) )
if not authn_setting or authn_setting.value != "1": if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled") http_exceptions.raise_bad_request("Passkey 未启用")
site_url_setting = await sqlmodels.Setting.get( rp_id, rp_name, _origin = await get_rp_config(session)
# 查询用户已注册凭证,用于 exclude_credentials
existing_authns: list[UserAuthn] = await UserAuthn.get(
session, session,
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL") UserAuthn.user_id == user.id,
fetch_mode="all",
) )
site_title_setting = await sqlmodels.Setting.get( exclude_credentials: list[PublicKeyCredentialDescriptor] = [
session, PublicKeyCredentialDescriptor(
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle") id=authn.credential_id,
transports=authn.transports.split(",") if authn.transports else [],
) )
for authn in existing_authns
]
options = generate_registration_options( options = generate_registration_options(
rp_id=site_url_setting.value if site_url_setting else "", rp_id=rp_id,
rp_name=site_title_setting.value if site_title_setting else "", rp_name=rp_name,
user_name=user.email, user_id=user.id.bytes,
user_display_name=user.nickname or user.email, user_name=user.email or str(user.id),
user_display_name=user.nickname or user.email or str(user.id),
exclude_credentials=exclude_credentials if exclude_credentials else None,
) )
return sqlmodels.ResponseBase(data=options_to_json_dict(options)) # 存储 challenge
await ChallengeStore.store(f"reg:{user.id}", options.challenge)
return json.loads(options_to_json(options))
@user_router.put( @user_router.put(
path='/authn/finish', path='/authn/finish',
summary='WebAuthn登录', summary='注册 Passkey 凭证(完成)',
description='Finish WebAuthn login for a user.', description='Finish Passkey registration for a user.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
status_code=201,
) )
def router_user_authn_finish() -> sqlmodels.ResponseBase: async def router_user_authn_finish(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
request: sqlmodels.AuthnFinishRequest,
) -> sqlmodels.AuthnDetailResponse:
""" """
Finish WebAuthn login for a user. Passkey 注册完成(需要登录)
Returns: 接收前端 navigator.credentials.create() 返回的凭证数据,
dict: A dictionary containing WebAuthn login information. 验证后创建 UserAuthn 行 + AuthIdentity(provider=passkey)。
请求体:
- credential: navigator.credentials.create() 返回的 JSON 字符串
- name: 凭证名称(可选)
错误处理:
- 400: challenge 已过期或无效 / 验证失败
""" """
http_exceptions.raise_not_implemented() # 取出 challenge一次性
challenge: bytes | None = await ChallengeStore.retrieve_and_delete(f"reg:{user.id}")
if challenge is None:
http_exceptions.raise_bad_request("注册会话已过期,请重新开始")
rp_id, _rp_name, origin = await get_rp_config(session)
# 验证注册响应
try:
verification = verify_registration_response(
credential=request.credential,
expected_challenge=challenge,
expected_rp_id=rp_id,
expected_origin=origin,
)
except Exception as e:
logger.warning(f"WebAuthn 注册验证失败: {e}")
http_exceptions.raise_bad_request("Passkey 验证失败")
# 编码为 base64url 存储
credential_id_b64: str = bytes_to_base64url(verification.credential_id)
credential_public_key_b64: str = bytes_to_base64url(verification.credential_public_key)
# 提取 transports
credential_dict: dict = json.loads(request.credential)
response_dict: dict = credential_dict.get("response", {})
transports_list: list[str] = response_dict.get("transports", [])
transports_str: str | None = ",".join(transports_list) if transports_list else None
# 创建 UserAuthn 记录
authn = UserAuthn(
credential_id=credential_id_b64,
credential_public_key=credential_public_key_b64,
sign_count=verification.sign_count,
credential_device_type=verification.credential_device_type,
credential_backed_up=verification.credential_backed_up,
transports=transports_str,
name=request.name,
user_id=user.id,
)
authn = await authn.save(session)
# 创建 AuthIdentityprovider=passkeyidentifier=credential_id_b64
identity = AuthIdentity(
provider=AuthProviderType.PASSKEY,
identifier=credential_id_b64,
is_primary=False,
is_verified=True,
user_id=user.id,
)
identity = await identity.save(session)
return authn.to_detail_response()
@user_router.post(
path='/authn/options',
summary='获取 Passkey 登录 options无需登录',
description='Generate authentication options for Passkey login.',
)
async def router_user_authn_options(
session: SessionDep,
) -> dict:
"""
获取 Passkey 登录的 authentication options无需登录
前端调用此端点获取 options 后使用 navigator.credentials.get() 处理。
使用 Discoverable Credentials 模式(空 allow_credentials
由浏览器/平台决定展示哪些凭证。
返回值包含 ``challenge_token`` 字段,前端在登录请求中作为 ``identifier`` 传入。
错误处理:
- 400: Passkey 未启用
"""
authn_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTHN)
& (sqlmodels.Setting.name == "authn_enabled"),
)
if not authn_setting or authn_setting.value != "1":
http_exceptions.raise_bad_request("Passkey 未启用")
rp_id, _rp_name, _origin = await get_rp_config(session)
options = generate_authentication_options(rp_id=rp_id)
# 生成 challenge_token 用于关联 challenge
challenge_token: str = str(uuid4())
await ChallengeStore.store(f"auth:{challenge_token}", options.challenge)
result: dict = json.loads(options_to_json(options))
result["challenge_token"] = challenge_token
return result

View File

@@ -1,33 +1,56 @@
from typing import Annotated from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import sqlmodels import sqlmodels
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from sqlmodels import (
BUILTIN_DEFAULT_COLORS, ThemePreset, UserThemeUpdateRequest,
SettingOption, UserSettingUpdateRequest,
AuthIdentity, AuthIdentityResponse, AuthProviderType, BindIdentityRequest,
ChangePasswordRequest,
AuthnDetailResponse, AuthnRenameRequest,
PolicySummary,
)
from sqlmodels.color import ThemeColorsBase
from sqlmodels.user_authn import UserAuthn
from utils import JWT, Password, http_exceptions from utils import JWT, Password, http_exceptions
from utils.password.pwd import PasswordStatus, TwoFactorResponse, TwoFactorVerifyRequest
from .file_viewers import file_viewers_router
user_settings_router = APIRouter( user_settings_router = APIRouter(
prefix='/settings', prefix='/settings',
tags=["user", "user_settings"], tags=["user", "user_settings"],
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
) )
user_settings_router.include_router(file_viewers_router)
@user_settings_router.get( @user_settings_router.get(
path='/policies', path='/policies',
summary='获取用户可选存储策略', summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
) )
def router_user_settings_policies() -> sqlmodels.ResponseBase: async def router_user_settings_policies(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[PolicySummary]:
""" """
Get user selectable storage policies. 获取当前用户所在组可选的存储策略列表
Returns: 返回用户组关联的所有存储策略的摘要信息。
dict: A dictionary containing available storage policies for the user.
""" """
http_exceptions.raise_not_implemented() group = await user.awaitable_attrs.group
await session.refresh(group, ['policies'])
return [
PolicySummary(
id=p.id, name=p.name, type=p.type,
server=p.server, max_size=p.max_size, is_private=p.is_private,
)
for p in group.policies
]
@user_settings_router.get( @user_settings_router.get(
@@ -67,77 +90,309 @@ def router_user_settings_tasks() -> sqlmodels.ResponseBase:
summary='获取当前用户设定', summary='获取当前用户设定',
description='Get current user settings.', description='Get current user settings.',
) )
def router_user_settings( async def router_user_settings(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.UserSettingResponse: ) -> sqlmodels.UserSettingResponse:
""" """
Get current user settings. 获取当前用户设定
Returns: 主题颜色合并策略:
dict: A dictionary containing the current user settings. 1. 用户有颜色快照7个字段均有值→ 直接使用快照
2. 否则查找默认预设 → 使用默认预设颜色
3. 无默认预设 → 使用内置默认值
""" """
# 计算主题颜色
has_snapshot = all([
user.color_primary, user.color_secondary, user.color_success,
user.color_info, user.color_warning, user.color_error, user.color_neutral,
])
if has_snapshot:
theme_colors = ThemeColorsBase(
primary=user.color_primary,
secondary=user.color_secondary,
success=user.color_success,
info=user.color_info,
warning=user.color_warning,
error=user.color_error,
neutral=user.color_neutral,
)
else:
default_preset: ThemePreset | None = await ThemePreset.get(
session, ThemePreset.is_default == True # noqa: E712
)
if default_preset:
theme_colors = ThemeColorsBase(
primary=default_preset.primary,
secondary=default_preset.secondary,
success=default_preset.success,
info=default_preset.info,
warning=default_preset.warning,
error=default_preset.error,
neutral=default_preset.neutral,
)
else:
theme_colors = BUILTIN_DEFAULT_COLORS
# 检查是否启用了两步验证(从 email_password AuthIdentity 的 extra_data 中读取)
has_two_factor = False
email_identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.user_id == user.id)
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
)
if email_identity and email_identity.extra_data:
import orjson
extra: dict = orjson.loads(email_identity.extra_data)
has_two_factor = bool(extra.get("two_factor"))
return sqlmodels.UserSettingResponse( return sqlmodels.UserSettingResponse(
id=user.id, id=user.id,
email=user.email, email=user.email,
phone=user.phone,
nickname=user.nickname, nickname=user.nickname,
created_at=user.created_at, created_at=user.created_at,
group_name=user.group.name, group_name=user.group.name,
language=user.language, language=user.language,
timezone=user.timezone, timezone=user.timezone,
group_expires=user.group_expires, group_expires=user.group_expires,
two_factor=user.two_factor is not None, two_factor=has_two_factor,
theme_preset_id=user.theme_preset_id,
theme_colors=theme_colors,
) )
@user_settings_router.post( @user_settings_router.post(
path='/avatar', path='/avatar',
summary='从文件上传头像', summary='从文件上传头像',
description='Upload user avatar from file.', status_code=204,
dependencies=[Depends(auth_required)],
) )
def router_user_settings_avatar() -> sqlmodels.ResponseBase: async def router_user_settings_avatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
file: UploadFile = File(...),
) -> None:
""" """
Upload user avatar from file. 上传头像文件
Returns: 认证JWT token
dict: A dictionary containing the result of the avatar upload. 请求体multipart/form-datafile 字段
流程:
1. 验证文件 MIME 类型JPEG/PNG/GIF/WebP
2. 验证文件大小 <= avatar_size 设置(默认 2MB
3. 调用 Pillow 验证图片有效性并处理(居中裁剪、缩放 L/M/S
4. 保存三种尺寸的 WebP 文件
5. 更新 User.avatar = "file"
错误处理:
- 400: 文件类型不支持 / 图片无法解析
- 413: 文件过大
""" """
http_exceptions.raise_not_implemented() from service.avatar import (
ALLOWED_CONTENT_TYPES,
get_avatar_settings,
process_and_save_avatar,
)
# 验证 MIME 类型
if file.content_type not in ALLOWED_CONTENT_TYPES:
http_exceptions.raise_bad_request(
f"不支持的图片格式,允许: {', '.join(ALLOWED_CONTENT_TYPES)}"
)
# 读取并验证大小
_, max_upload_size, _, _, _ = await get_avatar_settings(session)
raw_bytes = await file.read()
if len(raw_bytes) > max_upload_size:
raise HTTPException(
status_code=413,
detail=f"文件过大,最大允许 {max_upload_size} 字节",
)
# 处理并保存(内部会验证图片有效性,无效抛出 ValueError
try:
await process_and_save_avatar(session, user.id, raw_bytes)
except ValueError as e:
http_exceptions.raise_bad_request(str(e))
# 更新用户头像字段
user.avatar = "file"
user = await user.save(session)
@user_settings_router.put( @user_settings_router.put(
path='/avatar', path='/avatar',
summary='设定为Gravatar头像', summary='设定为 Gravatar 头像',
description='Set user avatar to Gravatar.', status_code=204,
dependencies=[Depends(auth_required)],
) )
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase: async def router_user_settings_avatar_gravatar(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
""" """
Set user avatar to Gravatar. 将头像切换为 Gravatar
Returns: 认证JWT token
dict: A dictionary containing the result of setting the Gravatar avatar.
流程:
1. 验证用户有邮箱Gravatar 基于邮箱 MD5
2. 如果当前是 FILE 头像,删除本地文件
3. 更新 User.avatar = "gravatar"
错误处理:
- 400: 用户没有邮箱
""" """
http_exceptions.raise_not_implemented() from service.avatar import delete_avatar_files
if not user.email:
http_exceptions.raise_bad_request("Gravatar 需要邮箱,请先绑定邮箱")
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "gravatar"
user = await user.save(session)
@user_settings_router.delete(
path='/avatar',
summary='重置头像为默认',
status_code=204,
)
async def router_user_settings_avatar_delete(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> None:
"""
重置头像为默认
认证JWT token
流程:
1. 如果当前是 FILE 头像,删除本地文件
2. 更新 User.avatar = "default"
"""
from service.avatar import delete_avatar_files
if user.avatar == "file":
await delete_avatar_files(session, user.id)
user.avatar = "default"
user = await user.save(session)
@user_settings_router.patch(
path='/theme',
summary='更新用户主题设置',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_user_settings_theme(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
request: UserThemeUpdateRequest,
) -> None:
"""
更新用户主题设置
请求体(均可选):
- theme_preset_id: 主题预设UUID
- theme_colors: 颜色配置对象(写入颜色快照)
错误处理:
- 404: 指定的主题预设不存在
"""
# 验证 preset_id 存在性
if request.theme_preset_id is not None:
preset: ThemePreset | None = await ThemePreset.get(
session, ThemePreset.id == request.theme_preset_id
)
if not preset:
http_exceptions.raise_not_found("主题预设不存在")
user.theme_preset_id = request.theme_preset_id
# 将颜色解构到快照列
if request.theme_colors is not None:
user.color_primary = request.theme_colors.primary
user.color_secondary = request.theme_colors.secondary
user.color_success = request.theme_colors.success
user.color_info = request.theme_colors.info
user.color_warning = request.theme_colors.warning
user.color_error = request.theme_colors.error
user.color_neutral = request.theme_colors.neutral
user = await user.save(session)
@user_settings_router.patch(
path='/password',
summary='修改密码',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_user_settings_change_password(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
request: ChangePasswordRequest,
) -> None:
"""
修改当前用户密码
请求体:
- old_password: 当前密码
- new_password: 新密码(至少 8 位)
错误处理:
- 400: 用户没有邮箱密码认证身份
- 403: 当前密码错误
"""
email_identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.user_id == user.id)
& (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
)
if not email_identity or not email_identity.credential:
http_exceptions.raise_bad_request("未找到邮箱密码认证身份")
verify_result = Password.verify(email_identity.credential, request.old_password)
if verify_result == PasswordStatus.INVALID:
http_exceptions.raise_forbidden("当前密码错误")
email_identity.credential = Password.hash(request.new_password)
email_identity = await email_identity.save(session)
@user_settings_router.patch( @user_settings_router.patch(
path='/{option}', path='/{option}',
summary='更新用户设定', summary='更新用户设定',
description='Update user settings.', status_code=status.HTTP_204_NO_CONTENT,
dependencies=[Depends(auth_required)],
) )
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase: async def router_user_settings_patch(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
option: SettingOption,
request: UserSettingUpdateRequest,
) -> None:
""" """
Update user settings. 更新单个用户设置项
Args: 路径参数:
option (str): The setting option to update. - option: 设置项名称nickname / language / timezone
Returns: 请求体:
dict: A dictionary containing the result of the settings update. - 包含与 option 同名的字段及其新值
错误处理:
- 422: 无效的 option 或字段值不符合约束
- 400: 必填字段值缺失
""" """
http_exceptions.raise_not_implemented() value = getattr(request, option.value)
# language / timezone 不允许设为 null
if value is None and option != SettingOption.NICKNAME:
http_exceptions.raise_bad_request(f"设置项 {option.value} 不允许为空")
setattr(user, option.value, value)
user = await user.save(session)
@user_settings_router.get( @user_settings_router.get(
@@ -148,17 +403,13 @@ def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
) )
async def router_user_settings_2fa( async def router_user_settings_2fa(
user: Annotated[sqlmodels.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase: ) -> TwoFactorResponse:
""" """
Get two-factor authentication initialization information. 获取两步验证初始化信息
Returns: 返回 setup_token用于后续验证请求和 uri用于生成二维码
dict: A dictionary containing two-factor authentication setup information.
""" """
return await Password.generate_totp(name=user.email or str(user.id))
return sqlmodels.ResponseBase(
data=await Password.generate_totp(user.email)
)
@user_settings_router.post( @user_settings_router.post(
@@ -166,38 +417,276 @@ async def router_user_settings_2fa(
summary='启用两步验证', summary='启用两步验证',
description='Enable two-factor authentication.', description='Enable two-factor authentication.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
status_code=204,
) )
async def router_user_settings_2fa_enable( async def router_user_settings_2fa_enable(
session: SessionDep, session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
setup_token: str, request: TwoFactorVerifyRequest,
code: str, ) -> None:
) -> sqlmodels.ResponseBase:
""" """
Enable two-factor authentication for the user. 启用两步验证
Returns: 将 2FA secret 存储到 email_password AuthIdentity 的 extra_data 中。
dict: A dictionary containing the result of enabling two-factor authentication.
""" """
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
try: try:
# 1. 解包 Token设置有效期例如 600秒 secret = serializer.loads(request.setup_token, salt="2fa-setup-salt", max_age=600)
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired: except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired") raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature: except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token") raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码 if Password.verify_totp(secret, request.code) != PasswordStatus.VALID:
if not Password.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code") raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA # 将 secret 存储到 AuthIdentity.extra_data 中
user.two_factor = secret email_identity: AuthIdentity | None = await AuthIdentity.get(
user = await user.save(session) session,
(AuthIdentity.user_id == user.id)
return sqlmodels.ResponseBase( & (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD),
data={"message": "Two-factor authentication enabled successfully"}
) )
if not email_identity:
raise HTTPException(status_code=400, detail="未找到邮箱密码认证身份")
import orjson
extra: dict = orjson.loads(email_identity.extra_data) if email_identity.extra_data else {}
extra["two_factor"] = secret
email_identity.extra_data = orjson.dumps(extra).decode('utf-8')
email_identity = await email_identity.save(session)
# ==================== 认证身份管理 ====================
@user_settings_router.get(
path='/identities',
summary='列出已绑定的认证身份',
)
async def router_user_settings_identities(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[AuthIdentityResponse]:
"""
列出当前用户已绑定的所有认证身份
返回:
- 认证身份列表,包含 provider、identifier、display_name 等
"""
identities: list[AuthIdentity] = await AuthIdentity.get(
session,
AuthIdentity.user_id == user.id,
fetch_mode="all",
)
return [identity.to_response() for identity in identities]
@user_settings_router.post(
path='/identity',
summary='绑定新的认证身份',
status_code=status.HTTP_201_CREATED,
)
async def router_user_settings_bind_identity(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
request: BindIdentityRequest,
) -> AuthIdentityResponse:
"""
绑定新的登录方式
请求体:
- provider: 提供者类型
- identifier: 标识符(邮箱 / 手机号 / OAuth code
- credential: 凭证(密码、验证码等)
- redirect_uri: OAuth 回调地址(可选)
错误处理:
- 400: provider 未启用
- 409: 该身份已被其他用户绑定
"""
# 检查是否已被绑定
existing = await AuthIdentity.get(
session,
(AuthIdentity.provider == request.provider)
& (AuthIdentity.identifier == request.identifier),
)
if existing:
raise HTTPException(status_code=409, detail="该身份已被绑定")
# 处理密码类型的凭证
credential: str | None = None
if request.provider == AuthProviderType.EMAIL_PASSWORD and request.credential:
credential = Password.hash(request.credential)
identity = AuthIdentity(
provider=request.provider,
identifier=request.identifier,
credential=credential,
is_primary=False,
is_verified=False,
user_id=user.id,
)
identity = await identity.save(session)
return identity.to_response()
@user_settings_router.delete(
path='/identity/{identity_id}',
summary='解绑认证身份',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_user_settings_unbind_identity(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
identity_id: UUID,
) -> None:
"""
解绑一个认证身份
约束:
- 不能解绑最后一个身份
- 站长配置强制绑定邮箱/手机号时,不能解绑对应身份
错误处理:
- 404: 身份不存在或不属于当前用户
- 400: 不能解绑最后一个身份 / 不能解绑强制绑定的身份
"""
# 查找目标身份
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.id == identity_id) & (AuthIdentity.user_id == user.id),
)
if not identity:
http_exceptions.raise_not_found("认证身份不存在")
# 检查是否为最后一个身份
all_identities: list[AuthIdentity] = await AuthIdentity.get(
session,
AuthIdentity.user_id == user.id,
fetch_mode="all",
)
if len(all_identities) <= 1:
http_exceptions.raise_bad_request("不能解绑最后一个认证身份")
# 检查强制绑定约束
if identity.provider == AuthProviderType.EMAIL_PASSWORD:
email_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_email_binding_required"),
)
if email_required_setting and email_required_setting.value == "1":
http_exceptions.raise_bad_request("站长要求必须绑定邮箱,不能解绑")
if identity.provider == AuthProviderType.PHONE_SMS:
phone_required_setting = await sqlmodels.Setting.get(
session,
(sqlmodels.Setting.type == sqlmodels.SettingsType.AUTH)
& (sqlmodels.Setting.name == "auth_phone_binding_required"),
)
if phone_required_setting and phone_required_setting.value == "1":
http_exceptions.raise_bad_request("站长要求必须绑定手机号,不能解绑")
await AuthIdentity.delete(session, identity)
# ==================== WebAuthn 凭证管理 ====================
@user_settings_router.get(
path='/authns',
summary='列出用户所有 WebAuthn 凭证',
)
async def router_user_settings_authns(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> list[AuthnDetailResponse]:
"""
列出当前用户所有已注册的 WebAuthn 凭证
返回:
- 凭证列表,包含 credential_id、name、device_type 等
"""
authns: list[UserAuthn] = await UserAuthn.get(
session,
UserAuthn.user_id == user.id,
fetch_mode="all",
)
return [authn.to_detail_response() for authn in authns]
@user_settings_router.patch(
path='/authn/{authn_id}',
summary='重命名 WebAuthn 凭证',
)
async def router_user_settings_rename_authn(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
authn_id: int,
request: AuthnRenameRequest,
) -> AuthnDetailResponse:
"""
重命名一个 WebAuthn 凭证
错误处理:
- 404: 凭证不存在或不属于当前用户
"""
authn: UserAuthn | None = await UserAuthn.get(
session,
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
)
if not authn:
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
authn.name = request.name
authn = await authn.save(session)
return authn.to_detail_response()
@user_settings_router.delete(
path='/authn/{authn_id}',
summary='删除 WebAuthn 凭证',
status_code=status.HTTP_204_NO_CONTENT,
)
async def router_user_settings_delete_authn(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
authn_id: int,
) -> None:
"""
删除一个 WebAuthn 凭证
同时删除对应的 AuthIdentity(provider=passkey) 记录。
如果这是用户最后一个认证身份,拒绝删除。
错误处理:
- 404: 凭证不存在或不属于当前用户
- 400: 不能删除最后一个认证身份
"""
authn: UserAuthn | None = await UserAuthn.get(
session,
(UserAuthn.id == authn_id) & (UserAuthn.user_id == user.id),
)
if not authn:
http_exceptions.raise_not_found("WebAuthn 凭证不存在")
# 检查是否为最后一个认证身份
all_identities: list[AuthIdentity] = await AuthIdentity.get(
session,
AuthIdentity.user_id == user.id,
fetch_mode="all",
)
if len(all_identities) <= 1:
http_exceptions.raise_bad_request("不能删除最后一个认证身份")
# 删除对应的 AuthIdentity
passkey_identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.provider == AuthProviderType.PASSKEY)
& (AuthIdentity.identifier == authn.credential_id)
& (AuthIdentity.user_id == user.id),
)
if passkey_identity:
await AuthIdentity.delete(session, passkey_identity, commit=False)
# 删除 UserAuthn
await UserAuthn.delete(session, authn)

View File

@@ -0,0 +1,146 @@
"""
用户文件查看器偏好设置端点
提供用户"始终使用"默认查看器的增删查功能。
"""
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, status
from sqlalchemy import and_
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from sqlmodels import (
FileApp,
FileAppExtension,
SetDefaultViewerRequest,
User,
UserFileAppDefault,
UserFileAppDefaultResponse,
)
from utils import http_exceptions
file_viewers_router = APIRouter(
prefix='/file-viewers',
tags=["user", "user_settings", "file_viewers"],
dependencies=[Depends(auth_required)],
)
@file_viewers_router.put(
path='/default',
summary='设置默认查看器',
description='为指定扩展名设置"始终使用"的查看器。',
)
async def set_default_viewer(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: SetDefaultViewerRequest,
) -> UserFileAppDefaultResponse:
"""
设置默认查看器端点
如果用户已有该扩展名的默认设置,则更新;否则创建新记录。
认证JWT token 必填
错误处理:
- 404: 应用不存在
- 400: 应用不支持该扩展名
"""
# 规范化扩展名
normalized_ext = request.extension.lower().strip().lstrip('.')
# 验证应用存在
app: FileApp | None = await FileApp.get(session, FileApp.id == request.app_id)
if not app:
http_exceptions.raise_not_found("应用不存在")
# 验证应用支持该扩展名
ext_record: FileAppExtension | None = await FileAppExtension.get(
session,
and_(
FileAppExtension.app_id == app.id,
FileAppExtension.extension == normalized_ext,
),
)
if not ext_record:
http_exceptions.raise_bad_request("该应用不支持此扩展名")
# 查找已有记录
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
session,
and_(
UserFileAppDefault.user_id == user.id,
UserFileAppDefault.extension == normalized_ext,
),
)
if existing:
existing.app_id = request.app_id
existing = await existing.save(session, load=UserFileAppDefault.app)
return existing.to_response()
else:
new_default = UserFileAppDefault(
user_id=user.id,
extension=normalized_ext,
app_id=request.app_id,
)
new_default = await new_default.save(session, load=UserFileAppDefault.app)
return new_default.to_response()
@file_viewers_router.get(
path='/defaults',
summary='列出所有默认查看器设置',
description='获取当前用户所有"始终使用"的查看器偏好。',
)
async def list_default_viewers(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[UserFileAppDefaultResponse]:
"""
列出所有默认查看器设置端点
认证JWT token 必填
"""
defaults: list[UserFileAppDefault] = await UserFileAppDefault.get(
session,
UserFileAppDefault.user_id == user.id,
fetch_mode="all",
load=UserFileAppDefault.app,
)
return [d.to_response() for d in defaults]
@file_viewers_router.delete(
path='/default/{default_id}',
summary='撤销默认查看器设置',
description='删除指定的"始终使用"偏好。',
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_default_viewer(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
default_id: UUID,
) -> None:
"""
撤销默认查看器设置端点
认证JWT token 必填
错误处理:
- 404: 记录不存在或不属于当前用户
"""
existing: UserFileAppDefault | None = await UserFileAppDefault.get(
session,
and_(
UserFileAppDefault.id == default_id,
UserFileAppDefault.user_id == user.id,
),
)
if not existing:
http_exceptions.raise_not_found("默认设置不存在")
await UserFileAppDefault.delete(session, existing)

View File

@@ -1,106 +0,0 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from sqlmodels import ResponseBase
from utils import http_exceptions
vas_router = APIRouter(
prefix="/vas",
tags=["vas"]
)
@vas_router.get(
path='/pack',
summary='获取容量包及配额信息',
description='Get information about storage packs and quotas.',
dependencies=[Depends(auth_required)]
)
def router_vas_pack() -> ResponseBase:
"""
Get information about storage packs and quotas.
Returns:
ResponseBase: A model containing the response data for storage packs and quotas.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/product',
summary='获取商品信息,同时返回支付信息',
description='Get product information along with payment details.',
dependencies=[Depends(auth_required)]
)
def router_vas_product() -> ResponseBase:
"""
Get product information along with payment details.
Returns:
ResponseBase: A model containing the response data for products and payment information.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/order',
summary='新建支付订单',
description='Create an order for a product.',
dependencies=[Depends(auth_required)]
)
def router_vas_order() -> ResponseBase:
"""
Create an order for a product.
Returns:
ResponseBase: A model containing the response data for the created order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/order/{id}',
summary='查询订单状态',
description='Get information about a specific payment order by ID.',
dependencies=[Depends(auth_required)]
)
def router_vas_order_get(id: str) -> ResponseBase:
"""
Get information about a specific payment order by ID.
Args:
id (str): The ID of the order to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified order.
"""
http_exceptions.raise_not_implemented()
@vas_router.get(
path='/redeem',
summary='获取兑换码信息',
description='Get information about a specific redemption code.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem(code: str) -> ResponseBase:
"""
Get information about a specific redemption code.
Args:
code (str): The redemption code to retrieve information for.
Returns:
ResponseBase: A model containing the response data for the specified redemption code.
"""
http_exceptions.raise_not_implemented()
@vas_router.post(
path='/redeem',
summary='执行兑换',
description='Redeem a redemption code for a product or service.',
dependencies=[Depends(auth_required)]
)
def router_vas_redeem_post() -> ResponseBase:
"""
Redeem a redemption code for a product or service.
Returns:
ResponseBase: A model containing the response data for the redeemed code.
"""
http_exceptions.raise_not_implemented()

View File

@@ -1,110 +1,207 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required
from sqlmodels import ResponseBase from middleware.dependencies import SessionDep
from sqlmodels import (
Object,
User,
WebDAV,
WebDAVAccountResponse,
WebDAVCreateRequest,
WebDAVUpdateRequest,
)
from service.redis.webdav_auth_cache import WebDAVAuthCache
from utils import http_exceptions from utils import http_exceptions
from utils.password.pwd import Password
# WebDAV 管理路由
webdav_router = APIRouter( webdav_router = APIRouter(
prefix='/webdav', prefix='/webdav',
tags=["webdav"], tags=["webdav"],
) )
def _check_webdav_enabled(user: User) -> None:
"""检查用户组是否启用了 WebDAV 功能"""
if not user.group.web_dav_enabled:
http_exceptions.raise_forbidden("WebDAV 功能未启用")
def _to_response(account: WebDAV) -> WebDAVAccountResponse:
"""将 WebDAV 数据库模型转换为响应 DTO"""
return WebDAVAccountResponse(
id=account.id,
name=account.name,
root=account.root,
readonly=account.readonly,
use_proxy=account.use_proxy,
created_at=str(account.created_at),
updated_at=str(account.updated_at),
)
@webdav_router.get( @webdav_router.get(
path='/accounts', path='/accounts',
summary='获取账号信息', summary='获取账号列表',
description='Get account information for WebDAV.',
dependencies=[Depends(auth_required)],
) )
def router_webdav_accounts() -> ResponseBase: async def list_accounts(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> list[WebDAVAccountResponse]:
""" """
Get account information for WebDAV. 列出当前用户所有 WebDAV 账户
Returns: 认证JWT Bearer Token
ResponseBase: A model containing the response data for the account information.
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
accounts: list[WebDAV] = await WebDAV.get(
session,
WebDAV.user_id == user_id,
fetch_mode="all",
)
return [_to_response(a) for a in accounts]
@webdav_router.post( @webdav_router.post(
path='/accounts', path='/accounts',
summary='建账号', summary='建账号',
description='Create a new WebDAV account.', status_code=201,
dependencies=[Depends(auth_required)],
) )
def router_webdav_create_account() -> ResponseBase: async def create_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: WebDAVCreateRequest,
) -> WebDAVAccountResponse:
""" """
Create a new WebDAV account. 创建 WebDAV 账户
Returns: 认证JWT Bearer Token
ResponseBase: A model containing the response data for the created account.
错误处理:
- 403: WebDAV 功能未启用
- 400: 根目录路径不存在或不是目录
- 409: 账户名已存在
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
@webdav_router.delete( # 验证账户名唯一
path='/accounts/{id}', existing = await WebDAV.get(
summary='删除账号', session,
description='Delete a WebDAV account by its ID.', (WebDAV.name == request.name) & (WebDAV.user_id == user_id),
dependencies=[Depends(auth_required)], )
) if existing:
def router_webdav_delete_account(id: str) -> ResponseBase: http_exceptions.raise_conflict("账户名已存在")
"""
Delete a WebDAV account by its ID.
Args: # 验证 root 路径存在且为目录
id (str): The ID of the account to be deleted. root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
Returns: # 创建账户
ResponseBase: A model containing the response data for the deletion operation. account = WebDAV(
""" name=request.name,
http_exceptions.raise_not_implemented() password=Password.hash(request.password),
root=request.root,
readonly=request.readonly,
use_proxy=request.use_proxy,
user_id=user_id,
)
account = await account.save(session)
@webdav_router.post( l.info(f"用户 {user_id} 创建 WebDAV 账户: {account.name}")
path='/mount', return _to_response(account)
summary='新建目录挂载',
description='Create a new WebDAV mount point.',
dependencies=[Depends(auth_required)],
)
def router_webdav_create_mount() -> ResponseBase:
"""
Create a new WebDAV mount point.
Returns:
ResponseBase: A model containing the response data for the created mount point.
"""
http_exceptions.raise_not_implemented()
@webdav_router.delete(
path='/mount/{id}',
summary='删除目录挂载',
description='Delete a WebDAV mount point by its ID.',
dependencies=[Depends(auth_required)],
)
def router_webdav_delete_mount(id: str) -> ResponseBase:
"""
Delete a WebDAV mount point by its ID.
Args:
id (str): The ID of the mount point to be deleted.
Returns:
ResponseBase: A model containing the response data for the deletion operation.
"""
http_exceptions.raise_not_implemented()
@webdav_router.patch( @webdav_router.patch(
path='accounts/{id}', path='/accounts/{account_id}',
summary='更新账号信息', summary='更新账号',
description='Update WebDAV account information by ID.',
dependencies=[Depends(auth_required)],
) )
def router_webdav_update_account(id: str) -> ResponseBase: async def update_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
request: WebDAVUpdateRequest,
) -> WebDAVAccountResponse:
""" """
Update WebDAV account information by ID. 更新 WebDAV 账户
Args: 认证JWT Bearer Token
id (str): The ID of the account to be updated.
Returns: 错误处理:
ResponseBase: A model containing the response data for the updated account. - 403: WebDAV 功能未启用
- 404: 账户不存在
- 400: 根目录路径不存在或不是目录
""" """
http_exceptions.raise_not_implemented() _check_webdav_enabled(user)
user_id: UUID = user.id
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
# 验证 root 路径
if request.root is not None:
root_obj = await Object.get_by_path(session, user_id, request.root)
if not root_obj or not root_obj.is_folder:
http_exceptions.raise_bad_request("根目录路径不存在或不是目录")
# 密码哈希后原地替换update() 会通过 model_dump(exclude_unset=True) 只取已设置字段
is_password_changed = request.password is not None
if is_password_changed:
request.password = Password.hash(request.password)
account = await account.update(session, request)
# 密码变更时清除认证缓存
if is_password_changed:
await WebDAVAuthCache.invalidate_account(user_id, account.name)
l.info(f"用户 {user_id} 更新 WebDAV 账户: {account.name}")
return _to_response(account)
@webdav_router.delete(
path='/accounts/{account_id}',
summary='删除账号',
status_code=204,
)
async def delete_account(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
account_id: int,
) -> None:
"""
删除 WebDAV 账户
认证JWT Bearer Token
错误处理:
- 403: WebDAV 功能未启用
- 404: 账户不存在
"""
_check_webdav_enabled(user)
user_id: UUID = user.id
account = await WebDAV.get(
session,
(WebDAV.id == account_id) & (WebDAV.user_id == user_id),
)
if not account:
http_exceptions.raise_not_found("WebDAV 账户不存在")
account_name = account.name
await WebDAV.delete(session, account)
# 清除认证缓存
await WebDAVAuthCache.invalidate_account(user_id, account_name)
l.info(f"用户 {user_id} 删除 WebDAV 账户: {account_name}")

View File

@@ -1 +0,0 @@
# WebDAV 操作路由

35
routers/dav/__init__.py Normal file
View File

@@ -0,0 +1,35 @@
"""
WebDAV 协议入口
使用 WsgiDAV + a2wsgi 提供 WebDAV 协议支持。
WsgiDAV 在 a2wsgi 的线程池中运行,不阻塞 FastAPI 事件循环。
"""
from a2wsgi import WSGIMiddleware
from wsgidav.wsgidav_app import WsgiDAVApp
from .domain_controller import DiskNextDomainController
from .provider import DiskNextDAVProvider
_wsgidav_config: dict[str, object] = {
"provider_mapping": {
"/": DiskNextDAVProvider(),
},
"http_authenticator": {
"domain_controller": DiskNextDomainController,
"accept_basic": True,
"accept_digest": False,
"default_to_digest": False,
},
"verbose": 1,
# 使用 WsgiDAV 内置的内存锁管理器
"lock_storage": True,
# 禁用 WsgiDAV 的目录浏览器(纯 DAV 协议)
"dir_browser": {
"enable": False,
},
}
_wsgidav_app = WsgiDAVApp(_wsgidav_config)
dav_app = WSGIMiddleware(_wsgidav_app, workers=10)
"""ASGI 应用,挂载到 /dav 路径"""

View File

@@ -0,0 +1,148 @@
"""
WebDAV 认证控制器
实现 WsgiDAV 的 BaseDomainController 接口,使用 HTTP Basic Auth
通过 DiskNext 的 WebDAV 账户模型进行认证。
用户名格式: {email}/{webdav_account_name}
"""
import asyncio
from uuid import UUID
from loguru import logger as l
from wsgidav.dc.base_dc import BaseDomainController
from routers.dav.provider import EventLoopRef, _get_session
from service.redis.webdav_auth_cache import WebDAVAuthCache
from sqlmodels.user import User, UserStatus
from sqlmodels.webdav import WebDAV
from utils.password.pwd import Password, PasswordStatus
async def _authenticate(
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
异步认证 WebDAV 用户。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 明文密码
:return: (user_id, webdav_id) 或 None
"""
# 1. 查缓存
cached = await WebDAVAuthCache.get(email, account_name, password)
if cached is not None:
return cached
# 2. 缓存未命中,查库验证
async with _get_session() as session:
user = await User.get(session, User.email == email, load=User.group)
if not user:
return None
if user.status != UserStatus.ACTIVE:
return None
if not user.group.web_dav_enabled:
return None
account = await WebDAV.get(
session,
(WebDAV.name == account_name) & (WebDAV.user_id == user.id),
)
if not account:
return None
status = Password.verify(account.password, password)
if status == PasswordStatus.INVALID:
return None
user_id: UUID = user.id
webdav_id: int = account.id
# 3. 写入缓存
await WebDAVAuthCache.set(email, account_name, password, user_id, webdav_id)
return user_id, webdav_id
class DiskNextDomainController(BaseDomainController):
"""
DiskNext WebDAV 认证控制器
用户名格式: {email}/{webdav_account_name}
密码: WebDAV 账户密码(创建账户时设置)
"""
def __init__(self, wsgidav_app: object, config: dict[str, object]) -> None:
super().__init__(wsgidav_app, config)
def get_domain_realm(self, path_info: str, environ: dict[str, object]) -> str:
"""返回 realm 名称"""
return "DiskNext WebDAV"
def require_authentication(self, realm: str, environ: dict[str, object]) -> bool:
"""所有请求都需要认证"""
return True
def is_share_anonymous(self, path_info: str) -> bool:
"""不支持匿名访问"""
return False
def supports_http_digest_auth(self) -> bool:
"""不支持 Digest 认证(密码存的是 Argon2 哈希,无法反推)"""
return False
def basic_auth_user(
self,
realm: str,
user_name: str,
password: str,
environ: dict[str, object],
) -> bool:
"""
HTTP Basic Auth 认证。
用户名格式: {email}/{webdav_account_name}
在 WSGI 线程中通过 anyio.from_thread.run 调用异步认证逻辑。
"""
# 解析用户名
if "/" not in user_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
email, account_name = user_name.split("/", 1)
if not email or not account_name:
l.debug(f"WebDAV 认证失败: 用户名格式无效 '{user_name}'")
return False
# 在 WSGI 线程中调用异步认证
future = asyncio.run_coroutine_threadsafe(
_authenticate(email, account_name, password),
EventLoopRef.get(),
)
result = future.result()
if result is None:
l.debug(f"WebDAV 认证失败: {email}/{account_name}")
return False
user_id, webdav_id = result
# 将认证信息存入 environ供 Provider 使用
environ["disknext.user_id"] = user_id
environ["disknext.webdav_id"] = webdav_id
environ["disknext.email"] = email
environ["disknext.account_name"] = account_name
return True
def digest_auth_user(
self,
realm: str,
user_name: str,
environ: dict[str, object],
) -> bool:
"""不支持 Digest 认证"""
return False

645
routers/dav/provider.py Normal file
View File

@@ -0,0 +1,645 @@
"""
DiskNext WebDAV 存储 Provider
将 WsgiDAV 的文件操作映射到 DiskNext 的 Object 模型。
所有异步数据库/文件操作通过 asyncio.run_coroutine_threadsafe() 桥接。
"""
import asyncio
import io
import mimetypes
from pathlib import Path
from typing import ClassVar
from uuid import UUID
from loguru import logger as l
from wsgidav.dav_error import (
DAVError,
HTTP_FORBIDDEN,
HTTP_INSUFFICIENT_STORAGE,
HTTP_NOT_FOUND,
)
from wsgidav.dav_provider import DAVCollection, DAVNonCollection, DAVProvider
from service.storage import LocalStorageService, adjust_user_storage
from sqlmodels.database_connection import DatabaseManager
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy
from sqlmodels.user import User
from sqlmodels.webdav import WebDAV
class EventLoopRef:
"""持有主线程事件循环引用,供 WSGI 线程使用"""
_loop: ClassVar[asyncio.AbstractEventLoop | None] = None
@classmethod
async def capture(cls) -> None:
"""在 async 上下文中调用,捕获当前事件循环"""
cls._loop = asyncio.get_running_loop()
@classmethod
def get(cls) -> asyncio.AbstractEventLoop:
if cls._loop is None:
raise RuntimeError("事件循环尚未捕获,请先调用 EventLoopRef.capture()")
return cls._loop
def _run_async(coro): # type: ignore[no-untyped-def]
"""在 WSGI 线程中通过 run_coroutine_threadsafe 运行协程"""
future = asyncio.run_coroutine_threadsafe(coro, EventLoopRef.get())
return future.result()
def _get_session(): # type: ignore[no-untyped-def]
"""获取数据库会话上下文管理器"""
return DatabaseManager._async_session_factory()
# ==================== 异步辅助函数 ====================
async def _get_webdav_account(webdav_id: int) -> WebDAV | None:
"""获取 WebDAV 账户"""
async with _get_session() as session:
return await WebDAV.get(session, WebDAV.id == webdav_id)
async def _get_object_by_path(user_id: UUID, path: str) -> Object | None:
"""根据路径获取对象"""
async with _get_session() as session:
return await Object.get_by_path(session, user_id, path)
async def _get_children(user_id: UUID, parent_id: UUID) -> list[Object]:
"""获取目录子对象"""
async with _get_session() as session:
return await Object.get_children(session, user_id, parent_id)
async def _get_object_by_id(object_id: UUID) -> Object | None:
"""根据ID获取对象"""
async with _get_session() as session:
return await Object.get(session, Object.id == object_id, load=Object.physical_file)
async def _get_user(user_id: UUID) -> User | None:
"""获取用户(含 group 关系)"""
async with _get_session() as session:
return await User.get(session, User.id == user_id, load=User.group)
async def _get_policy(policy_id: UUID) -> Policy | None:
"""获取存储策略"""
async with _get_session() as session:
return await Policy.get(session, Policy.id == policy_id)
async def _create_folder(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建目录对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FOLDER,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _create_file(
name: str,
parent_id: UUID,
owner_id: UUID,
policy_id: UUID,
) -> Object:
"""创建空文件对象"""
async with _get_session() as session:
obj = Object(
name=name,
type=ObjectType.FILE,
size=0,
parent_id=parent_id,
owner_id=owner_id,
policy_id=policy_id,
)
obj = await obj.save(session)
return obj
async def _soft_delete_object(object_id: UUID) -> None:
"""软删除对象(移入回收站)"""
from service.storage import soft_delete_objects
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
await soft_delete_objects(session, [obj])
async def _finalize_upload(
object_id: UUID,
physical_path: str,
size: int,
owner_id: UUID,
policy_id: UUID,
) -> None:
"""上传完成后更新对象元数据和物理文件记录"""
async with _get_session() as session:
# 获取存储路径(相对路径)
policy = await Policy.get(session, Policy.id == policy_id)
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
base_path = Path(policy.server).resolve()
full_path = Path(physical_path).resolve()
storage_path = str(full_path.relative_to(base_path))
# 创建 PhysicalFile 记录
pf = PhysicalFile(
storage_path=storage_path,
size=size,
policy_id=policy_id,
reference_count=1,
)
pf = await pf.save(session)
# 更新 Object
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'size': size, 'physical_file_id': pf.id})
obj = await obj.save(session)
# 更新用户存储用量
if size > 0:
await adjust_user_storage(session, owner_id, size)
async def _move_object(
object_id: UUID,
new_parent_id: UUID,
new_name: str,
) -> None:
"""移动/重命名对象"""
async with _get_session() as session:
obj = await Object.get(session, Object.id == object_id)
if obj:
obj.sqlmodel_update({'parent_id': new_parent_id, 'name': new_name})
obj = await obj.save(session)
async def _copy_object_recursive(
src_id: UUID,
dst_parent_id: UUID,
dst_name: str,
owner_id: UUID,
) -> None:
"""递归复制对象"""
from service.storage import copy_object_recursive
async with _get_session() as session:
src = await Object.get(session, Object.id == src_id)
if not src:
return
await copy_object_recursive(session, src, dst_parent_id, owner_id, new_name=dst_name)
# ==================== 辅助工具 ====================
def _get_environ_info(environ: dict[str, object]) -> tuple[UUID, int]:
"""从 environ 中提取认证信息"""
user_id: UUID = environ["disknext.user_id"] # type: ignore[assignment]
webdav_id: int = environ["disknext.webdav_id"] # type: ignore[assignment]
return user_id, webdav_id
def _resolve_dav_path(account_root: str, dav_path: str) -> str:
"""
将 DAV 相对路径映射到 DiskNext 绝对路径。
:param account_root: 账户挂载根路径,如 "/""/docs"
:param dav_path: DAV 请求路径,如 "/""/photos/cat.jpg"
:return: DiskNext 内部路径,如 "/docs/photos/cat.jpg"
"""
# 规范化根路径
root = account_root.rstrip("/")
if not root:
root = ""
# 规范化 DAV 路径
if not dav_path or dav_path == "/":
return root + "/" if root else "/"
if not dav_path.startswith("/"):
dav_path = "/" + dav_path
full = root + dav_path
return full if full else "/"
def _check_readonly(environ: dict[str, object]) -> None:
"""检查账户是否只读,只读则抛出 403"""
account = environ.get("disknext.webdav_account")
if account and getattr(account, 'readonly', False):
raise DAVError(HTTP_FORBIDDEN, "WebDAV 账户为只读模式")
def _check_storage_quota(user: User, additional_bytes: int) -> None:
"""检查存储配额"""
max_storage = user.group.max_storage
if max_storage > 0 and user.storage + additional_bytes > max_storage:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
class QuotaLimitedWriter(io.RawIOBase):
"""带配额限制的写入流包装器"""
def __init__(self, stream: io.BufferedWriter, max_bytes: int) -> None:
self._stream = stream
self._max_bytes = max_bytes
self._bytes_written = 0
def writable(self) -> bool:
return True
def write(self, b: bytes | bytearray) -> int:
if self._bytes_written + len(b) > self._max_bytes:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
written = self._stream.write(b)
self._bytes_written += written
return written
def close(self) -> None:
self._stream.close()
super().close()
@property
def bytes_written(self) -> int:
return self._bytes_written
# ==================== Provider ====================
class DiskNextDAVProvider(DAVProvider):
"""DiskNext WebDAV 存储 Provider"""
def __init__(self) -> None:
super().__init__()
def get_resource_inst(
self,
path: str,
environ: dict[str, object],
) -> 'DiskNextCollection | DiskNextFile | None':
"""
将 WebDAV 路径映射到资源对象。
首次调用时加载 WebDAV 账户信息并缓存到 environ。
"""
user_id, webdav_id = _get_environ_info(environ)
# 首次请求时加载账户信息
if "disknext.webdav_account" not in environ:
account = _run_async(_get_webdav_account(webdav_id))
if not account:
return None
environ["disknext.webdav_account"] = account
account: WebDAV = environ["disknext.webdav_account"] # type: ignore[no-redef]
disknext_path = _resolve_dav_path(account.root, path)
obj = _run_async(_get_object_by_path(user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(path, environ, obj, user_id, account)
else:
return DiskNextFile(path, environ, obj, user_id, account)
def is_readonly(self) -> bool:
"""只读由账户级别控制,不在 provider 级别限制"""
return False
# ==================== Collection目录 ====================
class DiskNextCollection(DAVCollection):
"""DiskNext 目录资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
def get_display_info(self) -> dict[str, str]:
return {"type": "Directory"}
def get_member_names(self) -> list[str]:
"""获取子对象名称列表"""
children = _run_async(_get_children(self._user_id, self._obj.id))
return [c.name for c in children]
def get_member(self, name: str) -> 'DiskNextCollection | DiskNextFile | None':
"""获取指定名称的子资源"""
member_path = self.path.rstrip("/") + "/" + name
account_root = self._account.root
disknext_path = _resolve_dav_path(account_root, member_path)
obj = _run_async(_get_object_by_path(self._user_id, disknext_path))
if not obj:
return None
if obj.is_folder:
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
else:
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def create_empty_resource(self, name: str) -> 'DiskNextFile':
"""创建空文件PUT 操作的第一步)"""
_check_readonly(self.environ)
obj = _run_async(_create_file(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextFile(member_path, self.environ, obj, self._user_id, self._account)
def create_collection(self, name: str) -> 'DiskNextCollection':
"""创建子目录MKCOL"""
_check_readonly(self.environ)
obj = _run_async(_create_folder(
name=name,
parent_id=self._obj.id,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
member_path = self.path.rstrip("/") + "/" + name
return DiskNextCollection(member_path, self.environ, obj, self._user_id, self._account)
def delete(self) -> None:
"""软删除目录"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动目录"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_recursive_delete(self) -> bool:
return True
def support_recursive_move(self, dest_path: str) -> bool:
return True
# ==================== NonCollection文件 ====================
class DiskNextFile(DAVNonCollection):
"""DiskNext 文件资源"""
def __init__(
self,
path: str,
environ: dict[str, object],
obj: Object,
user_id: UUID,
account: WebDAV,
) -> None:
super().__init__(path, environ)
self._obj = obj
self._user_id = user_id
self._account = account
self._write_path: str | None = None
self._write_stream: io.BufferedWriter | QuotaLimitedWriter | None = None
def get_content_length(self) -> int | None:
return self._obj.size if self._obj.size else 0
def get_content_type(self) -> str | None:
# 尝试从文件名推断 MIME 类型
mime, _ = mimetypes.guess_type(self._obj.name)
return mime or "application/octet-stream"
def get_creation_date(self) -> float | None:
if self._obj.created_at:
return self._obj.created_at.timestamp()
return None
def get_last_modified(self) -> float | None:
if self._obj.updated_at:
return self._obj.updated_at.timestamp()
return None
def get_display_info(self) -> dict[str, str]:
return {"type": "File"}
def get_content(self) -> io.BufferedReader | None:
"""
返回文件内容的可读流。
WsgiDAV 在线程中运行,可安全使用同步 open()。
"""
obj_with_file = _run_async(_get_object_by_id(self._obj.id))
if not obj_with_file or not obj_with_file.physical_file:
return None
pf = obj_with_file.physical_file
policy = _run_async(_get_policy(obj_with_file.policy_id))
if not policy or not policy.server:
return None
full_path = Path(policy.server).resolve() / pf.storage_path
if not full_path.is_file():
l.warning(f"WebDAV: 物理文件不存在: {full_path}")
return None
return open(full_path, "rb") # noqa: SIM115
def begin_write(
self,
*,
content_type: str | None = None,
) -> io.BufferedWriter | QuotaLimitedWriter:
"""
开始写入文件PUT 操作)。
返回一个可写的文件流WsgiDAV 将向其中写入请求体数据。
当用户有配额限制时,返回 QuotaLimitedWriter 在写入过程中实时检查配额。
"""
_check_readonly(self.environ)
# 检查配额
remaining_quota: int = 0
user = _run_async(_get_user(self._user_id))
if user:
max_storage = user.group.max_storage
if max_storage > 0:
remaining_quota = max_storage - user.storage
if remaining_quota <= 0:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# Content-Length 预检(如果有的话)
content_length = self.environ.get("CONTENT_LENGTH")
if content_length and int(content_length) > remaining_quota:
raise DAVError(HTTP_INSUFFICIENT_STORAGE, "存储空间不足")
# 获取策略以确定存储路径
policy = _run_async(_get_policy(self._obj.policy_id))
if not policy or not policy.server:
raise DAVError(HTTP_NOT_FOUND, "存储策略不存在")
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = _run_async(
storage_service.generate_file_path(
user_id=self._user_id,
original_filename=self._obj.name,
)
)
self._write_path = full_path
raw_stream = open(full_path, "wb") # noqa: SIM115
# 有配额限制时使用包装流,实时检查写入量
if remaining_quota > 0:
self._write_stream = QuotaLimitedWriter(raw_stream, remaining_quota)
else:
self._write_stream = raw_stream
return self._write_stream
def end_write(self, *, with_errors: bool) -> None:
"""写入完成后的收尾工作"""
if self._write_stream:
self._write_stream.close()
self._write_stream = None
if with_errors:
if self._write_path:
file_path = Path(self._write_path)
if file_path.exists():
file_path.unlink()
return
if not self._write_path:
return
# 获取文件大小
file_path = Path(self._write_path)
if not file_path.exists():
return
size = file_path.stat().st_size
# 更新数据库记录
_run_async(_finalize_upload(
object_id=self._obj.id,
physical_path=self._write_path,
size=size,
owner_id=self._user_id,
policy_id=self._obj.policy_id,
))
l.debug(f"WebDAV 文件写入完成: {self._obj.name}, size={size}")
def delete(self) -> None:
"""软删除文件"""
_check_readonly(self.environ)
_run_async(_soft_delete_object(self._obj.id))
def copy_move_single(self, dest_path: str, *, is_move: bool) -> bool:
"""复制或移动文件"""
_check_readonly(self.environ)
account_root = self._account.root
dest_disknext = _resolve_dav_path(account_root, dest_path)
# 解析目标父路径和新名称
if "/" in dest_disknext.rstrip("/"):
parent_path = dest_disknext.rsplit("/", 1)[0] or "/"
new_name = dest_disknext.rsplit("/", 1)[1]
else:
parent_path = "/"
new_name = dest_disknext.lstrip("/")
dest_parent = _run_async(_get_object_by_path(self._user_id, parent_path))
if not dest_parent:
raise DAVError(HTTP_NOT_FOUND, "目标父目录不存在")
if is_move:
_run_async(_move_object(self._obj.id, dest_parent.id, new_name))
else:
_run_async(_copy_object_recursive(
self._obj.id, dest_parent.id, new_name, self._user_id,
))
return True
def support_content_length(self) -> bool:
return True
def get_etag(self) -> str | None:
"""返回 ETag基于ID和更新时间WsgiDAV 会自动加双引号"""
if self._obj.updated_at:
return f"{self._obj.id}-{int(self._obj.updated_at.timestamp())}"
return None
def support_etag(self) -> bool:
return True
def support_ranges(self) -> bool:
return True

11
routers/wopi/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
"""
WOPIWeb Application Open Platform Interface路由
挂载在根级别 /wopi非 /api/v1 下),因为 WOPI 客户端要求标准路径。
"""
from fastapi import APIRouter
from .files import wopi_files_router
wopi_router = APIRouter(prefix="/wopi", tags=["wopi"])
wopi_router.include_router(wopi_files_router)

View File

@@ -0,0 +1,203 @@
"""
WOPI 文件操作端点
实现 WOPI 协议的核心文件操作接口:
- CheckFileInfo: 获取文件元数据
- GetFile: 下载文件内容
- PutFile: 上传/更新文件内容
"""
from uuid import UUID
from fastapi import APIRouter, Query, Request, Response
from fastapi.responses import JSONResponse
from loguru import logger as l
from middleware.dependencies import SessionDep
from sqlmodels import Object, PhysicalFile, Policy, PolicyType, User, WopiFileInfo
from service.storage import LocalStorageService
from utils import http_exceptions
from utils.JWT.wopi_token import verify_wopi_token
wopi_files_router = APIRouter(prefix="/files", tags=["wopi"])
@wopi_files_router.get(
path='/{file_id}',
summary='WOPI CheckFileInfo',
description='返回文件的元数据信息。',
)
async def check_file_info(
session: SessionDep,
file_id: UUID,
access_token: str = Query(...),
) -> JSONResponse:
"""
WOPI CheckFileInfo 端点
认证WOPI access_tokenquery 参数)
返回 WOPI 规范的 PascalCase JSON。
"""
# 验证令牌
payload = verify_wopi_token(access_token)
if not payload or payload.file_id != file_id:
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
# 获取文件
file_obj: Object | None = await Object.get(
session,
Object.id == file_id,
)
if not file_obj or not file_obj.is_file:
http_exceptions.raise_not_found("文件不存在")
# 获取用户信息
user: User | None = await User.get(session, User.id == payload.user_id)
user_name = user.nickname or user.email or str(payload.user_id) if user else str(payload.user_id)
# 构建响应
info = WopiFileInfo(
base_file_name=file_obj.name,
size=file_obj.size or 0,
owner_id=str(file_obj.owner_id),
user_id=str(payload.user_id),
user_friendly_name=user_name,
version=file_obj.updated_at.isoformat() if file_obj.updated_at else "",
user_can_write=payload.can_write,
read_only=not payload.can_write,
supports_update=payload.can_write,
)
return JSONResponse(content=info.to_wopi_dict())
@wopi_files_router.get(
path='/{file_id}/contents',
summary='WOPI GetFile',
description='返回文件的二进制内容。',
)
async def get_file(
session: SessionDep,
file_id: UUID,
access_token: str = Query(...),
) -> Response:
"""
WOPI GetFile 端点
认证WOPI access_tokenquery 参数)
返回文件的原始二进制内容。
"""
# 验证令牌
payload = verify_wopi_token(access_token)
if not payload or payload.file_id != file_id:
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
# 获取文件
file_obj: Object | None = await Object.get(session, Object.id == file_id)
if not file_obj or not file_obj.is_file:
http_exceptions.raise_not_found("文件不存在")
# 获取物理文件
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
if not physical_file or not physical_file.storage_path:
http_exceptions.raise_internal_error("文件存储路径丢失")
# 获取策略
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
if not await storage_service.file_exists(physical_file.storage_path):
http_exceptions.raise_not_found("物理文件不存在")
import aiofiles
async with aiofiles.open(physical_file.storage_path, 'rb') as f:
content = await f.read()
return Response(
content=content,
media_type="application/octet-stream",
headers={"X-WOPI-ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
)
else:
http_exceptions.raise_not_implemented("S3 存储暂未实现")
@wopi_files_router.post(
path='/{file_id}/contents',
summary='WOPI PutFile',
description='更新文件内容。',
)
async def put_file(
session: SessionDep,
request: Request,
file_id: UUID,
access_token: str = Query(...),
) -> JSONResponse:
"""
WOPI PutFile 端点
认证WOPI access_tokenquery 参数,需要写权限)
接收请求体中的文件二进制内容并覆盖存储。
"""
# 验证令牌
payload = verify_wopi_token(access_token)
if not payload or payload.file_id != file_id:
http_exceptions.raise_unauthorized("WOPI token 无效或文件不匹配")
if not payload.can_write:
http_exceptions.raise_forbidden("没有写入权限")
# 获取文件
file_obj: Object | None = await Object.get(session, Object.id == file_id)
if not file_obj or not file_obj.is_file:
http_exceptions.raise_not_found("文件不存在")
# 获取物理文件
physical_file: PhysicalFile | None = await file_obj.awaitable_attrs.physical_file
if not physical_file or not physical_file.storage_path:
http_exceptions.raise_internal_error("文件存储路径丢失")
# 获取策略
policy: Policy | None = await Policy.get(session, Policy.id == file_obj.policy_id)
if not policy:
http_exceptions.raise_internal_error("存储策略不存在")
# 读取请求体
content = await request.body()
if policy.type == PolicyType.LOCAL:
import aiofiles
async with aiofiles.open(physical_file.storage_path, 'wb') as f:
await f.write(content)
# 更新文件大小
new_size = len(content)
old_size = file_obj.size or 0
file_obj.size = new_size
file_obj = await file_obj.save(session, commit=False)
# 更新物理文件大小
physical_file.size = new_size
await physical_file.save(session, commit=False)
# 更新用户存储配额
size_diff = new_size - old_size
if size_diff != 0:
from service.storage import adjust_user_storage
await adjust_user_storage(session, file_obj.owner_id, size_diff, commit=False)
await session.commit()
l.info(f"WOPI PutFile: file_id={file_id}, new_size={new_size}")
return JSONResponse(
content={"ItemVersion": file_obj.updated_at.isoformat() if file_obj.updated_at else ""},
status_code=200,
)
else:
http_exceptions.raise_not_implemented("S3 存储暂未实现")

View File

@@ -57,7 +57,7 @@ class CaptchaScene(StrEnum):
async def verify_captcha_if_needed( async def verify_captcha_if_needed(
session: AsyncSession, session: AsyncSession,
scene: CaptchaScene, scene: CaptchaScene,
captcha_code: str | None, captcha_code: str,
) -> None: ) -> None:
""" """
通用验证码校验:查询设置判断是否需要,需要则校验。 通用验证码校验:查询设置判断是否需要,需要则校验。
@@ -81,23 +81,19 @@ async def verify_captcha_if_needed(
if not scene_setting or scene_setting.value != "1": if not scene_setting or scene_setting.value != "1":
return return
# 2. 需要但未提供 # 2. 查询验证码类型和密钥
if not captcha_code:
http_exceptions.raise_bad_request(detail="请完成验证码验证")
# 3. 查询验证码类型和密钥
captcha_settings: list[Setting] = await Setting.get( captcha_settings: list[Setting] = await Setting.get(
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all", session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
) )
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings} s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
captcha_type = CaptchaType(s.get("captcha_type") or "default") captcha_type = CaptchaType(s.get("captcha_type") or "default")
# 4. DEFAULT 图片验证码尚未实现,跳过 # 3. DEFAULT 图片验证码尚未实现,跳过
if captcha_type == CaptchaType.DEFAULT: if captcha_type == CaptchaType.DEFAULT:
l.warning("DEFAULT 图片验证码尚未实现,跳过验证") l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
return return
# 5. 选择验证器和密钥 # 4. 选择验证器和密钥
if captcha_type == CaptchaType.GCAPTCHA: if captcha_type == CaptchaType.GCAPTCHA:
secret = s.get("captcha_ReCaptchaSecret") secret = s.get("captcha_ReCaptchaSecret")
verifier: CaptchaBase = GCaptcha() verifier: CaptchaBase = GCaptcha()
@@ -112,7 +108,7 @@ async def verify_captcha_if_needed(
l.error(f"验证码密钥未配置: captcha_type={captcha_type}") l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
# 6. 调用第三方 API 校验 # 5. 调用第三方 API 校验
is_valid = await verifier.verify_captcha( is_valid = await verifier.verify_captcha(
CaptchaRequestBase(response=captcha_code, secret=secret) CaptchaRequestBase(response=captcha_code, secret=secret)
) )

View File

@@ -0,0 +1,5 @@
from captcha.image import ImageCaptcha
captcha = ImageCaptcha()
print(captcha.generate())

View File

@@ -0,0 +1,68 @@
"""
WebAuthn Challenge 一次性存储
支持 Redis首选使用 GETDEL 原子操作)和内存 TTLCache降级
Challenge 存储后 5 分钟过期,取出即删除(防重放)。
"""
from typing import ClassVar
from cachetools import TTLCache
from loguru import logger as l
from . import RedisManager
# Challenge 过期时间(秒)
_CHALLENGE_TTL: int = 300
class ChallengeStore:
"""
WebAuthn Challenge 一次性存储管理器
根据 Redis 可用性自动选择存储后端:
- Redis 可用:使用 Redis GETDEL 原子操作
- Redis 不可用:使用内存 TTLCache仅单实例
Key 约定:
- 注册: ``reg:{user_id}``
- 登录: ``auth:{challenge_token}``
"""
_memory_cache: ClassVar[TTLCache[str, bytes]] = TTLCache(
maxsize=10000,
ttl=_CHALLENGE_TTL,
)
"""内存缓存降级方案"""
@classmethod
async def store(cls, key: str, challenge: bytes) -> None:
"""
存储 challengeTTL 5 分钟。
:param key: 存储键(如 ``reg:{user_id}`` 或 ``auth:{token}``
:param challenge: challenge 字节数据
"""
client = RedisManager.get_client()
if client is not None:
redis_key = f"webauthn_challenge:{key}"
await client.set(redis_key, challenge, ex=_CHALLENGE_TTL)
else:
cls._memory_cache[key] = challenge
@classmethod
async def retrieve_and_delete(cls, key: str) -> bytes | None:
"""
一次性取出并删除 challenge防重放
:param key: 存储键
:return: challenge 字节数据,过期或不存在时返回 None
"""
client = RedisManager.get_client()
if client is not None:
redis_key = f"webauthn_challenge:{key}"
result: bytes | None = await client.getdel(redis_key)
return result
else:
return cls._memory_cache.pop(key, None)

View File

@@ -0,0 +1,128 @@
"""
WebDAV 认证缓存
缓存 HTTP Basic Auth 的认证结果,避免每次请求都查库 + Argon2 验证。
支持 Redis首选和内存缓存降级两种存储后端。
"""
import hashlib
from typing import ClassVar
from uuid import UUID
from cachetools import TTLCache
from loguru import logger as l
from . import RedisManager
_AUTH_TTL: int = 300
"""认证缓存 TTL5 分钟"""
class WebDAVAuthCache:
"""
WebDAV 认证结果缓存
缓存键格式: webdav_auth:{email}/{account_name}:{sha256(password)}
缓存值格式: {user_id}:{webdav_id}
密码的 SHA256 作为缓存键的一部分,密码变更后旧缓存自然 miss。
"""
_memory_cache: ClassVar[TTLCache[str, str]] = TTLCache(maxsize=10000, ttl=_AUTH_TTL)
"""内存缓存降级方案"""
@classmethod
def _build_key(cls, email: str, account_name: str, password: str) -> str:
"""构建缓存键"""
pwd_hash = hashlib.sha256(password.encode()).hexdigest()[:16]
return f"webdav_auth:{email}/{account_name}:{pwd_hash}"
@classmethod
async def get(
cls,
email: str,
account_name: str,
password: str,
) -> tuple[UUID, int] | None:
"""
查询缓存中的认证结果。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:return: (user_id, webdav_id) 或 None缓存未命中
"""
key = cls._build_key(email, account_name, password)
client = RedisManager.get_client()
if client is not None:
value = await client.get(key)
if value is not None:
raw = value.decode() if isinstance(value, bytes) else value
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
else:
raw = cls._memory_cache.get(key)
if raw is not None:
user_id_str, webdav_id_str = raw.split(":", 1)
return UUID(user_id_str), int(webdav_id_str)
return None
@classmethod
async def set(
cls,
email: str,
account_name: str,
password: str,
user_id: UUID,
webdav_id: int,
) -> None:
"""
写入认证结果到缓存。
:param email: 用户邮箱
:param account_name: WebDAV 账户名
:param password: 用户提供的明文密码
:param user_id: 用户UUID
:param webdav_id: WebDAV 账户ID
"""
key = cls._build_key(email, account_name, password)
value = f"{user_id}:{webdav_id}"
client = RedisManager.get_client()
if client is not None:
await client.set(key, value, ex=_AUTH_TTL)
else:
cls._memory_cache[key] = value
@classmethod
async def invalidate_account(cls, user_id: UUID, account_name: str) -> None:
"""
失效指定账户的所有缓存。
由于缓存键包含 password hash无法精确删除
Redis 端使用 pattern scan 删除,内存端清空全部。
:param user_id: 用户UUID
:param account_name: WebDAV 账户名
"""
client = RedisManager.get_client()
if client is not None:
pattern = f"webdav_auth:*/{account_name}:*"
cursor: int = 0
while True:
cursor, keys = await client.scan(cursor, match=pattern, count=100)
if keys:
await client.delete(*keys)
if cursor == 0:
break
else:
# 内存缓存无法按 pattern 删除,清除所有含该账户名的条目
keys_to_delete = [
k for k in cls._memory_cache
if f"/{account_name}:" in k
]
for k in keys_to_delete:
cls._memory_cache.pop(k, None)
l.debug(f"已清除 WebDAV 认证缓存: user={user_id}, account={account_name}")

View File

@@ -3,6 +3,7 @@
提供文件存储相关的服务,包括: 提供文件存储相关的服务,包括:
- 本地存储服务 - 本地存储服务
- S3 存储服务
- 命名规则解析器 - 命名规则解析器
- 存储异常定义 - 存储异常定义
""" """
@@ -11,6 +12,8 @@ from .exceptions import (
FileReadError, FileReadError,
FileWriteError, FileWriteError,
InvalidPathError, InvalidPathError,
S3APIError,
S3MultipartUploadError,
StorageException, StorageException,
StorageFileNotFoundError, StorageFileNotFoundError,
UploadSessionExpiredError, UploadSessionExpiredError,
@@ -18,3 +21,13 @@ from .exceptions import (
) )
from .local_storage import LocalStorageService from .local_storage import LocalStorageService
from .naming_rule import NamingContext, NamingRuleParser from .naming_rule import NamingContext, NamingRuleParser
from .object import (
adjust_user_storage,
copy_object_recursive,
delete_object_recursive,
permanently_delete_objects,
restore_objects,
soft_delete_objects,
)
from .migrate import migrate_file_with_task, migrate_directory_files
from .s3_storage import S3StorageService

View File

@@ -43,3 +43,13 @@ class UploadSessionExpiredError(StorageException):
class InvalidPathError(StorageException): class InvalidPathError(StorageException):
"""无效的路径""" """无效的路径"""
pass pass
class S3APIError(StorageException):
"""S3 API 请求错误"""
pass
class S3MultipartUploadError(S3APIError):
"""S3 分片上传错误"""
pass

View File

@@ -263,15 +263,49 @@ class LocalStorageService:
""" """
删除文件(物理删除) 删除文件(物理删除)
删除文件后会尝试清理因此变空的父目录。
:param path: 完整文件路径 :param path: 完整文件路径
""" """
if await self.file_exists(path): if await self.file_exists(path):
try: try:
await aiofiles.os.remove(path) await aiofiles.os.remove(path)
l.debug(f"已删除文件: {path}") l.debug(f"已删除文件: {path}")
await self._cleanup_empty_parents(path)
except OSError as e: except OSError as e:
l.warning(f"删除文件失败 {path}: {e}") l.warning(f"删除文件失败 {path}: {e}")
async def _cleanup_empty_parents(self, file_path: str) -> None:
"""
从被删文件的父目录开始,向上逐级删除空目录
在以下情况停止:
- 到达存储根目录_base_path
- 遇到非空目录
- 遇到 .trash 目录
- 删除失败(权限、并发等)
:param file_path: 被删文件的完整路径
"""
current = Path(file_path).parent
while current != self._base_path and str(current).startswith(str(self._base_path)):
if current.name == '.trash':
break
try:
entries = await aiofiles.os.listdir(str(current))
if entries:
break
await aiofiles.os.rmdir(str(current))
l.debug(f"已清理空目录: {current}")
current = current.parent
except OSError as e:
l.debug(f"清理空目录失败(忽略): {current}: {e}")
break
async def move_to_trash( async def move_to_trash(
self, self,
source_path: str, source_path: str,
@@ -304,6 +338,7 @@ class LocalStorageService:
try: try:
await aiofiles.os.rename(source_path, str(trash_path)) await aiofiles.os.rename(source_path, str(trash_path))
l.info(f"文件已移动到回收站: {source_path} -> {trash_path}") l.info(f"文件已移动到回收站: {source_path} -> {trash_path}")
await self._cleanup_empty_parents(source_path)
return str(trash_path) return str(trash_path)
except OSError as e: except OSError as e:
raise StorageException(f"移动文件到回收站失败: {e}") raise StorageException(f"移动文件到回收站失败: {e}")

291
service/storage/migrate.py Normal file
View File

@@ -0,0 +1,291 @@
"""
存储策略迁移服务
提供跨存储策略的文件迁移功能:
- 单文件迁移:从源策略下载 → 上传到目标策略 → 更新数据库记录
- 目录批量迁移:递归遍历目录下所有文件逐个迁移,同时更新子目录的 policy_id
"""
from uuid import UUID
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.object import Object, ObjectType
from sqlmodels.physical_file import PhysicalFile
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.task import Task, TaskStatus
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
async def _get_storage_service(
policy: Policy,
) -> LocalStorageService | S3StorageService:
"""
根据策略类型创建对应的存储服务实例
:param policy: 存储策略
:return: 存储服务实例
"""
if policy.type == PolicyType.LOCAL:
return LocalStorageService(policy)
elif policy.type == PolicyType.S3:
return await S3StorageService.from_policy(policy)
else:
raise ValueError(f"不支持的存储策略类型: {policy.type}")
async def _read_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> bytes:
"""
从存储服务读取文件内容
:param service: 存储服务实例
:param storage_path: 文件存储路径
:return: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
return await service.read_file(storage_path)
else:
return await service.download_file(storage_path)
async def _write_file_to_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
data: bytes,
) -> None:
"""
将文件内容写入存储服务
:param service: 存储服务实例
:param storage_path: 文件存储路径
:param data: 文件二进制内容
"""
if isinstance(service, LocalStorageService):
await service.write_file(storage_path, data)
else:
await service.upload_file(storage_path, data)
async def _delete_file_from_storage(
service: LocalStorageService | S3StorageService,
storage_path: str,
) -> None:
"""
从存储服务删除文件
:param service: 存储服务实例
:param storage_path: 文件存储路径
"""
if isinstance(service, LocalStorageService):
await service.delete_file(storage_path)
else:
await service.delete_file(storage_path)
async def migrate_single_file(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
) -> None:
"""
将单个文件对象从当前存储策略迁移到目标策略
流程:
1. 获取源物理文件和存储服务
2. 读取源文件内容
3. 在目标存储中生成新路径并写入
4. 创建新的 PhysicalFile 记录
5. 更新 Object 的 policy_id 和 physical_file_id
6. 旧 PhysicalFile 引用计数 -1如为 0 则删除源物理文件
:param session: 数据库会话
:param obj: 待迁移的文件对象(必须为文件类型)
:param dest_policy: 目标存储策略
"""
if obj.type != ObjectType.FILE:
raise ValueError(f"只能迁移文件对象,当前类型: {obj.type}")
# 获取源策略和物理文件
src_policy: Policy = await obj.awaitable_attrs.policy
old_physical: PhysicalFile | None = await obj.awaitable_attrs.physical_file
if not old_physical:
l.warning(f"文件 {obj.id} 没有关联物理文件,跳过迁移")
return
if src_policy.id == dest_policy.id:
l.debug(f"文件 {obj.id} 已在目标策略中,跳过")
return
# 1. 从源存储读取文件
src_service = await _get_storage_service(src_policy)
data = await _read_file_from_storage(src_service, old_physical.storage_path)
# 2. 在目标存储生成新路径并写入
dest_service = await _get_storage_service(dest_policy)
_dir_path, _storage_name, new_storage_path = await dest_service.generate_file_path(
user_id=obj.owner_id,
original_filename=obj.name,
)
await _write_file_to_storage(dest_service, new_storage_path, data)
# 3. 创建新的 PhysicalFile
new_physical = PhysicalFile(
storage_path=new_storage_path,
size=old_physical.size,
checksum_md5=old_physical.checksum_md5,
policy_id=dest_policy.id,
reference_count=1,
)
new_physical = await new_physical.save(session)
# 4. 更新 Object
obj.policy_id = dest_policy.id
obj.physical_file_id = new_physical.id
obj = await obj.save(session)
# 5. 旧 PhysicalFile 引用计数 -1
old_physical.decrement_reference()
if old_physical.can_be_deleted:
# 删除源存储中的物理文件
try:
await _delete_file_from_storage(src_service, old_physical.storage_path)
except Exception as e:
l.warning(f"删除源文件失败(不影响迁移结果): {old_physical.storage_path}: {e}")
await PhysicalFile.delete(session, old_physical)
else:
old_physical = await old_physical.save(session)
l.info(f"文件迁移完成: {obj.name} ({obj.id}), {src_policy.name}{dest_policy.name}")
async def migrate_file_with_task(
session: AsyncSession,
obj: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移单个文件并更新任务状态
:param session: 数据库会话
:param obj: 待迁移的文件对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
await migrate_single_file(session, obj, dest_policy)
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
except Exception as e:
l.error(f"文件迁移任务失败: {obj.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def migrate_directory_files(
session: AsyncSession,
folder: Object,
dest_policy: Policy,
task: Task,
) -> None:
"""
迁移目录下所有文件到目标存储策略
递归遍历目录树,将所有文件迁移到目标策略。
子目录的 policy_id 同步更新。
任务进度按文件数比例更新。
:param session: 数据库会话
:param folder: 目录对象
:param dest_policy: 目标存储策略
:param task: 关联的任务记录
"""
try:
task.status = TaskStatus.RUNNING
task.progress = 0
task = await task.save(session)
# 收集所有需要迁移的文件
files_to_migrate: list[Object] = []
folders_to_update: list[Object] = []
await _collect_objects_recursive(session, folder, files_to_migrate, folders_to_update)
total = len(files_to_migrate)
migrated = 0
errors: list[str] = []
for file_obj in files_to_migrate:
try:
await migrate_single_file(session, file_obj, dest_policy)
migrated += 1
except Exception as e:
error_msg = f"{file_obj.name}: {e}"
l.error(f"迁移文件失败: {error_msg}")
errors.append(error_msg)
# 更新进度
if total > 0:
task.progress = min(99, int(migrated / total * 100))
task = await task.save(session)
# 更新所有子目录的 policy_id
for sub_folder in folders_to_update:
sub_folder.policy_id = dest_policy.id
sub_folder = await sub_folder.save(session)
# 完成任务
if errors:
task.status = TaskStatus.ERROR
task.error = f"部分文件迁移失败 ({len(errors)}/{total}): " + "; ".join(errors[:5])
else:
task.status = TaskStatus.COMPLETED
task.progress = 100
task = await task.save(session)
l.info(
f"目录迁移完成: {folder.name} ({folder.id}), "
f"成功 {migrated}/{total}, 错误 {len(errors)}"
)
except Exception as e:
l.error(f"目录迁移任务失败: {folder.id}: {e}")
task.status = TaskStatus.ERROR
task.error = str(e)[:500]
task = await task.save(session)
async def _collect_objects_recursive(
session: AsyncSession,
folder: Object,
files: list[Object],
folders: list[Object],
) -> None:
"""
递归收集目录下所有文件和子目录
:param session: 数据库会话
:param folder: 当前目录
:param files: 文件列表(输出)
:param folders: 子目录列表(输出)
"""
children: list[Object] = await Object.get_children(session, folder.owner_id, folder.id)
for child in children:
if child.type == ObjectType.FILE:
files.append(child)
elif child.type == ObjectType.FOLDER:
folders.append(child)
await _collect_objects_recursive(session, child, files, folders)

View File

@@ -23,7 +23,7 @@ import string
from datetime import datetime from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from sqlmodels.base import SQLModelBase from sqlmodel_ext import SQLModelBase
class NamingContext(SQLModelBase): class NamingContext(SQLModelBase):

505
service/storage/object.py Normal file
View File

@@ -0,0 +1,505 @@
from datetime import datetime
from uuid import UUID
from loguru import logger as l
from sqlalchemy import update as sql_update
from sqlalchemy.sql.functions import func
from middleware.dependencies import SessionDep
from .local_storage import LocalStorageService
from .s3_storage import S3StorageService
from sqlmodels import (
Object,
PhysicalFile,
Policy,
PolicyType,
User,
)
async def adjust_user_storage(
session: SessionDep,
user_id: UUID,
delta: int,
commit: bool = True,
) -> None:
"""
原子更新用户已用存储空间
使用 SQL UPDATE SET storage = GREATEST(0, storage + delta) 避免竞态条件。
:param session: 数据库会话
:param user_id: 用户UUID
:param delta: 变化量(正数增加,负数减少)
:param commit: 是否立即提交
"""
if delta == 0:
return
stmt = (
sql_update(User)
.where(User.id == user_id)
.values(storage=func.greatest(0, User.storage + delta))
)
await session.execute(stmt)
if commit:
await session.commit()
l.debug(f"用户 {user_id} 存储配额变更: {'+' if delta > 0 else ''}{delta} bytes")
# ==================== 软删除 ====================
async def soft_delete_objects(
session: SessionDep,
objects: list[Object],
) -> int:
"""
软删除对象列表
只标记顶层对象:设置 deleted_at、保存原 parent_id 到 deleted_original_parent_id、
将 parent_id 置 NULL 脱离文件树。子对象保持不变,物理文件不移动。
:param session: 数据库会话
:param objects: 待软删除的对象列表
:return: 软删除的对象数量
"""
deleted_count = 0
now = datetime.now()
for obj in objects:
obj.deleted_at = now
obj.deleted_original_parent_id = obj.parent_id
obj.parent_id = None
await obj.save(session, commit=False, refresh=False)
deleted_count += 1
await session.commit()
return deleted_count
# ==================== 恢复 ====================
async def _resolve_name_conflict(
session: SessionDep,
user_id: UUID,
parent_id: UUID,
name: str,
) -> str:
"""
解决同名冲突,返回不冲突的名称
命名规则:原名称 → 原名称 (1) → 原名称 (2) → ...
对于有扩展名的文件name.ext → name (1).ext → name (2).ext → ...
:param session: 数据库会话
:param user_id: 用户UUID
:param parent_id: 父目录UUID
:param name: 原始名称
:return: 不冲突的名称
"""
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == parent_id) &
(Object.name == name) &
(Object.deleted_at == None)
)
if not existing:
return name
# 分离文件名和扩展名
if '.' in name:
base, ext = name.rsplit('.', 1)
ext = f".{ext}"
else:
base = name
ext = ""
counter = 1
while True:
new_name = f"{base} ({counter}){ext}"
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == parent_id) &
(Object.name == new_name) &
(Object.deleted_at == None)
)
if not existing:
return new_name
counter += 1
async def restore_objects(
session: SessionDep,
objects: list[Object],
user_id: UUID,
) -> int:
"""
从回收站恢复对象
检查原父目录是否存在且未删除:
- 存在 → 恢复到原位置
- 不存在 → 恢复到用户根目录
处理同名冲突(自动重命名)。
:param session: 数据库会话
:param objects: 待恢复的对象列表(必须是回收站中的顶层对象)
:param user_id: 用户UUID
:return: 恢复的对象数量
"""
root = await Object.get_root(session, user_id)
if not root:
raise ValueError("用户根目录不存在")
restored_count = 0
for obj in objects:
if not obj.deleted_at:
continue
# 确定恢复目标目录
target_parent_id = root.id
if obj.deleted_original_parent_id:
original_parent = await Object.get(
session,
(Object.id == obj.deleted_original_parent_id) & (Object.deleted_at == None)
)
if original_parent:
target_parent_id = original_parent.id
# 解决同名冲突
resolved_name = await _resolve_name_conflict(
session, user_id, target_parent_id, obj.name
)
# 恢复对象
obj.parent_id = target_parent_id
obj.deleted_at = None
obj.deleted_original_parent_id = None
if resolved_name != obj.name:
obj.name = resolved_name
await obj.save(session, commit=False, refresh=False)
restored_count += 1
await session.commit()
return restored_count
# ==================== 永久删除 ====================
async def _collect_file_entries_all(
session: SessionDep,
user_id: UUID,
root: Object,
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
"""
BFS 收集子树中所有文件的物理文件信息(包含已删除和未删除的子对象)
只执行 SELECT 查询,不触发 commitORM 对象始终有效。
:param session: 数据库会话
:param user_id: 用户UUID
:param root: 根对象
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
"""
file_entries: list[tuple[UUID, str, UUID]] = []
total_count = 1
total_file_size = 0
# 根对象本身是文件
if root.is_file and root.physical_file_id:
file_entries.append((root.id, root.name, root.physical_file_id))
total_file_size += root.size
# BFS 遍历子目录(使用 get_all_children 包含所有子对象)
if root.is_folder:
queue: list[UUID] = [root.id]
while queue:
parent_id = queue.pop(0)
children = await Object.get_all_children(session, user_id, parent_id)
for child in children:
total_count += 1
if child.is_file and child.physical_file_id:
file_entries.append((child.id, child.name, child.physical_file_id))
total_file_size += child.size
elif child.is_folder:
queue.append(child.id)
return file_entries, total_count, total_file_size
async def permanently_delete_objects(
session: SessionDep,
objects: list[Object],
user_id: UUID,
) -> int:
"""
永久删除回收站中的对象
验证对象在回收站中deleted_at IS NOT NULL
BFS 收集所有子文件的 PhysicalFile 信息,
处理引用计数,引用为 0 时物理删除文件,
最后硬删除根 ObjectCASCADE 自动清理子对象)。
:param session: 数据库会话
:param objects: 待永久删除的对象列表
:param user_id: 用户UUID
:return: 永久删除的对象数量
"""
total_deleted = 0
for obj in objects:
if not obj.deleted_at:
l.warning(f"对象 {obj.id} 不在回收站中,跳过永久删除")
continue
root_id = obj.id
file_entries, obj_count, total_file_size = await _collect_file_entries_all(
session, user_id, obj
)
# 处理 PhysicalFile 引用计数
for obj_id, obj_name, physical_file_id in file_entries:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
if not physical_file:
continue
physical_file.decrement_reference()
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy:
try:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
s3_service = await S3StorageService.from_policy(policy)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
physical_file = await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 更新用户存储配额
if total_file_size > 0:
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
# 硬删除根对象CASCADE 自动删除所有子对象(不立即提交,避免其余对象过期)
await Object.delete(session, condition=Object.id == root_id, commit=False)
total_deleted += obj_count
# 统一提交所有变更
await session.commit()
return total_deleted
# ==================== 旧接口(保持向后兼容) ====================
async def _collect_file_entries(
session: SessionDep,
user_id: UUID,
root: Object,
) -> tuple[list[tuple[UUID, str, UUID]], int, int]:
"""
BFS 收集子树中所有文件的物理文件信息
只执行 SELECT 查询,不触发 commitORM 对象始终有效。
:param session: 数据库会话
:param user_id: 用户UUID
:param root: 根对象
:return: (文件条目列表[(obj_id, name, physical_file_id)], 总对象数, 总文件大小)
"""
file_entries: list[tuple[UUID, str, UUID]] = []
total_count = 1
total_file_size = 0
# 根对象本身是文件
if root.is_file and root.physical_file_id:
file_entries.append((root.id, root.name, root.physical_file_id))
total_file_size += root.size
# BFS 遍历子目录
if root.is_folder:
queue: list[UUID] = [root.id]
while queue:
parent_id = queue.pop(0)
children = await Object.get_children(session, user_id, parent_id)
for child in children:
total_count += 1
if child.is_file and child.physical_file_id:
file_entries.append((child.id, child.name, child.physical_file_id))
total_file_size += child.size
elif child.is_folder:
queue.append(child.id)
return file_entries, total_count, total_file_size
async def delete_object_recursive(
session: SessionDep,
obj: Object,
user_id: UUID,
) -> int:
"""
删除对象及其所有子对象(硬删除)
两阶段策略:
1. BFS 只读收集所有文件的 PhysicalFile 信息
2. 批量处理引用计数commit=False最后删除根对象触发 CASCADE
:param session: 数据库会话
:param obj: 要删除的对象
:param user_id: 用户UUID
:return: 删除的对象数量
"""
# 阶段一:只读收集(不触发任何 commit
root_id = obj.id
file_entries, total_count, total_file_size = await _collect_file_entries(session, user_id, obj)
# 阶段二:批量处理 PhysicalFile 引用(全部 commit=False
for obj_id, obj_name, physical_file_id in file_entries:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == physical_file_id)
if not physical_file:
continue
physical_file.decrement_reference()
if physical_file.can_be_deleted:
# 物理删除文件
policy = await Policy.get(session, Policy.id == physical_file.policy_id)
if policy:
try:
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
await storage_service.delete_file(physical_file.storage_path)
elif policy.type == PolicyType.S3:
options = await policy.awaitable_attrs.options
s3_service = S3StorageService(
policy,
region=options.s3_region if options else 'us-east-1',
is_path_style=options.s3_path_style if options else False,
)
await s3_service.delete_file(physical_file.storage_path)
l.debug(f"物理文件已删除: {obj_name}")
except Exception as e:
l.warning(f"物理删除文件失败: {obj_name}, 错误: {e}")
await PhysicalFile.delete(session, physical_file, commit=False)
l.debug(f"物理文件记录已删除: {physical_file.storage_path}")
else:
physical_file = await physical_file.save(session, commit=False)
l.debug(f"物理文件仍有 {physical_file.reference_count} 个引用: {physical_file.storage_path}")
# 阶段三:更新用户存储配额(与删除在同一事务中)
if total_file_size > 0:
await adjust_user_storage(session, user_id, -total_file_size, commit=False)
# 阶段四:删除根对象,数据库 CASCADE 自动删除所有子对象
# commit=True默认一次性提交所有 PhysicalFile 变更 + Object 删除 + 配额更新
await Object.delete(session, condition=Object.id == root_id)
return total_count
# ==================== 复制 ====================
async def _copy_object_recursive(
session: SessionDep,
src: Object,
dst_parent_id: UUID,
user_id: UUID,
) -> tuple[int, list[UUID], int]:
"""
递归复制对象(内部实现)
:param session: 数据库会话
:param src: 源对象
:param dst_parent_id: 目标父目录UUID
:param user_id: 用户UUID
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
"""
copied_count = 0
new_ids: list[UUID] = []
total_copied_size = 0
# 在 save() 之前保存需要的属性值,避免 commit 后对象过期导致懒加载失败
src_is_folder = src.is_folder
src_is_file = src.is_file
src_id = src.id
src_size = src.size
src_physical_file_id = src.physical_file_id
# 创建新的 Object 记录
new_obj = Object(
name=src.name,
type=src.type,
size=src.size,
password=src.password,
parent_id=dst_parent_id,
owner_id=user_id,
policy_id=src.policy_id,
physical_file_id=src.physical_file_id,
)
# 如果是文件,增加物理文件引用计数
if src_is_file and src_physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == src_physical_file_id)
if physical_file:
physical_file.increment_reference()
physical_file = await physical_file.save(session)
total_copied_size += src_size
new_obj = await new_obj.save(session)
copied_count += 1
new_ids.append(new_obj.id)
# 如果是目录,递归复制子对象
if src_is_folder:
children = await Object.get_children(session, user_id, src_id)
for child in children:
child_count, child_ids, child_size = await _copy_object_recursive(
session, child, new_obj.id, user_id
)
copied_count += child_count
new_ids.extend(child_ids)
total_copied_size += child_size
return copied_count, new_ids, total_copied_size
async def copy_object_recursive(
session: SessionDep,
src: Object,
dst_parent_id: UUID,
user_id: UUID,
) -> tuple[int, list[UUID], int]:
"""
递归复制对象
对于文件:
- 增加 PhysicalFile 引用计数
- 创建新的 Object 记录指向同一 PhysicalFile
对于目录:
- 创建新目录
- 递归复制所有子对象
:param session: 数据库会话
:param src: 源对象
:param dst_parent_id: 目标父目录UUID
:param user_id: 用户UUID
:return: (复制数量, 新对象UUID列表, 复制的总文件大小)
"""
return await _copy_object_recursive(session, src, dst_parent_id, user_id)

View File

@@ -0,0 +1,709 @@
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息,提供文件上传/下载/删除及分片上传功能。
移植自 foxline-pro-backend-server 项目的 S3APIClient
适配 DiskNext 现有的 Service 架构(与 LocalStorageService 平行)。
"""
import hashlib
import hmac
import xml.etree.ElementTree as ET
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import ClassVar, Literal
from urllib.parse import quote, urlencode
from uuid import UUID
import aiohttp
from yarl import URL
from loguru import logger as l
from sqlmodels.policy import Policy
from .exceptions import S3APIError, S3MultipartUploadError
from .naming_rule import NamingContext, NamingRuleParser
def _sign(key: bytes, msg: str) -> bytes:
"""HMAC-SHA256 签名"""
return hmac.new(key, msg.encode(), hashlib.sha256).digest()
_NS_AWS = "http://s3.amazonaws.com/doc/2006-03-01/"
class S3StorageService:
"""
S3 存储服务
使用 AWS Signature V4 签名的异步 S3 API 客户端。
从 Policy 配置中读取 S3 连接信息。
使用示例::
service = S3StorageService(policy, region='us-east-1')
await service.upload_file('path/to/file.txt', b'content')
data = await service.download_file('path/to/file.txt')
"""
_http_session: ClassVar[aiohttp.ClientSession | None] = None
def __init__(
self,
policy: Policy,
region: str = 'us-east-1',
is_path_style: bool = False,
):
"""
:param policy: 存储策略server=endpoint_url, bucket_name, access_key, secret_key
:param region: S3 区域
:param is_path_style: 是否使用路径风格 URL
"""
if not policy.server:
raise S3APIError("S3 策略必须指定 server (endpoint URL)")
if not policy.bucket_name:
raise S3APIError("S3 策略必须指定 bucket_name")
if not policy.access_key:
raise S3APIError("S3 策略必须指定 access_key")
if not policy.secret_key:
raise S3APIError("S3 策略必须指定 secret_key")
self._policy = policy
self._endpoint_url = policy.server.rstrip("/")
self._bucket_name = policy.bucket_name
self._access_key = policy.access_key
self._secret_key = policy.secret_key
self._region = region
self._is_path_style = is_path_style
self._base_url = policy.base_url
# 从 endpoint_url 提取 host
self._host = self._endpoint_url.replace("https://", "").replace("http://", "").split("/")[0]
# ==================== 工厂方法 ====================
@classmethod
async def from_policy(cls, policy: Policy) -> 'S3StorageService':
"""
根据 Policy 异步创建 S3StorageService自动加载 options
:param policy: 存储策略
:return: S3StorageService 实例
"""
options = await policy.awaitable_attrs.options
region = options.s3_region if options else 'us-east-1'
is_path_style = options.s3_path_style if options else False
return cls(policy, region=region, is_path_style=is_path_style)
# ==================== HTTP Session 管理 ====================
@classmethod
async def initialize_session(cls) -> None:
"""初始化全局 aiohttp ClientSession"""
if cls._http_session is None or cls._http_session.closed:
cls._http_session = aiohttp.ClientSession()
l.info("S3StorageService HTTP session 已初始化")
@classmethod
async def close_session(cls) -> None:
"""关闭全局 aiohttp ClientSession"""
if cls._http_session and not cls._http_session.closed:
await cls._http_session.close()
cls._http_session = None
l.info("S3StorageService HTTP session 已关闭")
@classmethod
def _get_session(cls) -> aiohttp.ClientSession:
"""获取 HTTP session"""
if cls._http_session is None or cls._http_session.closed:
# 懒初始化,以防 initialize_session 未被调用
cls._http_session = aiohttp.ClientSession()
return cls._http_session
# ==================== AWS Signature V4 签名 ====================
def _get_signature_key(self, date_stamp: str) -> bytes:
"""生成 AWS Signature V4 签名密钥"""
k_date = _sign(f"AWS4{self._secret_key}".encode(), date_stamp)
k_region = _sign(k_date, self._region)
k_service = _sign(k_region, "s3")
return _sign(k_service, "aws4_request")
def _create_authorization_header(
self,
method: str,
uri: str,
query_string: str,
headers: dict[str, str],
payload_hash: str,
amz_date: str,
date_stamp: str,
) -> str:
"""创建 AWS Signature V4 授权头"""
signed_headers = ";".join(sorted(k.lower() for k in headers.keys()))
canonical_headers = "".join(
f"{k.lower()}:{v.strip()}\n" for k, v in sorted(headers.items())
)
canonical_request = (
f"{method}\n{uri}\n{query_string}\n{canonical_headers}\n"
f"{signed_headers}\n{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
string_to_sign = (
f"{algorithm}\n{amz_date}\n{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
return (
f"{algorithm} Credential={self._access_key}/{credential_scope}, "
f"SignedHeaders={signed_headers}, Signature={signature}"
)
def _build_headers(
self,
method: str,
uri: str,
query_string: str = "",
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
payload_hash: str | None = None,
host: str | None = None,
) -> dict[str, str]:
"""
构建包含 AWS V4 签名的完整请求头
:param method: HTTP 方法
:param uri: 请求 URI
:param query_string: 查询字符串
:param payload: 请求体字节(用于计算哈希)
:param content_type: Content-Type
:param extra_headers: 额外请求头
:param payload_hash: 预计算的 payload 哈希,流式上传时传 "UNSIGNED-PAYLOAD"
:param host: Host 头(默认使用 self._host
"""
now_utc = datetime.now(timezone.utc)
amz_date = now_utc.strftime("%Y%m%dT%H%M%SZ")
date_stamp = now_utc.strftime("%Y%m%d")
if payload_hash is None:
payload_hash = hashlib.sha256(payload).hexdigest()
effective_host = host or self._host
headers: dict[str, str] = {
"Host": effective_host,
"X-Amz-Date": amz_date,
"X-Amz-Content-Sha256": payload_hash,
}
if content_type:
headers["Content-Type"] = content_type
if extra_headers:
headers.update(extra_headers)
authorization = self._create_authorization_header(
method, uri, query_string, headers, payload_hash, amz_date, date_stamp
)
headers["Authorization"] = authorization
return headers
# ==================== 内部请求方法 ====================
def _build_uri(self, key: str | None = None) -> str:
"""
构建请求 URI
按 AWS S3 Signature V4 规范对路径进行 URI 编码S3 仅需一次)。
斜杠作为路径分隔符保留不编码。
"""
if self._is_path_style:
if key:
return f"/{self._bucket_name}/{quote(key, safe='/')}"
return f"/{self._bucket_name}"
else:
if key:
return f"/{quote(key, safe='/')}"
return "/"
def _build_url(self, uri: str, query_string: str = "") -> str:
"""构建完整请求 URL"""
if self._is_path_style:
base = self._endpoint_url
else:
# 虚拟主机风格bucket.endpoint
protocol = "https://" if self._endpoint_url.startswith("https://") else "http://"
base = f"{protocol}{self._bucket_name}.{self._host}"
url = f"{base}{uri}"
if query_string:
url = f"{url}?{query_string}"
return url
def _get_effective_host(self) -> str:
"""获取实际请求的 Host 头"""
if self._is_path_style:
return self._host
return f"{self._bucket_name}.{self._host}"
async def _request(
self,
method: str,
key: str | None = None,
query_params: dict[str, str] | None = None,
payload: bytes = b"",
content_type: str | None = None,
extra_headers: dict[str, str] | None = None,
) -> aiohttp.ClientResponse:
"""发送签名请求"""
uri = self._build_uri(key)
query_string = urlencode(sorted(query_params.items())) if query_params else ""
effective_host = self._get_effective_host()
headers = self._build_headers(
method, uri, query_string, payload, content_type,
extra_headers, host=effective_host,
)
url = self._build_url(uri, query_string)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=payload if payload else None,
)
return response
except Exception as e:
raise S3APIError(f"S3 请求失败: {method} {url}: {e}") from e
async def _request_streaming(
self,
method: str,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> aiohttp.ClientResponse:
"""
发送流式签名请求(大文件上传)
使用 UNSIGNED-PAYLOAD 作为 payload hash。
"""
uri = self._build_uri(key)
effective_host = self._get_effective_host()
headers = self._build_headers(
method,
uri,
query_string="",
content_type=content_type,
extra_headers={"Content-Length": str(content_length)},
payload_hash="UNSIGNED-PAYLOAD",
host=effective_host,
)
url = self._build_url(uri)
try:
response = await self._get_session().request(
method, URL(url, encoded=True),
headers=headers, data=data_stream,
)
return response
except Exception as e:
raise S3APIError(f"S3 流式请求失败: {method} {url}: {e}") from e
# ==================== 文件操作 ====================
async def upload_file(
self,
key: str,
data: bytes,
content_type: str = 'application/octet-stream',
) -> None:
"""
上传文件
:param key: S3 对象键
:param data: 文件内容
:param content_type: MIME 类型
"""
async with await self._request(
"PUT", key=key, payload=data, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 上传成功: {self._bucket_name}/{key}")
async def upload_file_streaming(
self,
key: str,
data_stream: AsyncIterator[bytes],
content_length: int,
content_type: str | None = None,
) -> None:
"""
流式上传文件(大文件,避免全部加载到内存)
:param key: S3 对象键
:param data_stream: 异步字节流迭代器
:param content_length: 数据总长度(必须准确)
:param content_type: MIME 类型
"""
async with await self._request_streaming(
"PUT", key=key, data_stream=data_stream,
content_length=content_length, content_type=content_type,
) as response:
if response.status not in (200, 201):
body = await response.text()
raise S3APIError(
f"流式上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.debug(f"S3 流式上传成功: {self._bucket_name}/{key}, 大小: {content_length}")
async def download_file(self, key: str) -> bytes:
"""
下载文件
:param key: S3 对象键
:return: 文件内容
"""
async with await self._request("GET", key=key) as response:
if response.status != 200:
body = await response.text()
raise S3APIError(
f"下载失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
data = await response.read()
l.debug(f"S3 下载成功: {self._bucket_name}/{key}, 大小: {len(data)}")
return data
async def delete_file(self, key: str) -> None:
"""
删除文件
:param key: S3 对象键
"""
async with await self._request("DELETE", key=key) as response:
if response.status in (200, 204):
l.debug(f"S3 删除成功: {self._bucket_name}/{key}")
else:
body = await response.text()
raise S3APIError(
f"删除失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
async def file_exists(self, key: str) -> bool:
"""
检查文件是否存在
:param key: S3 对象键
:return: 是否存在
"""
async with await self._request("HEAD", key=key) as response:
if response.status == 200:
return True
elif response.status == 404:
return False
else:
raise S3APIError(
f"检查文件存在性失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
async def get_file_size(self, key: str) -> int:
"""
获取文件大小
:param key: S3 对象键
:return: 文件大小(字节)
"""
async with await self._request("HEAD", key=key) as response:
if response.status != 200:
raise S3APIError(
f"获取文件信息失败: {self._bucket_name}/{key}, 状态: {response.status}"
)
return int(response.headers.get("Content-Length", 0))
# ==================== Multipart Upload ====================
async def create_multipart_upload(
self,
key: str,
content_type: str = 'application/octet-stream',
) -> str:
"""
创建分片上传任务
:param key: S3 对象键
:param content_type: MIME 类型
:return: Upload ID
"""
async with await self._request(
"POST",
key=key,
query_params={"uploads": ""},
content_type=content_type,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"创建分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
body = await response.text()
root = ET.fromstring(body)
# 查找 UploadId 元素(支持命名空间)
upload_id_elem = root.find("UploadId")
if upload_id_elem is None:
upload_id_elem = root.find(f"{{{_NS_AWS}}}UploadId")
if upload_id_elem is None or not upload_id_elem.text:
raise S3MultipartUploadError(
f"创建分片上传响应中未找到 UploadId: {body}"
)
upload_id = upload_id_elem.text
l.debug(f"S3 分片上传已创建: {self._bucket_name}/{key}, upload_id={upload_id}")
return upload_id
async def upload_part(
self,
key: str,
upload_id: str,
part_number: int,
data: bytes,
) -> str:
"""
上传单个分片
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param part_number: 分片编号(从 1 开始)
:param data: 分片数据
:return: ETag
"""
async with await self._request(
"PUT",
key=key,
query_params={
"partNumber": str(part_number),
"uploadId": upload_id,
},
payload=data,
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"上传分片失败: {self._bucket_name}/{key}, "
f"part={part_number}, 状态: {response.status}, {body}"
)
etag = response.headers.get("ETag", "").strip('"')
l.debug(
f"S3 分片上传成功: {self._bucket_name}/{key}, "
f"part={part_number}, etag={etag}"
)
return etag
async def complete_multipart_upload(
self,
key: str,
upload_id: str,
parts: list[tuple[int, str]],
) -> None:
"""
完成分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
:param parts: 分片列表 [(part_number, etag)]
"""
# 按 part_number 排序
parts_sorted = sorted(parts, key=lambda p: p[0])
# 构建 CompleteMultipartUpload XML
xml_parts = ''.join(
f"<Part><PartNumber>{pn}</PartNumber><ETag>{etag}</ETag></Part>"
for pn, etag in parts_sorted
)
payload = f'<?xml version="1.0" encoding="UTF-8"?><CompleteMultipartUpload>{xml_parts}</CompleteMultipartUpload>'
payload_bytes = payload.encode('utf-8')
async with await self._request(
"POST",
key=key,
query_params={"uploadId": upload_id},
payload=payload_bytes,
content_type="application/xml",
) as response:
if response.status != 200:
body = await response.text()
raise S3MultipartUploadError(
f"完成分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
l.info(
f"S3 分片上传已完成: {self._bucket_name}/{key}, "
f"{len(parts)} 个分片"
)
async def abort_multipart_upload(self, key: str, upload_id: str) -> None:
"""
取消分片上传
:param key: S3 对象键
:param upload_id: 分片上传 ID
"""
async with await self._request(
"DELETE",
key=key,
query_params={"uploadId": upload_id},
) as response:
if response.status in (200, 204):
l.debug(f"S3 分片上传已取消: {self._bucket_name}/{key}")
else:
body = await response.text()
l.warning(
f"取消分片上传失败: {self._bucket_name}/{key}, "
f"状态: {response.status}, {body}"
)
# ==================== 预签名 URL ====================
def generate_presigned_url(
self,
key: str,
method: Literal['GET', 'PUT'] = 'GET',
expires_in: int = 3600,
filename: str | None = None,
) -> str:
"""
生成 S3 预签名 URLAWS Signature V4 Query String
:param key: S3 对象键
:param method: HTTP 方法GET 下载PUT 上传)
:param expires_in: URL 有效期(秒)
:param filename: 文件名GET 请求时设置 Content-Disposition
:return: 预签名 URL
"""
current_time = datetime.now(timezone.utc)
amz_date = current_time.strftime("%Y%m%dT%H%M%SZ")
date_stamp = current_time.strftime("%Y%m%d")
credential_scope = f"{date_stamp}/{self._region}/s3/aws4_request"
credential = f"{self._access_key}/{credential_scope}"
uri = self._build_uri(key)
effective_host = self._get_effective_host()
query_params: dict[str, str] = {
'X-Amz-Algorithm': 'AWS4-HMAC-SHA256',
'X-Amz-Credential': credential,
'X-Amz-Date': amz_date,
'X-Amz-Expires': str(expires_in),
'X-Amz-SignedHeaders': 'host',
}
# GET 请求时添加 Content-Disposition
if method == "GET" and filename:
encoded_filename = quote(filename, safe='')
query_params['response-content-disposition'] = (
f"attachment; filename*=UTF-8''{encoded_filename}"
)
canonical_query_string = "&".join(
f"{quote(k, safe='')}={quote(v, safe='')}"
for k, v in sorted(query_params.items())
)
canonical_headers = f"host:{effective_host}\n"
signed_headers = "host"
payload_hash = "UNSIGNED-PAYLOAD"
canonical_request = (
f"{method}\n"
f"{uri}\n"
f"{canonical_query_string}\n"
f"{canonical_headers}\n"
f"{signed_headers}\n"
f"{payload_hash}"
)
algorithm = "AWS4-HMAC-SHA256"
string_to_sign = (
f"{algorithm}\n"
f"{amz_date}\n"
f"{credential_scope}\n"
f"{hashlib.sha256(canonical_request.encode()).hexdigest()}"
)
signing_key = self._get_signature_key(date_stamp)
signature = hmac.new(
signing_key, string_to_sign.encode(), hashlib.sha256
).hexdigest()
base_url = self._build_url(uri)
return (
f"{base_url}?"
f"{canonical_query_string}&"
f"X-Amz-Signature={signature}"
)
# ==================== 路径生成 ====================
async def generate_file_path(
self,
user_id: UUID,
original_filename: str,
) -> tuple[str, str, str]:
"""
根据命名规则生成 S3 文件存储路径
与 LocalStorageService.generate_file_path 接口一致。
:param user_id: 用户UUID
:param original_filename: 原始文件名
:return: (相对目录路径, 存储文件名, 完整存储路径)
"""
context = NamingContext(
user_id=user_id,
original_filename=original_filename,
)
# 解析目录规则
dir_path = ""
if self._policy.dir_name_rule:
dir_path = NamingRuleParser.parse(self._policy.dir_name_rule, context)
# 解析文件名规则
if self._policy.auto_rename and self._policy.file_name_rule:
storage_name = NamingRuleParser.parse(self._policy.file_name_rule, context)
# 确保有扩展名
if '.' in original_filename and '.' not in storage_name:
ext = original_filename.rsplit('.', 1)[1]
storage_name = f"{storage_name}.{ext}"
else:
storage_name = original_filename
# S3 不需要创建目录,直接拼接路径
if dir_path:
storage_path = f"{dir_path}/{storage_name}"
else:
storage_path = storage_name
return dir_path, storage_name, storage_path

View File

@@ -1 +1 @@
from .login import login from .login import unified_login

View File

@@ -1,83 +1,428 @@
from uuid import uuid4 """
统一登录服务
from loguru import logger 支持多种认证方式邮箱密码、GitHub OAuth、QQ OAuth、Passkey、Magic Link、手机短信预留
"""
import hashlib
from uuid import UUID, uuid4
from middleware.dependencies import SessionDep from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from sqlmodels import LoginRequest, TokenResponse, User from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from service.redis.token_store import TokenStore
from sqlmodels.auth_identity import AuthIdentity, AuthProviderType
from sqlmodels.group import GroupClaims, GroupOptions from sqlmodels.group import GroupClaims, GroupOptions
from sqlmodels.user import UserStatus from sqlmodels.object import Object, ObjectType
from utils import http_exceptions from sqlmodels.policy import Policy
from utils.JWT import create_access_token, create_refresh_token from sqlmodels.setting import Setting, SettingsType
from sqlmodels.user import TokenResponse, UnifiedLoginRequest, User, UserStatus
from utils import JWT, http_exceptions
from utils.password.pwd import Password, PasswordStatus from utils.password.pwd import Password, PasswordStatus
async def login( async def unified_login(
session: SessionDep, session: AsyncSession,
login_request: LoginRequest, request: UnifiedLoginRequest,
) -> TokenResponse: ) -> TokenResponse:
""" """
根据账号密码进行登录 统一登录入口,根据 provider 分发到不同的登录逻辑
如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。
:param session: 数据库会话 :param session: 数据库会话
:param login_request: 登录请求 :param request: 统一登录请求
:return: TokenResponse
:return: TokenResponse 对象或状态码或 None
""" """
# 获取用户信息(预加载 group 关系) await _check_provider_enabled(session, request.provider)
current_user: User = await User.get(
match request.provider:
case AuthProviderType.EMAIL_PASSWORD:
user = await _login_email_password(session, request)
case AuthProviderType.GITHUB:
user = await _login_oauth(session, request, AuthProviderType.GITHUB)
case AuthProviderType.QQ:
user = await _login_oauth(session, request, AuthProviderType.QQ)
case AuthProviderType.PASSKEY:
user = await _login_passkey(session, request)
case AuthProviderType.MAGIC_LINK:
user = await _login_magic_link(session, request)
case AuthProviderType.PHONE_SMS:
http_exceptions.raise_not_implemented("短信登录暂未开放")
case _:
http_exceptions.raise_bad_request(f"不支持的登录方式: {request.provider}")
return await _issue_tokens(session, user)
async def _check_provider_enabled(session: AsyncSession, provider: AuthProviderType) -> None:
"""检查认证方式是否已被站长启用"""
# OAuth 类型从 OAUTH 设置中查询
if provider in (AuthProviderType.GITHUB, AuthProviderType.QQ):
setting_name = f"{provider.value}_enabled"
setting = await Setting.get(
session, session,
User.email == login_request.email, (Setting.type == SettingsType.OAUTH) & (Setting.name == setting_name),
fetch_mode="first", )
load=User.group, if not setting or setting.value != "1":
) #type: ignore http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
return
# 验证用户是否存在 # 其他类型从 AUTH 设置中查询
if not current_user: setting_name = f"auth_{provider.value}_enabled"
logger.debug(f"Cannot find user with email: {login_request.email}") setting = await Setting.get(
http_exceptions.raise_unauthorized("Invalid email or password") session,
(Setting.type == SettingsType.AUTH) & (Setting.name == setting_name),
)
if not setting or setting.value != "1":
http_exceptions.raise_bad_request(f"登录方式 {provider.value} 未启用")
# 验证密码是否正确
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
logger.debug(f"Password verification failed for user: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid email or password")
# 验证用户是否可登录修复显式枚举比较StrEnum 永远 truthy async def _login_email_password(
if current_user.status != UserStatus.ACTIVE: session: AsyncSession,
http_exceptions.raise_forbidden("Your account is disabled") request: UnifiedLoginRequest,
) -> User:
"""邮箱+密码登录"""
if not request.credential:
http_exceptions.raise_bad_request("密码不能为空")
# 检查两步验证 # 查找 AuthIdentity
if current_user.two_factor: identity: AuthIdentity | None = await AuthIdentity.get(
# 用户已启用两步验证 session,
if not login_request.two_fa_code: (AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
logger.debug(f"2FA required for user: {login_request.email}") & (AuthIdentity.identifier == request.identifier),
http_exceptions.raise_precondition_required("2FA required") )
if not identity:
l.debug(f"未找到邮箱密码身份: {request.identifier}")
http_exceptions.raise_unauthorized("邮箱或密码错误")
# 验证 OTP # 验证
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: if not identity.credential:
logger.debug(f"Invalid 2FA code for user: {login_request.email}") http_exceptions.raise_unauthorized("邮箱或密码错误")
http_exceptions.raise_unauthorized("Invalid 2FA code")
if Password.verify(identity.credential, request.credential) != PasswordStatus.VALID:
l.debug(f"密码验证失败: {request.identifier}")
http_exceptions.raise_unauthorized("邮箱或密码错误")
# 加载用户
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
# 验证用户状态
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 检查两步验证(从 AuthIdentity.extra_data 中读取 2FA secret
if identity.extra_data:
import orjson
extra: dict = orjson.loads(identity.extra_data)
two_factor_secret: str | None = extra.get("two_factor")
if two_factor_secret:
if not request.two_fa_code:
l.debug(f"需要两步验证: {request.identifier}")
http_exceptions.raise_precondition_required("需要两步验证")
if Password.verify_totp(two_factor_secret, request.two_fa_code) != PasswordStatus.VALID:
l.debug(f"两步验证失败: {request.identifier}")
http_exceptions.raise_unauthorized("两步验证码错误")
return user
async def _login_oauth(
session: AsyncSession,
request: UnifiedLoginRequest,
provider: AuthProviderType,
) -> User:
"""
OAuth 登录GitHub / QQ
identifier 为 OAuth authorization code后端换取 access_token 再获取用户信息。
"""
# 读取 OAuth 配置
client_id_setting = await Setting.get(
session,
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_id"),
)
client_secret_setting = await Setting.get(
session,
(Setting.type == SettingsType.OAUTH) & (Setting.name == f"{provider.value}_client_secret"),
)
if not client_id_setting or not client_secret_setting:
http_exceptions.raise_bad_request(f"{provider.value} OAuth 未配置")
client_id = client_id_setting.value or ""
client_secret = client_secret_setting.value or ""
# 根据 provider 创建对应的 OAuth 客户端
if provider == AuthProviderType.GITHUB:
from service.oauth import GithubOAuth
oauth_client = GithubOAuth(client_id, client_secret)
token_resp = await oauth_client.get_access_token(code=request.identifier)
user_info_resp = await oauth_client.get_user_info(token_resp)
openid = str(user_info_resp.user_data.id)
nickname = user_info_resp.user_data.name or user_info_resp.user_data.login
avatar_url = user_info_resp.user_data.avatar_url
email = user_info_resp.user_data.email
elif provider == AuthProviderType.QQ:
from service.oauth import QQOAuth
oauth_client = QQOAuth(client_id, client_secret)
token_resp = await oauth_client.get_access_token(
code=request.identifier,
redirect_uri=request.redirect_uri or "",
)
openid_resp = await oauth_client.get_openid(token_resp.access_token)
user_info_resp = await oauth_client.get_user_info(
token_resp,
app_id=client_id,
openid=openid_resp.openid,
)
openid = openid_resp.openid
nickname = user_info_resp.user_data.nickname
avatar_url = user_info_resp.user_data.figureurl_qq_2 or user_info_resp.user_data.figureurl_2
email = None
else:
http_exceptions.raise_bad_request(f"不支持的 OAuth 提供者: {provider.value}")
# 查找已有 AuthIdentity
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.provider == provider) & (AuthIdentity.identifier == openid),
)
if identity:
# 已绑定 → 更新 OAuth 信息并返回关联用户
identity.display_name = nickname
identity.avatar_url = avatar_url
identity = await identity.save(session)
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return user
# 未绑定 → 自动注册
user = await _auto_register_oauth_user(
session,
provider=provider,
openid=openid,
nickname=nickname,
avatar_url=avatar_url,
email=email,
)
return user
async def _auto_register_oauth_user(
session: AsyncSession,
*,
provider: AuthProviderType,
openid: str,
nickname: str | None,
avatar_url: str | None,
email: str | None,
) -> User:
"""OAuth 自动注册用户"""
# 获取默认用户组
default_group_setting = await Setting.get(
session,
(Setting.type == SettingsType.REGISTER) & (Setting.name == "default_group"),
)
if not default_group_setting or not default_group_setting.value:
l.error("默认用户组未配置")
http_exceptions.raise_internal_error()
default_group_id = UUID(default_group_setting.value)
# 创建用户
new_user = User(
email=email,
nickname=nickname,
avatar=avatar_url or "default",
group_id=default_group_id,
)
new_user_id = new_user.id
new_user = await new_user.save(session)
# 创建 AuthIdentity
identity = AuthIdentity(
provider=provider,
identifier=openid,
display_name=nickname,
avatar_url=avatar_url,
is_primary=True,
is_verified=True,
user_id=new_user_id,
)
identity = await identity.save(session)
# 创建用户根目录
default_policy = await Policy.get(session, Policy.name == "本地存储")
if default_policy:
await Object(
name="/",
type=ObjectType.FOLDER,
owner_id=new_user_id,
parent_id=None,
policy_id=default_policy.id,
).save(session)
# 重新加载用户(含 group 关系)
user: User = await User.get(session, User.id == new_user_id, load=User.group)
l.info(f"OAuth 自动注册用户: provider={provider.value}, openid={openid}")
return user
async def _login_passkey(
session: AsyncSession,
request: UnifiedLoginRequest,
) -> User:
"""
Passkey/WebAuthn 登录Discoverable Credentials 模式)
identifier 为 challenge_token前端从 ``POST /authn/options`` 获取),
credential 为 JSON 格式的 authenticator assertion response。
"""
from webauthn import verify_authentication_response
from webauthn.helpers import base64url_to_bytes
from service.redis.challenge_store import ChallengeStore
from service.webauthn import get_rp_config
from sqlmodels.user_authn import UserAuthn
if not request.credential:
http_exceptions.raise_bad_request("WebAuthn assertion response 不能为空")
if not request.identifier:
http_exceptions.raise_bad_request("challenge_token 不能为空")
# 从 ChallengeStore 取出 challenge一次性防重放
challenge: bytes | None = await ChallengeStore.retrieve_and_delete(f"auth:{request.identifier}")
if challenge is None:
http_exceptions.raise_unauthorized("登录会话已过期,请重新获取 options")
# 从 assertion JSON 中解析 credential_idDiscoverable Credentials 模式)
import orjson
credential_dict: dict = orjson.loads(request.credential)
credential_id_b64: str | None = credential_dict.get("id")
if not credential_id_b64:
http_exceptions.raise_bad_request("缺少凭证 ID")
# 查找 UserAuthn 记录
authn: UserAuthn | None = await UserAuthn.get(
session,
UserAuthn.credential_id == credential_id_b64,
)
if not authn:
http_exceptions.raise_unauthorized("Passkey 凭证未注册")
# 获取 RP 配置
rp_id, _rp_name, origin = await get_rp_config(session)
# 验证 WebAuthn assertion
try:
verification = verify_authentication_response(
credential=request.credential,
expected_rp_id=rp_id,
expected_origin=origin,
expected_challenge=challenge,
credential_public_key=base64url_to_bytes(authn.credential_public_key),
credential_current_sign_count=authn.sign_count,
)
except Exception as e:
l.warning(f"WebAuthn 验证失败: {e}")
http_exceptions.raise_unauthorized("Passkey 验证失败")
# 更新签名计数
authn.sign_count = verification.new_sign_count
authn = await authn.save(session)
# 加载用户
user: User = await User.get(session, User.id == authn.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
return user
async def _login_magic_link(
session: AsyncSession,
request: UnifiedLoginRequest,
) -> User:
"""
Magic Link 登录
identifier 为签名 token由 itsdangerous 生成。
"""
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
try:
email = serializer.loads(request.identifier, salt="magic-link-salt", max_age=600)
except SignatureExpired:
http_exceptions.raise_unauthorized("Magic Link 已过期")
except BadSignature:
http_exceptions.raise_unauthorized("Magic Link 无效")
# 防重放:使用 token 哈希作为标识符
token_hash = hashlib.sha256(request.identifier.encode()).hexdigest()
is_first_use = await TokenStore.mark_used(f"magic_link:{token_hash}", ttl=600)
if not is_first_use:
http_exceptions.raise_unauthorized("Magic Link 已被使用")
# 查找绑定了该邮箱的 AuthIdentityemail_password 或 magic_link
identity: AuthIdentity | None = await AuthIdentity.get(
session,
(AuthIdentity.identifier == email)
& (
(AuthIdentity.provider == AuthProviderType.EMAIL_PASSWORD)
| (AuthIdentity.provider == AuthProviderType.MAGIC_LINK)
),
)
if not identity:
http_exceptions.raise_unauthorized("该邮箱未注册")
user: User = await User.get(session, User.id == identity.user_id, load=User.group)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if user.status != UserStatus.ACTIVE:
http_exceptions.raise_forbidden("账户已被禁用")
# 标记邮箱已验证
if not identity.is_verified:
identity.is_verified = True
identity = await identity.save(session)
return user
async def _issue_tokens(session: AsyncSession, user: User) -> TokenResponse:
"""
签发 JWT 双令牌access + refresh
提取自原 login.py 的签发逻辑,供所有 provider 共用。
"""
# 加载 GroupOptions # 加载 GroupOptions
group_options: GroupOptions | None = await GroupOptions.get( group_options: GroupOptions | None = await GroupOptions.get(
session, session,
GroupOptions.group_id == current_user.group_id, GroupOptions.group_id == user.group_id,
) )
# 构建权限快照 # 构建权限快照
current_user.group.options = group_options user.group.options = group_options
group_claims = GroupClaims.from_group(current_user.group) group_claims = GroupClaims.from_group(user.group)
# 创建令牌 # 创建令牌
access_token = create_access_token( access_token = JWT.create_access_token(
sub=current_user.id, sub=user.id,
jti=uuid4(), jti=uuid4(),
status=current_user.status.value, status=user.status.value,
group=group_claims, group=group_claims,
) )
refresh_token = create_refresh_token( refresh_token = JWT.create_refresh_token(
sub=current_user.id, sub=user.id,
jti=uuid4() jti=uuid4(),
) )
return TokenResponse( return TokenResponse(

41
service/webauthn.py Normal file
View File

@@ -0,0 +1,41 @@
"""
WebAuthn RPRelying Party配置辅助模块
从数据库 Setting 中读取 siteURL / siteTitle
解析出 rp_id、rp_name、origin供注册/登录流程复用。
"""
from urllib.parse import urlparse
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodels.setting import Setting, SettingsType
async def get_rp_config(session: AsyncSession) -> tuple[str, str, str]:
"""
获取 WebAuthn RP 配置。
:param session: 数据库会话
:return: ``(rp_id, rp_name, origin)`` 元组
- ``rp_id``: 站点域名(从 siteURL 解析,如 ``example.com``
- ``rp_name``: 站点标题
- ``origin``: 完整 origin如 ``https://example.com``
"""
site_url_setting: Setting | None = await Setting.get(
session,
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
)
site_title_setting: Setting | None = await Setting.get(
session,
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteTitle"),
)
site_url: str = site_url_setting.value if site_url_setting and site_url_setting.value else "https://localhost"
rp_name: str = site_title_setting.value if site_title_setting and site_title_setting.value else "DiskNext"
parsed = urlparse(site_url)
rp_id: str = parsed.hostname or "localhost"
origin: str = f"{parsed.scheme}://{parsed.netloc}" if parsed.netloc else site_url
return rp_id, rp_name, origin

185
service/wopi/__init__.py Normal file
View File

@@ -0,0 +1,185 @@
"""
WOPI Discovery 服务模块
解析 WOPI 服务端Collabora / OnlyOffice 等)的 Discovery XML
提取支持的文件扩展名及对应的编辑器 URL 模板。
参考Cloudreve pkg/wopi/discovery.go 和 pkg/wopi/wopi.go
"""
import xml.etree.ElementTree as ET
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from loguru import logger as l
# WOPI URL 模板中已知的查询参数占位符及其替换值
# 值为 None 表示删除该参数,非 None 表示替换为该值
# 参考 Cloudreve pkg/wopi/wopi.go queryPlaceholders
_WOPI_QUERY_PLACEHOLDERS: dict[str, str | None] = {
'BUSINESS_USER': None,
'DC_LLCC': 'lng',
'DISABLE_ASYNC': None,
'DISABLE_CHAT': None,
'EMBEDDED': 'true',
'FULLSCREEN': 'true',
'HOST_SESSION_ID': None,
'SESSION_CONTEXT': None,
'RECORDING': None,
'THEME_ID': 'darkmode',
'UI_LLCC': 'lng',
'VALIDATOR_TEST_CATEGORY': None,
}
_WOPI_SRC_PLACEHOLDER = 'WOPI_SOURCE'
def process_wopi_action_url(raw_urlsrc: str) -> str:
"""
将 WOPI Discovery 中的原始 urlsrc 转换为 DiskNext 可用的 URL 模板。
处理流程(参考 Cloudreve generateActionUrl
1. 去除 ``<>`` 占位符标记
2. 解析查询参数,替换/删除已知占位符
3. ``WOPI_SOURCE`` → ``{wopi_src}``
注意access_token 和 access_token_ttl 不放在 URL 中,
根据 WOPI 规范它们通过 POST 表单字段传递给编辑器。
:param raw_urlsrc: WOPI Discovery XML 中的 urlsrc 原始值
:return: 处理后的 URL 模板字符串,包含 {wopi_src} 占位符
"""
# 去除 <> 标记
cleaned = raw_urlsrc.replace('<', '').replace('>', '')
parsed = urlparse(cleaned)
raw_params = parse_qs(parsed.query, keep_blank_values=True)
new_params: list[tuple[str, str]] = []
is_src_replaced = False
for key, values in raw_params.items():
value = values[0] if values else ''
# WOPI_SOURCE 占位符 → {wopi_src}
if value == _WOPI_SRC_PLACEHOLDER:
new_params.append((key, '{wopi_src}'))
is_src_replaced = True
continue
# 已知占位符
if value in _WOPI_QUERY_PLACEHOLDERS:
replacement = _WOPI_QUERY_PLACEHOLDERS[value]
if replacement is not None:
new_params.append((key, replacement))
# replacement 为 None 时删除该参数
continue
# 其他参数保留原值
new_params.append((key, value))
# 如果没有找到 WOPI_SOURCE 占位符,手动添加 WOPISrc
if not is_src_replaced:
new_params.append(('WOPISrc', '{wopi_src}'))
# LibreOffice/Collabora 需要 lang 参数(避免重复添加)
existing_keys = {k for k, _ in new_params}
if 'lang' not in existing_keys:
new_params.append(('lang', 'lng'))
# 注意access_token 和 access_token_ttl 不放在 URL 中
# 根据 WOPI 规范,它们通过 POST 表单字段传递给编辑器
# 重建 URL
new_query = urlencode(new_params, safe='{}')
result = urlunparse((
parsed.scheme,
parsed.netloc,
parsed.path,
parsed.params,
new_query,
'',
))
return result
def parse_wopi_discovery_xml(xml_content: str) -> tuple[dict[str, str], list[str]]:
"""
解析 WOPI Discovery XML提取扩展名到 URL 模板的映射。
XML 结构::
<wopi-discovery>
<net-zone name="external-https">
<app name="Writer" favIconUrl="...">
<action name="edit" ext="docx" urlsrc="https://..."/>
<action name="view" ext="docx" urlsrc="https://..."/>
</app>
</net-zone>
</wopi-discovery>
动作优先级edit > embedview > view参考 Cloudreve discovery.go
:param xml_content: WOPI Discovery 端点返回的 XML 字符串
:return: (action_urls, app_names) 元组
action_urls: {extension: processed_url_template}
app_names: 发现的应用名称列表
:raises ValueError: XML 解析失败或格式无效
"""
try:
root = ET.fromstring(xml_content)
except ET.ParseError as e:
raise ValueError(f"WOPI Discovery XML 解析失败: {e}")
# 查找 net-zone可能有多个取第一个非空的
net_zones = root.findall('net-zone')
if not net_zones:
raise ValueError("WOPI Discovery XML 缺少 net-zone 节点")
# ext_actions: {extension: {action_name: urlsrc}}
ext_actions: dict[str, dict[str, str]] = {}
app_names: list[str] = []
for net_zone in net_zones:
for app_elem in net_zone.findall('app'):
app_name = app_elem.get('name', '')
if app_name:
app_names.append(app_name)
for action_elem in app_elem.findall('action'):
action_name = action_elem.get('name', '')
ext = action_elem.get('ext', '')
urlsrc = action_elem.get('urlsrc', '')
if not ext or not urlsrc:
continue
# 只关注 edit / embedview / view 三种动作
if action_name not in ('edit', 'embedview', 'view'):
continue
if ext not in ext_actions:
ext_actions[ext] = {}
ext_actions[ext][action_name] = urlsrc
# 为每个扩展名选择最佳 URL: edit > embedview > view
action_urls: dict[str, str] = {}
for ext, actions_map in ext_actions.items():
selected_urlsrc: str | None = None
for preferred in ('edit', 'embedview', 'view'):
if preferred in actions_map:
selected_urlsrc = actions_map[preferred]
break
if selected_urlsrc:
action_urls[ext] = process_wopi_action_url(selected_urlsrc)
# 去重 app_names
seen: set[str] = set()
unique_names: list[str] = []
for name in app_names:
if name not in seen:
seen.add(name)
unique_names.append(name)
l.info(f"WOPI Discovery 解析完成: {len(action_urls)} 个扩展名, 应用: {unique_names}")
return action_urls, unique_names

92
setup_cython.py Normal file
View File

@@ -0,0 +1,92 @@
"""
Cython 编译脚本 — 将 ee/ 下的纯逻辑文件编译为 .so
用法:
uv run --extra build python setup_cython.py build_ext --inplace
编译规则:
- 跳过 __init__.pyPython 包发现需要)
- 只编译 .py 文件(纯函数 / 服务逻辑)
编译后清理Pro Docker 构建用):
uv run --extra build python setup_cython.py clean_artifacts
"""
import shutil
import sys
from pathlib import Path
EE_DIR = Path("ee")
# 跳过 __init__.py —— 包发现需要原始 .py
SKIP_NAMES = {"__init__.py"}
def _collect_modules() -> list[str]:
"""收集 ee/ 下需要编译的 .py 文件路径(点分模块名)。"""
modules: list[str] = []
for py_file in EE_DIR.rglob("*.py"):
if py_file.name in SKIP_NAMES:
continue
# ee/license.py → ee.license
module = str(py_file.with_suffix("")).replace("\\", "/").replace("/", ".")
modules.append(module)
return modules
def clean_artifacts() -> None:
"""删除已编译的 .py 源码、.c 中间文件和 build/ 目录。"""
for py_file in EE_DIR.rglob("*.py"):
if py_file.name in SKIP_NAMES:
continue
# 只删除有对应 .so / .pyd 的源文件
parent = py_file.parent
stem = py_file.stem
has_compiled = (
any(parent.glob(f"{stem}*.so")) or
any(parent.glob(f"{stem}*.pyd"))
)
if has_compiled:
py_file.unlink()
print(f"已删除源码: {py_file}")
# 删除 .c 中间文件
for c_file in EE_DIR.rglob("*.c"):
c_file.unlink()
print(f"已删除中间文件: {c_file}")
# 删除 build/ 目录
build_dir = Path("build")
if build_dir.exists():
shutil.rmtree(build_dir)
print(f"已删除: {build_dir}/")
if __name__ == "__main__":
if len(sys.argv) > 1 and sys.argv[1] == "clean_artifacts":
clean_artifacts()
sys.exit(0)
# 动态导入(仅在编译时需要)
from Cython.Build import cythonize
from setuptools import Extension, setup
modules = _collect_modules()
if not modules:
print("未找到需要编译的模块")
sys.exit(0)
print(f"即将编译以下模块: {modules}")
extensions = [
Extension(mod, [mod.replace(".", "/") + ".py"])
for mod in modules
]
setup(
name="disknext-ee",
packages=[],
ext_modules=cythonize(
extensions,
compiler_directives={'language_level': "3"},
),
)

View File

@@ -954,18 +954,11 @@ class PolicyType(StrEnum):
S3 = "s3" # S3 兼容存储 S3 = "s3" # S3 兼容存储
``` ```
### StorageType ### PolicyType
```python ```python
class StorageType(StrEnum): class PolicyType(StrEnum):
LOCAL = "local" # 本地存储 LOCAL = "local" # 本地存储
QINIU = "qiniu" # 七牛云 S3 = "s3" # S3 兼容存储
TENCENT = "tencent" # 腾讯云
ALIYUN = "aliyun" # 阿里云
ONEDRIVE = "onedrive" # OneDrive
GOOGLE_DRIVE = "google_drive" # Google Drive
DROPBOX = "dropbox" # Dropbox
WEBDAV = "webdav" # WebDAV
REMOTE = "remote" # 远程存储
``` ```
### UserStatus ### UserStatus

View File

@@ -1,9 +1,17 @@
from .auth_identity import (
AuthIdentity,
AuthIdentityResponse,
AuthProviderType,
BindIdentityRequest,
ChangePasswordRequest,
)
from .user import ( from .user import (
BatchDeleteRequest, BatchDeleteRequest,
JWTPayload, JWTPayload,
LoginRequest, MagicLinkRequest,
UnifiedLoginRequest,
UnifiedRegisterRequest,
RefreshTokenRequest, RefreshTokenRequest,
RegisterRequest,
AccessTokenBase, AccessTokenBase,
RefreshTokenBase, RefreshTokenBase,
TokenResponse, TokenResponse,
@@ -13,14 +21,28 @@ from .user import (
UserPublic, UserPublic,
UserResponse, UserResponse,
UserSettingResponse, UserSettingResponse,
UserThemeUpdateRequest,
SettingOption,
UserSettingUpdateRequest,
WebAuthnInfo, WebAuthnInfo,
UserTwoFactorResponse,
# 管理员DTO # 管理员DTO
UserAdminUpdateRequest, UserAdminUpdateRequest,
UserCalibrateResponse, UserCalibrateResponse,
UserAdminDetailResponse, UserAdminDetailResponse,
) )
from .user_authn import AuthnResponse, UserAuthn from .user_authn import (
from .color import ThemeResponse AuthnDetailResponse,
AuthnFinishRequest,
AuthnRenameRequest,
UserAuthn,
)
from .color import ChromaticColor, NeutralColor, ThemeColorsBase, BUILTIN_DEFAULT_COLORS
from .theme_preset import (
ThemePreset, ThemePresetBase,
ThemePresetCreateRequest, ThemePresetUpdateRequest,
ThemePresetResponse, ThemePresetListResponse,
)
from .download import ( from .download import (
Download, Download,
@@ -47,18 +69,20 @@ from .object import (
CreateUploadSessionRequest, CreateUploadSessionRequest,
DirectoryCreateRequest, DirectoryCreateRequest,
DirectoryResponse, DirectoryResponse,
FileMetadata,
FileMetadataBase,
Object, Object,
ObjectBase, ObjectBase,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
ObjectFileFinalize,
ObjectMoveRequest, ObjectMoveRequest,
ObjectMoveUpdate,
ObjectPropertyDetailResponse, ObjectPropertyDetailResponse,
ObjectPropertyResponse, ObjectPropertyResponse,
ObjectRenameRequest, ObjectRenameRequest,
ObjectResponse, ObjectResponse,
ObjectSwitchPolicyRequest,
ObjectType, ObjectType,
FileCategory,
PolicyResponse, PolicyResponse,
UploadChunkResponse, UploadChunkResponse,
UploadSession, UploadSession,
@@ -68,24 +92,75 @@ from .object import (
AdminFileResponse, AdminFileResponse,
AdminFileListResponse, AdminFileListResponse,
FileBanRequest, FileBanRequest,
# 回收站DTO
TrashItemResponse,
TrashRestoreRequest,
TrashDeleteRequest,
)
from .object_metadata import (
ObjectMetadata,
ObjectMetadataBase,
MetadataNamespace,
MetadataResponse,
MetadataPatchItem,
MetadataPatchRequest,
INTERNAL_NAMESPACES,
USER_WRITABLE_NAMESPACES,
)
from .custom_property import (
CustomPropertyDefinition,
CustomPropertyDefinitionBase,
CustomPropertyType,
CustomPropertyCreateRequest,
CustomPropertyUpdateRequest,
CustomPropertyResponse,
) )
from .physical_file import PhysicalFile, PhysicalFileBase from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType from .order import (
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary Order, OrderStatus, OrderType,
from .redeem import Redeem, RedeemType CreateOrderRequest, OrderResponse,
)
from .policy import (
Policy, PolicyBase, PolicyCreateRequest, PolicyOptions, PolicyOptionsBase,
PolicyType, PolicySummary, PolicyUpdateRequest,
)
from .product import (
Product, ProductBase, ProductType, PaymentMethod,
ProductCreateRequest, ProductUpdateRequest, ProductResponse,
)
from .redeem import (
Redeem, RedeemType,
RedeemCreateRequest, RedeemUseRequest, RedeemInfoResponse, RedeemAdminResponse,
)
from .report import Report, ReportReason from .report import Report, ReportReason
from .setting import ( from .setting import (
Setting, SettingsType, SiteConfigResponse, Setting, SettingsType, SiteConfigResponse, AuthMethodConfig,
# 管理员DTO # 管理员DTO
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
) )
from .share import Share, ShareBase, ShareCreateRequest, ShareResponse, AdminShareListItem from .share import (
Share, ShareBase, ShareCreateRequest, CreateShareResponse, ShareResponse,
ShareOwnerInfo, ShareObjectItem, ShareDetailResponse,
AdminShareListItem,
)
from .source_link import SourceLink from .source_link import SourceLink
from .storage_pack import StoragePack from .storage_pack import StoragePack, StoragePackResponse
from .tag import Tag, TagType from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary, TaskSummaryBase
from .webdav import WebDAV from .webdav import (
WebDAV, WebDAVBase,
WebDAVCreateRequest, WebDAVUpdateRequest, WebDAVAccountResponse,
)
from .file_app import (
FileApp, FileAppType, FileAppExtension, FileAppGroupLink, UserFileAppDefault,
# DTO
FileAppSummary, FileViewersResponse, SetDefaultViewerRequest, UserFileAppDefaultResponse,
FileAppCreateRequest, FileAppUpdateRequest, FileAppResponse, FileAppListResponse,
ExtensionUpdateRequest, GroupAccessUpdateRequest, WopiSessionResponse,
WopiDiscoveredExtension, WopiDiscoveryResponse,
)
from .wopi import WopiFileInfo, WopiAccessTokenPayload
from .database_connection import DatabaseManager from .database_connection import DatabaseManager
@@ -102,5 +177,5 @@ from .model_base import (
AdminSummaryResponse, AdminSummaryResponse,
) )
# mixin 中的通用分页模型 # 通用分页模型
from .mixin import ListResponse from sqlmodel_ext import ListResponse

148
sqlmodels/auth_identity.py Normal file
View File

@@ -0,0 +1,148 @@
"""
认证身份模块
一个用户可拥有多种登录方式邮箱密码、OAuth、Passkey、Magic Link 等)。
AuthIdentity 表存储每种认证方式的凭证信息。
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100, Str128, Str255, Text1024
if TYPE_CHECKING:
from .user import User
class AuthProviderType(StrEnum):
"""认证提供者类型"""
EMAIL_PASSWORD = "email_password"
"""邮箱+密码"""
PHONE_SMS = "phone_sms"
"""手机号+短信验证码(预留)"""
GITHUB = "github"
"""GitHub OAuth"""
QQ = "qq"
"""QQ OAuth"""
PASSKEY = "passkey"
"""Passkey/WebAuthn"""
MAGIC_LINK = "magic_link"
"""邮箱 Magic Link"""
# ==================== DTO 模型 ====================
class AuthIdentityResponse(SQLModelBase):
"""认证身份响应 DTO列表展示用"""
id: UUID
"""身份UUID"""
provider: AuthProviderType
"""提供者类型"""
identifier: str
"""标识符(邮箱/手机号/OAuth openid"""
display_name: str | None = None
"""显示名称OAuth 昵称等)"""
avatar_url: str | None = None
"""头像 URL"""
is_primary: bool = False
"""是否主要身份"""
is_verified: bool = False
"""是否已验证"""
class BindIdentityRequest(SQLModelBase):
"""绑定认证身份请求 DTO"""
provider: AuthProviderType
"""提供者类型"""
identifier: str
"""标识符(邮箱/手机号/OAuth code"""
credential: str | None = None
"""凭证(密码、验证码等)"""
redirect_uri: str | None = None
"""OAuth 回调地址"""
class ChangePasswordRequest(SQLModelBase):
"""修改密码请求 DTO"""
old_password: str = Field(min_length=1)
"""当前密码"""
new_password: Str128 = Field(min_length=8)
"""新密码(至少 8 位)"""
# ==================== 数据库模型 ====================
class AuthIdentity(SQLModelBase, UUIDTableBaseMixin):
"""用户认证身份 — 一个用户可以有多种登录方式"""
__table_args__ = (
UniqueConstraint("provider", "identifier", name="uq_auth_identity_provider_identifier"),
)
provider: AuthProviderType = Field(index=True)
"""提供者类型"""
identifier: Str255 = Field(index=True)
"""标识符(邮箱/手机号/OAuth openid"""
credential: Text1024 | None = None
"""凭证Argon2 哈希密码 / null"""
display_name: Str100 | None = None
"""OAuth 昵称"""
avatar_url: str | None = Field(default=None, max_length=512)
"""OAuth 头像 URL"""
extra_data: str | None = None
"""JSON 附加数据2FA secret、OAuth refresh_token 等)"""
is_primary: bool = False
"""是否主要身份"""
is_verified: bool = False
"""是否已验证"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE",
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="auth_identities")
def to_response(self) -> AuthIdentityResponse:
"""转换为响应 DTO"""
return AuthIdentityResponse(
id=self.id,
provider=self.provider,
identifier=self.identifier,
display_name=self.display_name,
avatar_url=self.avatar_url,
is_primary=self.is_primary,
is_verified=self.is_verified,
)

View File

@@ -1,657 +0,0 @@
# SQLModels Base Module
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
## Table of Contents
- [Overview](#overview)
- [Migration Notice](#migration-notice)
- [Python 3.14 Compatibility](#python-314-compatibility)
- [Core Component](#core-component)
- [SQLModelBase](#sqlmodelbase)
- [Metaclass Features](#metaclass-features)
- [Automatic sa_type Injection](#automatic-sa_type-injection)
- [Table Configuration](#table-configuration)
- [Polymorphic Support](#polymorphic-support)
- [Custom Types Integration](#custom-types-integration)
- [Best Practices](#best-practices)
- [Troubleshooting](#troubleshooting)
## Overview
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
- **Flexible configuration** through class parameters and automatic docstring support
- **Type-safe annotations** with automatic validation
All models in this project should directly or indirectly inherit from `SQLModelBase`.
---
## Migration Notice
As of the recent refactoring, the following components have been moved:
| Component | Old Location | New Location |
|-----------|-------------|--------------|
| `TableBase``TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `UUIDTableBase``UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
**Update your imports**:
```python
# ❌ Old (deprecated)
from sqlmodels.base import TableBase, UUIDTableBase
# ✅ New (correct)
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
```
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Python 3.14 Compatibility
### Overview
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
**Key Changes in Python 3.14**:
- Annotations are no longer evaluated at class definition time
- Type hints are stored as deferred code objects
- `__annotate__` function generates annotations on demand
- Forward references become `ForwardRef` objects
### Implementation Strategy
We use **`typing.get_type_hints()`** as the universal annotations resolver:
```python
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
# Create temporary proxy class
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
# Use get_type_hints with include_extras=True
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
include_extras=True # Preserve Annotated metadata
)
return dict(evaluated), {}, module_globals, localns
```
**Why `get_type_hints()`?**
- ✅ Works across Python 3.10-3.14+
- ✅ Handles PEP 649 automatically
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
- ✅ Resolves forward references
- ✅ Recommended by Python documentation
### SQLModel Compatibility Patch
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
```python
if sys.version_info >= (3, 14):
def _patched_get_sqlalchemy_type(field):
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# Priority 1: Check __sqlmodel_sa_type__ attribute
# Handles NumpyVector[dims, dtype] and similar custom types
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# Priority 2: Check Annotated metadata
if get_origin(annotation) is Annotated:
for metadata in get_args(annotation)[1:]:
if hasattr(metadata, '__sqlmodel_sa_type__'):
return metadata.__sqlmodel_sa_type__
# ... handle ForwardRef, ClassVar, etc.
return _original_get_sqlalchemy_type(field)
```
### Supported Patterns
#### Pattern 1: Direct Custom Type Usage
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding - sa_type automatically extracted"""
```
#### Pattern 2: Annotated Wrapper
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
```
#### Pattern 3: Array Type
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
from sqlmodels.mixin import TableBaseMixin
class ServerConfig(TableBaseMixin, table=True):
protocols: Array[ProtocolEnum]
"""Allowed protocols - sa_type from Array handler"""
```
### Migration from Python 3.13
**No code changes required!** The implementation is transparent:
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
- Custom types already use `__sqlmodel_sa_type__` attribute
- Monkey-patch only activates for Python 3.14+
---
## Core Component
### SQLModelBase
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
**Key Features**:
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
- Automatic `validate_by_name` configuration
- Custom metaclass for sa_type injection and polymorphic setup
- Integration with Pydantic v2
- Python 3.14 PEP 649 compatibility
**Usage**:
```python
from sqlmodels.base import SQLModelBase
class UserBase(SQLModelBase):
name: str
"""User's display name"""
email: str
"""User's email address"""
```
**Important Notes**:
- Use **docstrings** for field descriptions, not `Field(description=...)`
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
- This class should be used for non-table models (DTOs, request/response models)
**For table models**, use mixins from `sqlmodels.mixin`:
- `TableBaseMixin` - Integer primary key with timestamps
- `UUIDTableBaseMixin` - UUID primary key with timestamps
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
---
## Metaclass Features
### Automatic sa_type Injection
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
**Before** (verbose):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
**After** (clean):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding"""
```
**How It Works**:
The metaclass uses a three-tier detection strategy:
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
```python
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
```
2. **Annotated metadata** (Priority 2)
```python
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
if get_origin(annotation) is typing.Annotated:
for item in metadata_items:
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
```
3. **Pydantic Core Schema metadata** (Priority 3)
```python
schema = annotation.__get_pydantic_core_schema__(...)
if schema['metadata'].get('sa_type'):
return schema['metadata']['sa_type']
```
After extracting `sa_type`, the metaclass:
- Creates `Field(sa_type=sa_type)` if no Field is defined
- Injects `sa_type` into existing Field if not already set
- Respects explicit `Field(sa_type=...)` (no override)
**Supported Patterns**:
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# Pattern 1: Direct usage (recommended)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
# Pattern 2: With Field constraints
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
# Pattern 3: Annotated wrapper
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class Model(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
# Pattern 4: Explicit sa_type (override)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
)
```
### Table Configuration
The metaclass provides smart defaults and flexible configuration:
**Automatic `table=True`**:
```python
# Classes inheriting from TableBaseMixin automatically get table=True
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin): # table=True is automatic
pass
```
**Convenient mapper arguments**:
```python
# Instead of verbose __mapper_args__
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
polymorphic_on='_polymorphic_name',
polymorphic_abstract=True
):
pass
# Equivalent to:
class MyModel(UUIDTableBaseMixin):
__mapper_args__ = {
'polymorphic_on': '_polymorphic_name',
'polymorphic_abstract': True
}
```
**Smart merging**:
```python
# Dictionary and keyword arguments are merged
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
mapper_args={'version_id_col': 'version'},
polymorphic_on='type' # Merged into __mapper_args__
):
pass
```
### Polymorphic Support
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
**Supported parameters**:
- `polymorphic_on`: Discriminator column name
- `polymorphic_identity`: Identity value for this class
- `polymorphic_abstract`: Whether this is an abstract base
- `table_args`: SQLAlchemy table arguments
- `table_name`: Override table name (becomes `__tablename__`)
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Custom Types Integration
### Using NumpyVector
The `NumpyVector` type demonstrates automatic sa_type injection:
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
import numpy as np
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding - sa_type automatically injected"""
```
**How NumpyVector works**:
```python
# NumpyVector[dims, dtype] returns a class with:
class _NumpyVectorType:
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return handler.generate_schema(np.ndarray)
```
This dual approach ensures:
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
2. Pydantic can validate as `np.ndarray`
### Creating Custom SQLAlchemy Types
To create types that work with automatic injection, provide one of:
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
```python
from sqlalchemy import TypeDecorator, String
class UpperCaseString(TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
return value.upper() if value else value
class UpperCaseType:
__sqlmodel_sa_type__ = UpperCaseString()
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return core_schema.str_schema()
# Usage
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCaseType # Automatically uses UpperCaseString()
```
**Option 2: Pydantic metadata with sa_type**:
```python
def __get_pydantic_core_schema__(self, source_type, handler):
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.str_schema(),
metadata={'sa_type': UpperCaseString()}
)
```
**Option 3: Using Annotated**:
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
UpperCase = Annotated[str, UpperCaseType()]
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCase
```
---
## Best Practices
### 1. Inherit from correct base classes
```python
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
# ✅ For non-table models (DTOs, requests, responses)
class UserBase(SQLModelBase):
name: str
# ✅ For table models with UUID primary key
class User(UserBase, UUIDTableBaseMixin, table=True):
email: str
# ✅ For table models with custom primary key
class LegacyUser(TableBaseMixin, table=True):
id: int = Field(primary_key=True)
username: str
```
### 2. Use docstrings for field descriptions
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Recommended
class User(UUIDTableBaseMixin, table=True):
name: str
"""User's display name"""
# ❌ Avoid
class User(UUIDTableBaseMixin, table=True):
name: str = Field(description="User's display name")
```
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
### 3. Leverage automatic sa_type injection
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Clean and recommended
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding"""
# ❌ Verbose and unnecessary
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
### 4. Follow polymorphic naming conventions
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
### 5. Separate Base, Parent, and Implementation classes
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
# ✅ Recommended structure
class ASRBase(SQLModelBase):
"""Pure data fields, no table"""
name: str
base_url: str
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract parent with table"""
@abstractmethod
async def transcribe(self, audio: bytes) -> str:
pass
class WhisperASR(ASR, table=True):
"""Concrete implementation"""
model_size: str
async def transcribe(self, audio: bytes) -> str:
# Implementation
pass
```
**Why?**
- Base class can be reused for DTOs
- Parent class defines the polymorphic hierarchy
- Implementation classes are clean and focused
---
## Troubleshooting
### Issue: ValueError: X has no matching SQLAlchemy type
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
```python
# ✅ Provide __sqlmodel_sa_type__
class MyType:
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
```
### Issue: Can't generate DDL for NullType()
**Symptoms**: Error during table creation saying a column has `NullType`.
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
**Solution**:
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
3. Verify type annotation is correct (not a string forward reference)
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Correct
class Model(UUIDTableBaseMixin, table=True):
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
# ❌ Wrong (string annotation)
class Model(UUIDTableBaseMixin, table=True):
data: 'NumpyVector[256, np.float32]' # sa_type lost
```
### Issue: Polymorphic identity conflicts
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
**Solution**:
1. Check that each concrete class has a unique identity
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
3. Manually specify identity if needed:
```python
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
pass
```
### Issue: Python 3.14 annotation errors
**Symptoms**: Errors related to `__annotations__` or type resolution.
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
1. Check for manual `__annotations__` manipulation (avoid it)
2. Ensure all types are properly imported
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
### Issue: Polymorphic and CRUD-related errors
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Implementation Details
For developers modifying this module:
**Core files**:
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
**Key functions in this module**:
1. **`_resolve_annotations(attrs: dict[str, Any])`**
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
- Preserves `Annotated` metadata with `include_extras=True`
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
- Extracts SQLAlchemy type from type annotations
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
- Called by metaclass during class creation
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
- Global monkey-patch for SQLModel
- Checks `__sqlmodel_sa_type__` before falling back to original logic
- Handles custom types like `NumpyVector` and `Array`
4. **`__DeclarativeMeta.__new__()`**
- Processes class definition parameters
- Injects `sa_type` into field definitions
- Sets up `__mapper_args__`, `__table_args__`, etc.
- Handles Python 3.14 annotations via `get_type_hints()`
**Metaclass processing order**:
1. Check if class should be a table (`_has_table_mixin`)
2. Collect `__mapper_args__` from kwargs and explicit dict
3. Process `table_args`, `table_name`, `abstract` parameters
4. Resolve annotations using `get_type_hints()`
5. For each field, try to extract `sa_type` and inject into Field
6. Call parent metaclass with cleaned kwargs
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## See Also
**Project Documentation**:
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
**External References**:
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)

View File

@@ -1,12 +0,0 @@
"""
SQLModel 基础模块
包含:
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
注意:
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
为了避免循环导入,此处不再重新导出它们
请直接从 sqlmodels.mixin 导入这些类
"""
from .sqlmodel_base import SQLModelBase

View File

@@ -1,846 +0,0 @@
import sys
import typing
from typing import Any, Mapping, get_args, get_origin, get_type_hints
from pydantic import ConfigDict
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined as Undefined
from sqlalchemy.orm import Mapped
from sqlmodel import Field, SQLModel
from sqlmodel.main import SQLModelMetaclass
# Python 3.14+ PEP 649支持
if sys.version_info >= (3, 14):
import annotationlib
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
import sqlmodel.main
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
def _patched_get_sqlalchemy_type(field):
"""
修复SQLModel的get_sqlalchemy_type函数处理Python 3.14的类型问题。
问题:
1. ForwardRef对象来自Relationship字段会导致issubclass错误
2. typing._GenericAlias对象如ClassVar[T])也会导致同样问题
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
4. Mapped类型在Python 3.14下可能出现在annotation中
5. Annotated类型可能包含sa_type metadata如Array[T]
6. 自定义类型如NumpyVector有__sqlmodel_sa_type__属性
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
解决:
- 优先检查field.metadata中的__get_pydantic_core_schema__Pydantic已处理的情况
- 检测__sqlmodel_sa_type__属性NumpyVector等
- 检测Relationship/ClassVar等返回None
- 对于Annotated类型尝试提取sa_type metadata
- 其他情况调用原始函数
"""
# 优先检查 field.metadataPydantic已处理Annotated类型的情况
# 当使用 Array[T] 或 Annotated[T, metadata] 时Pydantic会将metadata存储在这里
metadata = getattr(field, 'metadata', None)
if metadata:
# metadata是一个列表包含所有Annotated的元数据项
for metadata_item in metadata:
# 检查metadata_item是否有__get_pydantic_core_schema__方法
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata_item.__get_pydantic_core_schema__(None, None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Pydantic schema获取可能失败类型不匹配、缺少属性等
# 这是正常情况继续检查下一个metadata项
pass
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# 优先检查 __sqlmodel_sa_type__ 属性
# 这处理 NumpyVector[dims, dtype] 等自定义类型
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 检查自定义类型如JSON100K的 __get_pydantic_core_schema__ 方法
# 这些类型在schema的metadata中定义sa_type
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
# 调用获取schema传None作为handler因为我们只需要metadata
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Schema获取失败继续其他检查
pass
anno_type_name = type(annotation).__name__
# ForwardRef: Relationship字段的annotation
if anno_type_name == 'ForwardRef':
return None
# AnnotatedAlias: 检查是否有sa_type metadata如Array[T]
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
from typing import get_origin, get_args
import typing
# 尝试提取Annotated的metadata
if hasattr(typing, 'get_args'):
args = get_args(annotation)
# args[0]是实际类型args[1:]是metadata
for metadata in args[1:]:
# 检查metadata是否有__get_pydantic_core_schema__方法
if hasattr(metadata, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata.__get_pydantic_core_schema__(None, None)
# 检查schema中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Annotated metadata的schema获取可能失败
# 这是正常的类型检查过程继续检查下一个metadata
pass
# _GenericAlias或GenericAlias: typing泛型类型
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
from typing import get_origin
import typing
origin = get_origin(annotation)
# ClassVar必须跳过
if origin is typing.ClassVar:
return None
# list/dict/tuple/set等内置泛型如果字段没有明确的Field或Relationship也跳过
# 这通常意味着它是Relationship字段或类变量
if origin in (list, dict, tuple, set):
# 检查field_info是否存在且有意义
# Relationship字段会有特殊的field_info
field_info = getattr(field, 'field_info', None)
if field_info is None:
return None
# Mapped: SQLAlchemy 2.0的Mapped类型SQLModel不应该处理
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
# 检查类型名称和annotation的字符串表示
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
return None
# 检查annotation是否是Mapped类或其实例
try:
from sqlalchemy.orm import Mapped as SAMapped
# 检查origin对于Mapped[T]这种泛型)
from typing import get_origin
if get_origin(annotation) is SAMapped:
return None
# 检查类型本身
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
return None
except (ImportError, TypeError):
# 如果SQLAlchemy没有Mapped或检查失败继续
pass
# 其他情况正常处理
return _original_get_sqlalchemy_type(field)
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
# 在Python 3.14 + SQLModel组合下当子类如SMSBaoProvider继承父类如VerificationCodeProvider
# 父类的关系字段如server_config会在子类的model_fields中出现
# 但其default值错误地设置为InstrumentedAttribute对象而不是None
# 这导致实例化时尝试设置InstrumentedAttribute为字段值触发SQLAlchemy内部错误
import sqlmodel._compat as _compat
from sqlalchemy.orm import attributes as _sa_attributes
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
def _patched_sqlmodel_table_construct(self_instance, values):
"""
修复sqlmodel_table_construct跳过InstrumentedAttribute默认值
问题:
- 继承自polymorphic基类的表类如FishAudioTTS, SMSBaoProvider
- 其model_fields中的继承字段default值为InstrumentedAttribute
- 原函数尝试将InstrumentedAttribute设置为字段值
- SQLAlchemy无法处理抛出 '_sa_instance_state' 错误
解决:
- 只设置用户提供的值和非InstrumentedAttribute默认值
- InstrumentedAttribute默认值跳过让SQLAlchemy自己处理
"""
cls = type(self_instance)
# 收集要设置的字段值
fields_to_set = {}
for name, field in cls.model_fields.items():
# 如果用户提供了值,直接使用
if name in values:
fields_to_set[name] = values[name]
continue
# 否则检查默认值
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
continue
# 使用正常的默认值
if field.default is not Undefined:
fields_to_set[name] = field.default
elif field.default_factory is not None:
fields_to_set[name] = field.get_default(call_default_factory=True)
# 设置属性 - 只设置非InstrumentedAttribute值
for key, value in fields_to_set.items():
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
setattr(self_instance, key, value)
# 设置Pydantic内部属性
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
if not cls.__pydantic_root_model__:
_extra = None
if cls.model_config.get('extra') == 'allow':
_extra = {}
for k, v in values.items():
if k not in cls.model_fields:
_extra[k] = v
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
if cls.__pydantic_post_init__:
self_instance.model_post_init(None)
elif not cls.__pydantic_root_model__:
object.__setattr__(self_instance, '__pydantic_private__', None)
# 设置关系
for key in self_instance.__sqlmodel_relationships__:
value = values.get(key, Undefined)
if value is not Undefined:
setattr(self_instance, key, value)
return self_instance
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
else:
annotationlib = None
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
"""
从类型注解中提取SQLAlchemy类型。
支持以下形式:
1. NumpyVector[256, np.float32] - 直接使用类型有__sqlmodel_sa_type__属性
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
Args:
annotation: 字段的类型注解
Returns:
提取到的SQLAlchemy类型如果没有则返回None
"""
# 方法1直接检查类型本身是否有__sqlmodel_sa_type__属性
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 方法2检查是否为Annotated类型
if get_origin(annotation) is typing.Annotated:
# 获取元数据项(跳过第一个实际类型参数)
args = get_args(annotation)
if len(args) >= 2:
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
# 遍历元数据查找包含sa_type的项
for item in metadata_items:
# 检查元数据项是否有__sqlmodel_sa_type__属性
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
# 检查是否有__get_pydantic_core_schema__方法
if hasattr(item, '__get_pydantic_core_schema__'):
try:
# 调用该方法获取core schema
schema = item.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# Pydantic core schema获取可能失败
# - TypeError: 参数不匹配
# - AttributeError: metadata不存在
# - KeyError: schema结构不符合预期
# - ValueError: 无效的类型定义
# 这是正常的类型探测过程继续检查下一个metadata项
pass
# 方法3检查类型本身是否有__get_pydantic_core_schema__
# 虽然NumpyVector已经在方法1处理但这是通用的fallback
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
schema = annotation.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# 类型本身的schema获取失败
# 这是正常的fallback机制annotation可能不支持此协议
pass
return None
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
dict[str, Any],
dict[str, str],
Mapping[str, Any],
Mapping[str, Any],
]:
"""
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
metadata and custom types remain accessible. Forward references that cannot be
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
whole resolution process.
"""
raw_annotations = attrs.get('__annotations__') or {}
try:
base_annotations = dict(raw_annotations)
except TypeError:
base_annotations = {}
module_name = attrs.get('__module__')
module_globals: dict[str, Any]
if module_name and module_name in sys.modules:
module_globals = dict(sys.modules[module_name].__dict__)
else:
module_globals = {}
module_globals.setdefault('__builtins__', __builtins__)
localns: dict[str, Any] = dict(attrs)
try:
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
temp_cls.__module__ = module_name
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
**extras_kw,
)
except (NameError, AttributeError, TypeError, RecursionError):
# get_type_hints可能失败的原因
# - NameError: 前向引用无法解析(类型尚未定义)
# - AttributeError: 模块或类型不存在
# - TypeError: 无效的类型注解
# - RecursionError: 循环依赖的类型定义
# 这是正常情况,回退到原始注解字符串
evaluated = base_annotations
return dict(evaluated), {}, module_globals, localns
def _evaluate_annotation_from_string(
field_name: str,
annotation_strings: dict[str, str],
current_type: Any,
globalns: Mapping[str, Any],
localns: Mapping[str, Any],
) -> Any:
"""
Attempt to re-evaluate the original annotation string for a field.
This is used as a fallback when the resolved annotation lost its metadata
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
"""
if not annotation_strings:
return current_type
expr = annotation_strings.get(field_name)
if not expr or not isinstance(expr, str):
return current_type
try:
return eval(expr, globalns, localns)
except (NameError, SyntaxError, AttributeError, TypeError):
# eval可能失败的原因
# - NameError: 类型名称在namespace中不存在
# - SyntaxError: 注解字符串有语法错误
# - AttributeError: 访问不存在的模块属性
# - TypeError: 无效的类型表达式
# 这是正常的fallback机制返回当前已解析的类型
return current_type
class __DeclarativeMeta(SQLModelMetaclass):
"""
一个智能的混合模式元类,它提供了灵活性和清晰度:
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
"""
_KNOWN_MAPPER_KEYS = {
"polymorphic_on",
"polymorphic_identity",
"polymorphic_abstract",
"version_id_col",
"concrete",
}
def __new__(cls, name, bases, attrs, **kwargs):
# 1. 约定优于配置:自动设置 table=True
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
if is_intended_as_table and 'table' not in kwargs:
kwargs['table'] = True
# 2. 智能合并 __mapper_args__
collected_mapper_args = {}
# 首先,处理明确的 mapper_args 字典 (优先级较低)
if 'mapper_args' in kwargs:
collected_mapper_args.update(kwargs.pop('mapper_args'))
# 其次,处理便捷的关键字参数 (优先级更高)
for key in cls._KNOWN_MAPPER_KEYS:
if key in kwargs:
# .pop() 获取值并移除,避免传递给父类
collected_mapper_args[key] = kwargs.pop(key)
# 如果收集到了任何 mapper 参数,则更新到类的属性中
if collected_mapper_args:
existing = attrs.get('__mapper_args__', {}).copy()
existing.update(collected_mapper_args)
attrs['__mapper_args__'] = existing
# 3. 处理其他明确的参数
if 'table_args' in kwargs:
attrs['__table_args__'] = kwargs.pop('table_args')
if 'table_name' in kwargs:
attrs['__tablename__'] = kwargs.pop('table_name')
if 'abstract' in kwargs:
attrs['__abstract__'] = kwargs.pop('abstract')
# 4. 从Annotated元数据中提取sa_type并注入到Field
# 重要必须在调用父类__new__之前处理因为SQLModel会消费annotations
#
# Python 3.14兼容性问题:
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃issubclass错误
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
# - 虽然PEP 749建议不修改__annotations__但这是修复SQLModel bug的必要措施
#
# 获取annotations的策略
# - Python 3.14+: 优先从__annotate__获取如果存在
# - fallback: 从__annotations__读取如果存在
# - 最终fallback: 空字典
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
if annotations:
attrs['__annotations__'] = annotations
if annotationlib is not None:
# 在Python 3.14中禁用descriptor转为普通dict
attrs['__annotate__'] = None
for field_name, field_type in annotations.items():
field_type = _evaluate_annotation_from_string(
field_name,
annotation_strings,
field_type,
eval_globals,
eval_locals,
)
# 跳过字符串或ForwardRef类型注解让SQLModel自己处理
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
continue
# 跳过特殊类型的字段
origin = get_origin(field_type)
# 跳过 ClassVar 字段 - 它们不是数据库字段
if origin is typing.ClassVar:
continue
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
if origin is Mapped:
continue
# 尝试从注解中提取sa_type
sa_type = _extract_sa_type_from_annotation(field_type)
if sa_type is not None:
# 检查字段是否已有Field定义
field_value = attrs.get(field_name, Undefined)
if field_value is Undefined:
# 没有Field定义创建一个新的Field并注入sa_type
attrs[field_name] = Field(sa_type=sa_type)
elif isinstance(field_value, FieldInfo):
# 已有Field定义检查是否已设置sa_type
# 注意:只有在未设置时才注入,尊重显式配置
# SQLModel使用Undefined作为"未设置"的标记
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
field_value.sa_type = sa_type
# 如果field_value是其他类型如默认值不处理
# SQLModel会在后续处理中将其转换为Field
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
result = super().__new__(cls, name, bases, attrs, **kwargs)
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
if kwargs.get('table', False):
for base in bases:
if hasattr(base, '__sqlmodel_relationships__'):
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
# 只继承子类没有重新定义的关系
if rel_name not in result.__sqlmodel_relationships__:
result.__sqlmodel_relationships__[rel_name] = rel_info
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
if hasattr(base, rel_name):
base_attr = getattr(base, rel_name)
setattr(result, rel_name, base_attr)
# 7. 检测:禁止子类重定义父类的 Relationship 字段
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
for base in bases:
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
for rel_name in parent_relationships:
# 检查当前类是否在 attrs 中重新定义了这个关系字段
if rel_name in attrs:
raise TypeError(
f"{name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'"
f"如需修改关系配置,请在父类中修改。"
)
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
# SQLModel 0.0.27 bug子类会错误地继承父类的 Relationship 字段到 model_fields
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema因为类型是
# Mapped[list['Character']] 这种前向引用Pydantic 无法解析,
# 导致 __pydantic_complete__ = False
#
# 修复策略:
# - 检查类的 __sqlmodel_relationships__ 属性
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
relationships = getattr(result, '__sqlmodel_relationships__', {})
if relationships:
model_fields = getattr(result, 'model_fields', {})
pydantic_fields = getattr(result, '__pydantic_fields__', {})
fields_removed = False
for rel_name in relationships:
if rel_name in model_fields:
del model_fields[rel_name]
fields_removed = True
if rel_name in pydantic_fields:
del pydantic_fields[rel_name]
fields_removed = True
# 如果移除了字段,重新构建 Pydantic 模式
# 注意:只在有字段被移除时才 rebuild避免不必要的开销
if fields_removed and hasattr(result, 'model_rebuild'):
result.model_rebuild(force=True)
return result
def __init__(
cls,
classname: str,
bases: tuple[type, ...],
dict_: dict[str, typing.Any],
**kw: typing.Any,
) -> None:
"""
重写 SQLModel 的 __init__ 以支持联表继承Joined Table Inheritance
SQLModel 原始行为:
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
- 这阻止了子类创建自己的表
修复逻辑:
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
- 强制调用 DeclarativeMeta.__init__ 来创建子表
"""
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
# 检查是否是表模型
if not is_table_model_class(cls):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
return
# 检查是否有基类是表模型
base_is_table = any(is_table_model_class(base) for base in bases)
if not base_is_table:
# 没有基类是表模型,走正常的 SQLModel 流程
# 处理关系字段
cls._setup_relationships()
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
return
# 关键:检测联表继承场景
# 条件:
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
# 2. 当前类有字段带有 foreign_key 指向父表
current_tablename = getattr(cls, '__tablename__', None)
# 查找父表信息
parent_table = None
parent_tablename = None
for base in bases:
if is_table_model_class(base) and hasattr(base, '__tablename__'):
parent_tablename = base.__tablename__
break
# 检查是否有不同的 tablename
has_different_tablename = (
current_tablename is not None
and parent_tablename is not None
and current_tablename != parent_tablename
)
# 检查是否有外键字段指向父表的主键
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
# 而不是当前类的合并后的 model_fields
has_fk_to_parent = False
def _normalize_tablename(name: str) -> str:
"""标准化表名以进行比较(移除下划线,转小写)"""
return name.replace('_', '').lower()
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
"""检查 FK 字符串是否指向父表"""
if not fk_str or not parent_table:
return False
# FK 格式: "tablename.column" 或 "schema.tablename.column"
parts = fk_str.split('.')
if len(parts) >= 2:
fk_table = parts[-2] # 取倒数第二个作为表名
# 标准化比较(处理下划线差异)
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
return False
if has_different_tablename and parent_tablename:
# 首先检查当前类的 model_fields
for field_name, field_info in cls.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
# 如果没找到,检查直接基类的 model_fields解决 mixin 字段被覆盖的问题)
if not has_fk_to_parent:
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
if has_fk_to_parent:
break
is_joined_inheritance = has_different_tablename and has_fk_to_parent
if is_joined_inheritance:
# 联表继承:需要创建子表
# 修复外键字段:由于字段合并,外键信息可能丢失
# 需要从基类的 mixin 中找回外键信息,并重建列
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm.attributes import InstrumentedAttribute
# 联表继承:子表只应该有 idFK 到父表)+ 子类特有的字段
# 所有继承自祖先表的列都不应该在子表中重复创建
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
# 需要遍历整个 MRO因为可能是多级继承如 Tool -> Function -> GetWeatherFunction
ancestor_column_names: set[str] = set()
for ancestor in cls.__mro__:
if ancestor is cls:
continue # 跳过当前类
if is_table_model_class(ancestor):
try:
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
mapper = sa_inspect(ancestor)
for col in mapper.local_table.columns:
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
if col.name.startswith('_polymorphic'):
continue
ancestor_column_names.add(col.name)
except NoInspectionAvailable:
continue
# 找到子类自己定义的字段(不在父类中的)
child_own_fields: set[str] = set()
for field_name in cls.model_fields:
# 检查这个字段是否是在当前类直接定义的(不是继承的)
# 通过检查父类是否有这个字段来判断
is_inherited = False
for base in bases:
if hasattr(base, 'model_fields') and field_name in base.model_fields:
is_inherited = True
break
if not is_inherited:
child_own_fields.add(field_name)
# 从子类类属性中移除父表已有的列定义
# 这样 SQLAlchemy 就不会在子表中创建这些列
fk_field_name = None
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
pk = getattr(field_info, 'primary_key', False)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
fk_field_name = field_name
# 找到了外键字段,重建它
# 创建一个新的 Column 对象包含外键约束
new_col = Column(
field_name,
SA_UUID(as_uuid=True),
ForeignKey(fk),
primary_key=pk if pk else False
)
setattr(cls, field_name, new_col)
break
else:
continue
break
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
# 这防止 SQLAlchemy 在子表中创建重复列
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
for col_name in ancestor_column_names:
if col_name == fk_field_name:
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
if col_name == 'id':
continue # id 会被 FK 字段覆盖
if col_name in child_own_fields:
continue # 保留子类自己定义的字段
# 检查类属性是否是 Column 或 InstrumentedAttribute
if col_name in cls.__dict__:
attr = cls.__dict__[col_name]
# Column 对象或 InstrumentedAttribute 都需要删除
if isinstance(attr, (Column, InstrumentedAttribute)):
try:
delattr(cls, col_name)
except AttributeError:
pass
# 找到子类自己定义的关系(不在父类中的)
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
child_own_relationships: set[str] = set()
for rel_name in cls.__sqlmodel_relationships__:
is_inherited = False
for base in bases:
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
is_inherited = True
break
if not is_inherited:
child_own_relationships.add(rel_name)
# 只为子类自己定义的新关系调用关系设置
if child_own_relationships:
cls._setup_relationships(only_these=child_own_relationships)
# 强制调用 DeclarativeMeta.__init__
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
else:
# 非联表继承:单表继承或正常 Pydantic 模型
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
"""
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
Args:
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
如果为 None设置所有关系默认行为
"""
from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import inspect
from sqlmodel.main import get_relationship_to
from typing import get_origin
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
# 如果指定了 only_these只设置这些关系
if only_these is not None and rel_name not in only_these:
continue
if rel_info.sa_relationship:
setattr(cls, rel_name, rel_info.sa_relationship)
continue
raw_ann = cls.__annotations__[rel_name]
origin: typing.Any = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
else:
ann = raw_ann
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
rel_kwargs: dict[str, typing.Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.cascade_delete:
rel_kwargs["cascade"] = "all, delete-orphan"
if rel_info.passive_deletes:
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
if local_table is None:
raise RuntimeError(
f"Couldn't find secondary table for {rel_info.link_model}"
)
rel_kwargs["secondary"] = local_table
rel_args: list[typing.Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value)
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
"""此类必须和TableBase系列类搭配使用"""
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)

View File

@@ -1,7 +1,71 @@
from .base import SQLModelBase from enum import StrEnum
class ThemeResponse(SQLModelBase): from sqlmodel_ext import SQLModelBase
"""主题响应 DTO"""
pass
class ChromaticColor(StrEnum):
"""有彩色枚举17种 Tailwind 调色板颜色)"""
RED = "red"
ORANGE = "orange"
AMBER = "amber"
YELLOW = "yellow"
LIME = "lime"
GREEN = "green"
EMERALD = "emerald"
TEAL = "teal"
CYAN = "cyan"
SKY = "sky"
BLUE = "blue"
INDIGO = "indigo"
VIOLET = "violet"
PURPLE = "purple"
FUCHSIA = "fuchsia"
PINK = "pink"
ROSE = "rose"
class NeutralColor(StrEnum):
"""无彩色枚举5种灰色调"""
SLATE = "slate"
GRAY = "gray"
ZINC = "zinc"
NEUTRAL = "neutral"
STONE = "stone"
class ThemeColorsBase(SQLModelBase):
"""嵌套颜色 DTOAPI 请求/响应层使用"""
primary: ChromaticColor
"""主色调"""
secondary: ChromaticColor
"""辅助色"""
success: ChromaticColor
"""成功色"""
info: ChromaticColor
"""信息色"""
warning: ChromaticColor
"""警告色"""
error: ChromaticColor
"""错误色"""
neutral: NeutralColor
"""中性色"""
BUILTIN_DEFAULT_COLORS = ThemeColorsBase(
primary=ChromaticColor.GREEN,
secondary=ChromaticColor.BLUE,
success=ChromaticColor.GREEN,
info=ChromaticColor.BLUE,
warning=ChromaticColor.YELLOW,
error=ChromaticColor.RED,
neutral=NeutralColor.ZINC,
)

View File

@@ -0,0 +1,135 @@
"""
用户自定义属性定义模型
允许用户定义类型化的自定义属性模板(如标签、评分、分类等),
实际值通过 ObjectMetadata KV 表存储键名格式custom:{property_definition_id}
支持的属性类型text, number, boolean, select, multi_select, rating, link
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import JSON
from sqlmodel import Field, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str100
if TYPE_CHECKING:
from .user import User
# ==================== 枚举 ====================
class CustomPropertyType(StrEnum):
"""自定义属性值类型枚举"""
TEXT = "text"
"""文本"""
NUMBER = "number"
"""数字"""
BOOLEAN = "boolean"
"""布尔值"""
SELECT = "select"
"""单选"""
MULTI_SELECT = "multi_select"
"""多选"""
RATING = "rating"
"""评分1-5"""
LINK = "link"
"""链接"""
# ==================== Base 模型 ====================
class CustomPropertyDefinitionBase(SQLModelBase):
"""自定义属性定义基础模型"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: Str100 | None = None
"""图标标识iconify 名称)"""
options: list[str] | None = Field(default=None, sa_type=JSON)
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = Field(default=None, max_length=500)
"""默认值"""
# ==================== 数据库模型 ====================
class CustomPropertyDefinition(CustomPropertyDefinitionBase, UUIDTableBaseMixin):
"""
用户自定义属性定义
每个用户独立管理自己的属性模板。
实际属性值存储在 ObjectMetadata 表中键名格式custom:{id}
"""
owner_id: UUID = Field(
foreign_key="user.id",
ondelete="CASCADE",
index=True,
)
"""所有者用户UUID"""
sort_order: int = 0
"""排序顺序"""
# 关系
owner: "User" = Relationship()
"""所有者"""
# ==================== DTO 模型 ====================
class CustomPropertyCreateRequest(SQLModelBase):
"""创建自定义属性请求 DTO"""
name: Str100
"""属性显示名称"""
type: CustomPropertyType
"""属性值类型"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表(仅 select/multi_select 类型)"""
default_value: str | None = None
"""默认值"""
class CustomPropertyUpdateRequest(SQLModelBase):
"""更新自定义属性请求 DTO"""
name: str | None = None
"""属性显示名称"""
icon: str | None = None
"""图标标识"""
options: list[str] | None = None
"""可选值列表"""
default_value: str | None = None
"""默认值"""
sort_order: int | None = None
"""排序顺序"""
class CustomPropertyResponse(CustomPropertyDefinitionBase):
"""自定义属性响应 DTO"""
id: UUID
"""属性定义UUID"""
sort_order: int
"""排序顺序"""

View File

@@ -1,33 +0,0 @@
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.conf import appmeta
from sqlalchemy.orm import sessionmaker
from typing import AsyncGenerator
ASYNC_DATABASE_URL = appmeta.database_url
engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL,
echo=appmeta.debug,
connect_args={
"check_same_thread": False
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
future=True,
# pool_size=POOL_SIZE,
# max_overflow=64,
)
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with _async_session_factory() as session:
yield session
async def init_db(
url: str = ASYNC_DATABASE_URL
):
"""创建数据库结构"""
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

View File

@@ -4,8 +4,7 @@ from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, Index from sqlmodel import Field, Relationship, UniqueConstraint, Index
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, TableBaseMixin, Str255
from .mixin import UUIDTableBaseMixin, TableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -142,7 +141,7 @@ class Download(DownloadBase, UUIDTableBaseMixin):
speed: int = Field(default=0) speed: int = Field(default=0)
"""下载速度bytes/s""" """下载速度bytes/s"""
parent: str | None = Field(default=None, max_length=255) parent: Str255 | None = None
"""父任务标识""" """父任务标识"""
error: str | None = Field(default=None) error: str | None = Field(default=None)

435
sqlmodels/file_app.py Normal file
View File

@@ -0,0 +1,435 @@
"""
文件查看器应用模块
提供文件预览应用选择器系统的数据模型和 DTO。
类似 Android 的"使用什么应用打开"机制:
- 管理员注册应用(内置/iframe/WOPI
- 用户按扩展名查询可用查看器
- 用户可设置"始终使用"偏好
- 支持用户组级别的访问控制
架构:
FileApp (应用注册表)
├── FileAppExtension (扩展名关联)
├── FileAppGroupLink (用户组访问控制)
└── UserFileAppDefault (用户默认偏好)
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str100, Str255, Text1024
if TYPE_CHECKING:
from .group import Group
# ==================== 枚举 ====================
class FileAppType(StrEnum):
"""文件应用类型"""
BUILTIN = "builtin"
"""前端内置查看器(如 pdf.js, Monaco"""
IFRAME = "iframe"
"""iframe 内嵌第三方服务"""
WOPI = "wopi"
"""WOPI 协议OnlyOffice / Collabora"""
# ==================== Link 表 ====================
class FileAppGroupLink(SQLModelBase, table=True):
"""应用-用户组访问控制关联表"""
app_id: UUID = Field(foreign_key="fileapp.id", primary_key=True, ondelete="CASCADE")
"""关联的应用UUID"""
group_id: UUID = Field(foreign_key="group.id", primary_key=True, ondelete="CASCADE")
"""关联的用户组UUID"""
# ==================== DTO 模型 ====================
class FileAppSummary(SQLModelBase):
"""查看器列表项 DTO用于选择器弹窗"""
id: UUID
"""应用UUID"""
name: str
"""应用名称"""
app_key: str
"""应用唯一标识"""
type: FileAppType
"""应用类型"""
icon: str | None = None
"""图标名称/URL"""
description: str | None = None
"""应用描述"""
iframe_url_template: str | None = None
"""iframe URL 模板"""
wopi_editor_url_template: str | None = None
"""WOPI 编辑器 URL 模板"""
class FileViewersResponse(SQLModelBase):
"""查看器查询响应 DTO"""
viewers: list[FileAppSummary] = []
"""可用查看器列表(已按 priority 排序)"""
default_viewer_id: UUID | None = None
"""用户默认查看器UUID如果已设置"始终使用""""
class SetDefaultViewerRequest(SQLModelBase):
"""设置默认查看器请求 DTO"""
extension: str = Field(max_length=20)
"""文件扩展名(小写,无点号)"""
app_id: UUID
"""应用UUID"""
class UserFileAppDefaultResponse(SQLModelBase):
"""用户默认查看器响应 DTO"""
id: UUID
"""记录UUID"""
extension: str
"""扩展名"""
app: FileAppSummary
"""关联的应用摘要"""
class FileAppCreateRequest(SQLModelBase):
"""管理员创建应用请求 DTO"""
name: Str100
"""应用名称"""
app_key: str = Field(max_length=50)
"""应用唯一标识"""
type: FileAppType
"""应用类型"""
icon: Str255 | None = None
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
"""应用描述"""
is_enabled: bool = True
"""是否启用"""
is_restricted: bool = False
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
"""iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板"""
extensions: list[str] = []
"""关联的扩展名列表"""
allowed_group_ids: list[UUID] = []
"""允许访问的用户组UUID列表"""
class FileAppUpdateRequest(SQLModelBase):
"""管理员更新应用请求 DTO所有字段可选"""
name: Str100 | None = None
"""应用名称"""
app_key: str | None = Field(default=None, max_length=50)
"""应用唯一标识"""
type: FileAppType | None = None
"""应用类型"""
icon: Str255 | None = None
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
"""应用描述"""
is_enabled: bool | None = None
"""是否启用"""
is_restricted: bool | None = None
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
"""iframe URL 模板"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板"""
class FileAppResponse(SQLModelBase):
"""管理员应用详情响应 DTO"""
id: UUID
"""应用UUID"""
name: str
"""应用名称"""
app_key: str
"""应用唯一标识"""
type: FileAppType
"""应用类型"""
icon: str | None = None
"""图标名称/URL"""
description: str | None = None
"""应用描述"""
is_enabled: bool = True
"""是否启用"""
is_restricted: bool = False
"""是否限制用户组访问"""
iframe_url_template: str | None = None
"""iframe URL 模板"""
wopi_discovery_url: str | None = None
"""WOPI 发现端点 URL"""
wopi_editor_url_template: str | None = None
"""WOPI 编辑器 URL 模板"""
extensions: list[str] = []
"""关联的扩展名列表"""
allowed_group_ids: list[UUID] = []
"""允许访问的用户组UUID列表"""
@classmethod
def from_app(
cls,
app: "FileApp",
extensions: list["FileAppExtension"],
group_links: list[FileAppGroupLink],
) -> "FileAppResponse":
"""从 ORM 对象构建 DTO"""
return cls(
id=app.id,
name=app.name,
app_key=app.app_key,
type=app.type,
icon=app.icon,
description=app.description,
is_enabled=app.is_enabled,
is_restricted=app.is_restricted,
iframe_url_template=app.iframe_url_template,
wopi_discovery_url=app.wopi_discovery_url,
wopi_editor_url_template=app.wopi_editor_url_template,
extensions=[ext.extension for ext in extensions],
allowed_group_ids=[link.group_id for link in group_links],
)
class FileAppListResponse(SQLModelBase):
"""管理员应用列表响应 DTO"""
apps: list[FileAppResponse] = []
"""应用列表"""
total: int = 0
"""总数"""
class ExtensionUpdateRequest(SQLModelBase):
"""扩展名全量替换请求 DTO"""
extensions: list[str]
"""扩展名列表(小写,无点号)"""
class GroupAccessUpdateRequest(SQLModelBase):
"""用户组权限全量替换请求 DTO"""
group_ids: list[UUID]
"""允许访问的用户组UUID列表"""
class WopiSessionResponse(SQLModelBase):
"""WOPI 会话响应 DTO"""
wopi_src: str
"""WOPI 源 URL"""
access_token: str
"""WOPI 访问令牌"""
access_token_ttl: int
"""令牌过期时间戳毫秒WOPI 规范要求)"""
editor_url: str
"""完整的编辑器 URL"""
class WopiDiscoveredExtension(SQLModelBase):
"""单个 WOPI Discovery 发现的扩展名"""
extension: str
"""文件扩展名"""
action_url: str
"""处理后的动作 URL 模板"""
class WopiDiscoveryResponse(SQLModelBase):
"""WOPI Discovery 结果响应 DTO"""
discovered_extensions: list[WopiDiscoveredExtension] = []
"""发现的扩展名及其 URL 模板"""
app_names: list[str] = []
"""WOPI 服务端报告的应用名称(如 Writer、Calc、Impress"""
applied_count: int = 0
"""已应用到 FileAppExtension 的数量"""
# ==================== 数据库模型 ====================
class FileApp(SQLModelBase, UUIDTableBaseMixin):
"""文件查看器应用注册表"""
name: Str100
"""应用名称"""
app_key: str = Field(max_length=50, unique=True, index=True)
"""应用唯一标识,前端路由用"""
type: FileAppType
"""应用类型"""
icon: Str255 | None = None
"""图标名称/URL"""
description: str | None = Field(default=None, max_length=500)
"""应用描述"""
is_enabled: bool = True
"""是否启用"""
is_restricted: bool = False
"""是否限制用户组访问"""
iframe_url_template: Text1024 | None = None
"""iframe URL 模板,支持 {file_url} 占位符"""
wopi_discovery_url: str | None = Field(default=None, max_length=512)
"""WOPI 客户端发现端点 URL"""
wopi_editor_url_template: Text1024 | None = None
"""WOPI 编辑器 URL 模板,支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系
extensions: list["FileAppExtension"] = Relationship(
back_populates="app",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
user_defaults: list["UserFileAppDefault"] = Relationship(
back_populates="app",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
allowed_groups: list["Group"] = Relationship(
link_model=FileAppGroupLink,
)
def to_summary(self) -> FileAppSummary:
"""转换为摘要 DTO"""
return FileAppSummary(
id=self.id,
name=self.name,
app_key=self.app_key,
type=self.type,
icon=self.icon,
description=self.description,
iframe_url_template=self.iframe_url_template,
wopi_editor_url_template=self.wopi_editor_url_template,
)
class FileAppExtension(SQLModelBase, TableBaseMixin):
"""扩展名关联表"""
__table_args__ = (
UniqueConstraint("app_id", "extension", name="uq_fileappextension_app_extension"),
)
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
"""关联的应用UUID"""
extension: str = Field(max_length=20, index=True)
"""扩展名(小写,无点号)"""
priority: int = Field(default=0, ge=0)
"""排序优先级(越小越优先)"""
wopi_action_url: str | None = Field(default=None, max_length=2048)
"""WOPI 动作 URL 模板Discovery 自动填充),支持 {wopi_src} {access_token} {access_token_ttl}"""
# 关系
app: FileApp = Relationship(back_populates="extensions")
class UserFileAppDefault(SQLModelBase, UUIDTableBaseMixin):
"""用户"始终使用"偏好"""
__table_args__ = (
UniqueConstraint("user_id", "extension", name="uq_userfileappdefault_user_extension"),
)
user_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
"""用户UUID"""
extension: str = Field(max_length=20)
"""扩展名(小写,无点号)"""
app_id: UUID = Field(foreign_key="fileapp.id", index=True, ondelete="CASCADE")
"""关联的应用UUID"""
# 关系
app: FileApp = Relationship(back_populates="user_defaults")
def to_response(self) -> UserFileAppDefaultResponse:
"""转换为响应 DTO需预加载 app 关系)"""
return UserFileAppDefaultResponse(
id=self.id,
extension=self.extension,
app=self.app.to_summary(),
)

View File

@@ -2,10 +2,10 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, text from sqlmodel import Field, Relationship, text
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin, UUIDTableBaseMixin, Str255
from .mixin import TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -67,7 +67,7 @@ class GroupAllOptionsBase(GroupOptionsBase):
class GroupCreateRequest(GroupAllOptionsBase): class GroupCreateRequest(GroupAllOptionsBase):
"""创建用户组请求 DTO""" """创建用户组请求 DTO"""
name: str = Field(max_length=255) name: Str255
"""用户组名称""" """用户组名称"""
max_storage: int = Field(default=0, ge=0) max_storage: int = Field(default=0, ge=0)
@@ -92,7 +92,7 @@ class GroupCreateRequest(GroupAllOptionsBase):
class GroupUpdateRequest(SQLModelBase): class GroupUpdateRequest(SQLModelBase):
"""更新用户组请求 DTO所有字段可选""" """更新用户组请求 DTO所有字段可选"""
name: str | None = Field(default=None, max_length=255) name: Str255 | None = None
"""用户组名称""" """用户组名称"""
max_storage: int | None = Field(default=None, ge=0) max_storage: int | None = Field(default=None, ge=0)
@@ -258,10 +258,10 @@ class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
class Group(GroupBase, UUIDTableBaseMixin): class Group(GroupBase, UUIDTableBaseMixin):
"""用户组模型""" """用户组模型"""
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""用户组名""" """用户组名"""
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) max_storage: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)""" """最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})

File diff suppressed because one or more lines are too long

View File

@@ -1,543 +0,0 @@
# SQLModel Mixin Module
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
## Module Overview
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
- **JWT Authentication**: Token generation and validation
- **Pagination & Sorting**: Standardized table view parameters
- **Response DTOs**: Consistent id/timestamp fields for API responses
## Module Structure
```
sqlmodels/mixin/
├── __init__.py # Module exports
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
└── jwt/ # JWT authentication
├── __init__.py
├── key.py # JWTKey database model
├── payload.py # JWTPayloadBase
├── manager.py # JWTManager singleton
├── auth.py # JWTAuthMixin
├── exceptions.py # JWT-related exceptions
└── responses.py # TokenResponse DTO
```
## Dependency Hierarchy
The module has a strict import order to avoid circular dependencies:
1. **polymorphic.py** - Only depends on `SQLModelBase`
2. **table.py** - Depends on `polymorphic.py`
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
4. **info_response.py** - Only depends on `SQLModelBase`
## Core Components
### 1. TableBaseMixin
Base mixin for database table models with integer primary keys.
**Features:**
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
- Automatic timestamp management (`created_at`, `updated_at`)
- Async relationship loading support (via `AsyncAttrs`)
- Pagination and sorting via `TableViewRequest`
- Polymorphic subclass loading support
**Fields:**
- `id: int | None` - Integer primary key (auto-increment)
- `created_at: datetime` - Record creation timestamp
- `updated_at: datetime` - Record update timestamp (auto-updated)
**Usage:**
```python
from sqlmodels.mixin import TableBaseMixin
from sqlmodels.base import SQLModelBase
class User(SQLModelBase, TableBaseMixin, table=True):
name: str
email: str
"""User email"""
# CRUD operations
async def example(session: AsyncSession):
# Add
user = User(name="Alice", email="alice@example.com")
user = await user.save(session)
# Get
user = await User.get(session, User.id == 1)
# Update
update_data = UserUpdateRequest(name="Alice Smith")
user = await user.update(session, update_data)
# Delete
await User.delete(session, user)
# Count
count = await User.count(session, User.is_active == True)
```
**Important Notes:**
- `save()` and `update()` return refreshed instances - **always use the return value**:
```python
# ✅ Correct
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device
```
### 2. UUIDTableBaseMixin
Extends `TableBaseMixin` with UUID primary keys instead of integers.
**Differences from TableBaseMixin:**
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
- `get_exist_one()` accepts `UUID` instead of `int`
**Usage:**
```python
from sqlmodels.mixin import UUIDTableBaseMixin
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
name: str
description: str | None = None
"""Character description"""
```
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
### 3. TableViewRequest
Standardized pagination and sorting parameters for LIST endpoints.
**Fields:**
- `offset: int | None` - Skip first N records (default: 0)
- `limit: int | None` - Return max N records (default: 50, max: 100)
- `desc: bool | None` - Sort descending (default: True)
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
**Usage with TableBaseMixin.get():**
```python
from dependencies import TableViewRequestDep
@router.get("/list")
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> list[Character]:
"""List characters with pagination and sorting"""
return await Character.get(
session,
fetch_mode="all",
table_view=table_view # Automatically handles pagination and sorting
)
```
**Manual usage:**
```python
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
```
**Backward Compatibility:**
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
### 4. PolymorphicBaseMixin
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
**Automatic Configuration:**
- Defines `_polymorphic_name: str` field (indexed)
- Sets `polymorphic_on='_polymorphic_name'`
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
**Methods:**
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
**Usage:**
```python
from abc import ABC, abstractmethod
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
"""Abstract base class for all tools"""
name: str
description: str
"""Tool description"""
@abstractmethod
async def execute(self, params: dict) -> dict:
"""Execute the tool"""
pass
```
**Why Single Underscore Prefix?**
- SQLAlchemy maps single-underscore fields to database columns
- Pydantic treats them as private (excluded from serialization)
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
### 5. create_subclass_id_mixin()
Factory function to create ID mixins for subclasses in joined table inheritance.
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
**Signature:**
```python
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
"""
Args:
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
Returns:
A mixin class containing id field (foreign key + primary key)
"""
```
**Usage:**
```python
from sqlmodels.mixin import create_subclass_id_mixin
# Create mixin for ASR subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
"""FunASR implementation"""
pass
```
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
### 6. AutoPolymorphicIdentityMixin
Automatically generates `polymorphic_identity` based on class name.
**Naming Convention:**
- Format: `{parent_identity}.{classname_lowercase}`
- If no parent identity exists, uses `{classname_lowercase}`
**Usage:**
```python
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
"""Base class for function-type tools"""
pass
# polymorphic_identity = 'function'
class GetWeatherFunction(Function, table=True):
"""Weather query function"""
pass
# polymorphic_identity = 'function.getweatherfunction'
```
**Manual Override:**
```python
class CustomTool(
Tool,
AutoPolymorphicIdentityMixin,
polymorphic_identity='custom_name', # Override auto-generated name
table=True
):
pass
```
### 7. JWTAuthMixin
Provides JWT token generation and validation for entity classes (User, Client).
**Methods:**
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
**Requirements:**
Subclasses must define:
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
**Usage:**
```python
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
JWTPayload = UserJWTPayload # Define payload model
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
email: str
is_admin: bool = False
is_active: bool = True
"""User active status"""
# Generate token
async def login(session: AsyncSession, user: User) -> str:
token = await user.issue_jwt(session)
return token
# Validate token
async def verify(session: AsyncSession, token: str) -> User:
user = await User.from_jwt(session, token)
return user
```
### 8. Response DTO Mixins
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
**Available Mixins:**
- `IntIdInfoMixin` - Integer ID field
- `UUIDIdInfoMixin` - UUID ID field
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
**Usage:**
```python
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
"""Character response DTO with id and timestamps"""
pass # Inherits id, created_at, updated_at from mixin
```
## Complete Joined Table Inheritance Example
Here's a complete example demonstrating polymorphic inheritance:
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import (
UUIDTableBaseMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
)
# 1. Define Base class (fields only, no table)
class ASRBase(SQLModelBase):
name: str
"""Configuration name"""
base_url: str
"""Service URL"""
# 2. Define abstract parent class (with table)
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract base class for ASR configurations"""
# PolymorphicBaseMixin automatically provides:
# - _polymorphic_name field
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True (when ABC with abstract methods)
@abstractmethod
async def transcribe(self, pcm_data: bytes) -> str:
"""Transcribe audio to text"""
pass
# 3. Create ID Mixin for second-level subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. Create second-level abstract class (if needed)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
"""FunASR abstract base (may have multiple implementations)"""
pass
# polymorphic_identity = 'funasr'
# 5. Create concrete implementation classes
class FunASRLocal(FunASR, table=True):
"""FunASR local deployment"""
# polymorphic_identity = 'funasr.funasrlocal'
async def transcribe(self, pcm_data: bytes) -> str:
# Implementation...
return "transcribed text"
# 6. Get all concrete subclasses (for selectin_polymorphic)
concrete_asrs = ASR.get_concrete_subclasses()
# Returns: [FunASRLocal, ...]
```
## Import Guidelines
**Standard Import:**
```python
from sqlmodels.mixin import (
TableBaseMixin,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
TableViewRequest,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
JWTAuthMixin,
UUIDIdDatetimeInfoMixin,
now,
now_date,
)
```
**Backward Compatibility:**
Some exports are also available from `sqlmodels.base` for backward compatibility:
```python
# Legacy import path (still works)
from sqlmodels.base import UUIDTableBase, TableViewRequest
# Recommended new import path
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
```
## Best Practices
### 1. Mixin Order Matters
**Correct Order:**
```python
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
pass
```
**Wrong Order:**
```python
# ❌ ID Mixin not first - won't override parent's id field
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
pass
```
### 2. Always Use Return Values from save() and update()
```python
# ✅ Correct - use returned instance
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device # AttributeError when accessing fields
```
### 3. Prefer table_view Over Manual Pagination
```python
# ✅ Recommended - consistent across all endpoints
characters = await Character.get(
session,
fetch_mode="all",
table_view=table_view
)
# ⚠️ Works but not recommended - manual parameter management
characters = await Character.get(
session,
fetch_mode="all",
offset=0,
limit=20,
order_by=[desc(Character.created_at)]
)
```
### 4. Polymorphic Loading for Many Subclasses
```python
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
)
# For fewer subclasses, specify the list explicitly
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
)
```
### 5. Response DTOs Should Inherit Base Classes
```python
# ✅ Correct - inherits from CharacterBase
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
pass
# ❌ Wrong - doesn't inherit from CharacterBase
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
name: str # Duplicated field definition
description: str | None = None
```
**Reason:** Inheriting from Base classes ensures:
- Type checking via `isinstance(obj, XxxBase)`
- Consistency across related DTOs
- Future field additions automatically propagate
### 6. Use Specific Types, Not Containers
```python
# ✅ Correct - specific DTO for config updates
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
weather_api_key: str | None = None
default_location: str | None = None
"""Default location"""
# ❌ Wrong - lose type safety
class ToolUpdateRequest(SQLModelBase):
config: dict[str, Any] # No field validation
```
## Type Variables
```python
from sqlmodels.mixin import T, M
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
M = TypeVar("M", bound="SQLModel") # For update() method
```
## Utility Functions
```python
from sqlmodels.mixin import now, now_date
# Lambda functions for default factories
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
```
## Related Modules
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
- **sqlmodels.user** - User model with JWT authentication
- **sqlmodels.client** - Client model with JWT authentication
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
## Additional Resources
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
- `CLAUDE.md` - Project coding standards and design philosophy
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)

View File

@@ -1,62 +0,0 @@
"""
SQLModel Mixin模块
提供各种Mixin类供SQLModel实体使用。
包含:
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- optimistic_lock: 乐观锁OptimisticLockMixin, OptimisticLockError
- table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- relation_preload: 关系预加载RelationPreloadMixin, requires_relations
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要,避免循环导入:
1. polymorphic只依赖 SQLModelBase
2. optimistic_lock只依赖 SQLAlchemy
3. table依赖 polymorphic 和 optimistic_lock
4. relation_preload只依赖 SQLModelBase
注意jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
"""
# polymorphic 必须先导入
from .polymorphic import (
AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
register_sti_column_properties_for_all_subclasses,
register_sti_columns_for_all_subclasses,
)
# optimistic_lock 只依赖 SQLAlchemy必须在 table 之前
from .optimistic_lock import (
OptimisticLockError,
OptimisticLockMixin,
)
# table 依赖 polymorphic 和 optimistic_lock
from .table import (
ListResponse,
PaginationRequest,
T,
TableBaseMixin,
TableViewRequest,
TimeFilterRequest,
UUIDTableBaseMixin,
now,
now_date,
)
# relation_preload 只依赖 SQLModelBase
from .relation_preload import (
RelationPreloadMixin,
requires_relations,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import (
DatetimeInfoMixin,
IntIdDatetimeInfoMixin,
IntIdInfoMixin,
UUIDIdDatetimeInfoMixin,
UUIDIdInfoMixin,
)

View File

@@ -1,46 +0,0 @@
"""
InfoResponse DTO Mixin模块
提供用于InfoResponse类型DTO的Mixin统一定义id/created_at/updated_at字段。
设计说明:
- 这些Mixin用于**响应DTO**,不是数据库表
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
- TableBase中的id=None和default_factory=now是正确的入库前为None数据库生成
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
"""
from datetime import datetime
from uuid import UUID
from sqlmodels.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase):
"""整数ID响应mixin - 用于InfoResponse DTO"""
id: int
"""记录ID"""
class UUIDIdInfoMixin(SQLModelBase):
"""UUID ID响应mixin - 用于InfoResponse DTO"""
id: UUID
"""记录ID"""
class DatetimeInfoMixin(SQLModelBase):
"""时间戳响应mixin - 用于InfoResponse DTO"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
"""整数ID + 时间戳响应mixin"""
pass
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
"""UUID ID + 时间戳响应mixin"""
pass

View File

@@ -1,90 +0,0 @@
"""
乐观锁 Mixin
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
乐观锁适用场景:
- 涉及"状态转换"的表(如:待支付 -> 已支付)
- 涉及"数值变动"的表(如:余额、库存)
不适用场景:
- 日志表、纯插入表、低价值统计表
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
使用示例:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
status: OrderStatusEnum
amount: Decimal
# save/update 时自动检查版本号
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
try:
order = await order.save(session)
except OptimisticLockError as e:
# 处理冲突:重新查询并重试,或报错给用户
l.warning(f"乐观锁冲突: {e}")
"""
from typing import ClassVar
from sqlalchemy.orm.exc import StaleDataError
class OptimisticLockError(Exception):
"""
乐观锁冲突异常
当 save/update 操作检测到版本号不匹配时抛出。
这意味着在读取和写入之间,其他事务已经修改了该记录。
Attributes:
model_class: 发生冲突的模型类名
record_id: 记录 ID如果可用
expected_version: 期望的版本号(如果可用)
original_error: 原始的 StaleDataError
"""
def __init__(
self,
message: str,
model_class: str | None = None,
record_id: str | None = None,
expected_version: int | None = None,
original_error: StaleDataError | None = None,
):
super().__init__(message)
self.model_class = model_class
self.record_id = record_id
self.expected_version = expected_version
self.original_error = original_error
class OptimisticLockMixin:
"""
乐观锁 Mixin
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
session.commit() 会抛出 StaleDataError被 save/update 方法捕获并转换为 OptimisticLockError。
原理:
1. 每条记录有一个 version 字段,初始值为 0
2. 每次 UPDATE 时SQLAlchemy 生成的 SQL 类似:
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
3. 如果 WHERE 条件不匹配version 已被其他事务修改),
UPDATE 影响 0 行SQLAlchemy 抛出 StaleDataError
继承顺序:
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
...
配套重试:
如果加了乐观锁,业务层需要处理 OptimisticLockError
- 报错给用户:"数据已被修改,请刷新后重试"
- 自动重试:重新查询最新数据并再次尝试
"""
_has_optimistic_lock: ClassVar[bool] = True
"""标记此类启用了乐观锁"""
version: int = 0
"""乐观锁版本号,每次更新自动递增"""

View File

@@ -1,710 +0,0 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from loguru import logger as l
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import Column, String, inspect
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from sqlmodel.main import get_column_from_field
from sqlmodels.base.sqlmodel_base import SQLModelBase
# 用于延迟注册 STI 子类列的队列
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
_sti_subclasses_to_register: list[type] = []
def register_sti_columns_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
此函数应在 configure_mappers() 之前调用。
将 STI 子类的字段添加到父表的 metadata 中。
同时修复被 Column 对象污染的 model_fields。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_columns()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
# 修复被 Column 对象污染的 model_fields
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
try:
_fix_polluted_model_fields(cls)
except Exception as e:
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
def register_sti_column_properties_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类添加列属性到 mapper第二阶段
此函数应在 configure_mappers() 之后调用。
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_column_properties()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
# 清空队列
_sti_subclasses_to_register.clear()
def _fix_polluted_model_fields(cls: type) -> None:
"""
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
当 SQLModel 类继承有表的父类时SQLAlchemy 会在类上创建 InstrumentedAttribute
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
时错误地使用这些 SQLAlchemy 对象作为默认值。
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
:param cls: 要修复的类
"""
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 或 Column 污染的
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
return field_info
return None
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute 或 Column
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
# 被继承到有表的子类时SQLModel 会创建一个 Column 对象替换原始默认值
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, True且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制 SQLModel 特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的 default_factory 丢失问题。
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin并支持STI子类列注册
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
**重要:数据库迁移注意事项**
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
例如,对于以下继承链::
LLM (polymorphic_on='_polymorphic_name')
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
迁移脚本中设置 _polymorphic_name 时::
# ❌ 错误:缺少父类前缀
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
# ✅ 正确:包含完整的继承链前缀
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
**STI单表继承支持**
当子类与父类共用同一张表STI模式此Mixin会自动将子类的新字段
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
注册到父表的问题。
Example (JTI):
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Example (STI):
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
... user_id: UUID
...
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
... # polymorphic_identity 会自动设置为 'pendingfile'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
- STI模式下新字段会在类定义时自动添加到父表的metadata中
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity并处理STI列注册
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
# 注册 STI 子类列的延迟执行
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
_sti_subclasses_to_register.append(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复 STI 继承中子类字段被 Column 对象污染的问题。
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
@classmethod
def _register_sti_columns(cls) -> None:
"""
将STI子类的新字段注册到父表的列定义中
检测当前类是否是STI子类与父类共用同一张表
如果是则将子类定义的新字段添加到父表的metadata中。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_fields: set[str] = set()
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
# 收集父类的所有字段名
if hasattr(base, 'model_fields'):
parent_fields.update(base.model_fields.keys())
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列注册到父表
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列注册
# 获取当前类的新字段(不在父类中的字段)
if not hasattr(cls, 'model_fields'):
return
existing_columns = {col.name for col in parent_table.columns}
for field_name, field_info in cls.model_fields.items():
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段和ClassVar
if field_name.startswith('_'):
continue
# 跳过已存在的列
if field_name in existing_columns:
continue
# 使用 SQLModel 的内置 API 创建列
try:
column = get_column_from_field(field_info)
column.name = field_name
column.key = field_name
# STI子类字段在数据库层面必须可空因为其他子类的行不会有这些字段的值
# Pydantic层面的约束仍然有效创建特定子类时会验证必填字段
column.nullable = True
# 将列添加到父表
parent_table.append_column(column)
except Exception as e:
l.warning(f"{cls.__name__} 创建列 {field_name} 失败: {e}")
@classmethod
def _register_sti_column_properties(cls) -> None:
"""
将 STI 子类的列作为 ColumnProperty 添加到 mapper
此方法在 configure_mappers() 之后调用,将已添加到表中的列
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
这确保了查询父类时SELECT 语句包含所有 STI 子类的列,
避免在响应序列化时触发懒加载MissingGreenlet 错误)。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_class = None
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
parent_class = base
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列属性注册
# 获取子类和父类的 mapper
child_mapper = inspect(cls).mapper
parent_mapper = inspect(parent_class).mapper
local_table = child_mapper.local_table
# 查找父类的所有字段名
parent_fields: set[str] = set()
if hasattr(parent_class, 'model_fields'):
parent_fields.update(parent_class.model_fields.keys())
if not hasattr(cls, 'model_fields'):
return
# 获取两个 mapper 已有的列属性
child_existing_props = {p.key for p in child_mapper.column_attrs}
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
for field_name in cls.model_fields:
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段
if field_name.startswith('_'):
continue
# 检查表中是否有这个列
if field_name not in local_table.columns:
continue
column = local_table.columns[field_name]
# 添加到子类的 mapper如果尚不存在
if field_name not in child_existing_props:
try:
prop = ColumnProperty(column)
child_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"{cls.__name__} 添加列属性 {field_name} 失败: {e}")
# 同时添加到父类的 mapper确保查询父类时 SELECT 包含所有 STI 子类的列)
if field_name not in parent_existing_props:
try:
prop = ColumnProperty(column)
parent_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def _is_joined_table_inheritance(cls) -> bool:
"""
检测当前类是否使用联表继承Joined Table Inheritance
通过检查子类是否有独立的表来判断:
- JTI: 子类有独立的 local_table与父类不同
- STI: 子类与父类共用同一个 local_table
:return: True 表示 JTIFalse 表示 STI 或无子类
"""
mapper = inspect(cls)
base_table_name = mapper.local_table.name
# 检查所有直接子类
for subclass in cls.__subclasses__():
sub_mapper = inspect(subclass)
# 如果任何子类有不同的表名,说明是 JTI
if sub_mapper.local_table.name != base_table_name:
return True
return False
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

View File

@@ -1,470 +0,0 @@
"""
关系预加载 Mixin
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
设计原则:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 查询最优:相同关系只查询一次,不同关系增量查询
- 零侵入:调用方无需任何改动
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
使用方式:
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
kling_video_generator: KlingO1Generator = Relationship(...)
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# 自动加载,可以安全访问
price = self.kling_video_generator.kling_o1.pro_price_per_second
...
# 调用方 - 无需任何改动
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
await tool._call(...) # 关系相同则跳过,否则增量加载
支持 AsyncGenerator
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # 装饰器正确处理 async generator
"""
import inspect as python_inspect
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from loguru import logger as l
from sqlalchemy import inspect as sa_inspect
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
P = ParamSpec('P')
R = TypeVar('R')
def _extract_session(
func: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> AsyncSession | None:
"""
从方法参数中提取 AsyncSession
按以下顺序查找:
1. kwargs 中名为 'session' 的参数
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
3. kwargs 中类型为 AsyncSession 的参数
"""
# 1. 优先从 kwargs 查找
if 'session' in kwargs:
return kwargs['session']
# 2. 从函数签名定位位置参数
try:
sig = python_inspect.signature(func)
param_names = list(sig.parameters.keys())
if 'session' in param_names:
# 计算位置(减去 self
idx = param_names.index('session') - 1
if 0 <= idx < len(args):
return args[idx]
except (ValueError, TypeError):
pass
# 3. 遍历 kwargs 找 AsyncSession 类型
for value in kwargs.values():
if isinstance(value, AsyncSession):
return value
return None
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
"""
检查对象的关系是否已加载(独立函数版本)
Args:
obj: 要检查的对象
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(obj)
return rel_name not in state.unloaded
except Exception:
return False
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
"""
在类中查找指向目标类的关系属性名
Args:
from_class: 源类
to_class: 目标类
Returns:
关系属性名,如果找不到则返回 None
Example:
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
# 返回 'kling_video_generator'
"""
for attr_name in dir(from_class):
try:
attr = getattr(from_class, attr_name, None)
if attr is None:
continue
# 检查是否是 SQLAlchemy InstrumentedAttribute关系属性
# parent.class_ 是关系所在的类property.mapper.class_ 是关系指向的目标类
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
target_class = attr.property.mapper.class_
if target_class == to_class:
return attr_name
except AttributeError:
continue
return None
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
装饰器:声明方法需要的关系,自动按需增量加载
参数格式:
- 字符串:本类属性名,如 'kling_video_generator'
- RelationshipInfo外部类属性如 KlingO1Generator.kling_o1
行为:
- 方法调用时自动检查关系是否已加载
- 未加载的关系会被增量加载(单次查询)
- 已加载的关系直接跳过
支持:
- 普通 async 方法:`async def cost(...) -> ToolCost`
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
Example:
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# self.kling_video_generator.kling_o1 已自动加载
...
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # AsyncGenerator 正确处理
验证:
- 字符串格式的关系名在类创建时__init_subclass__验证
- 拼写错误会在导入时抛出 AttributeError
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
# 检测是否是 async generator 函数
is_async_gen = python_inspect.isasyncgenfunction(func)
if is_async_gen:
# AsyncGenerator 需要特殊处理wrapper 也必须是 async generator
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
# 委托给原始 async generator逐个 yield 值
async for item in func(self, *args, **kwargs):
yield item # type: ignore
else:
# 普通 async 函数await 并返回结果
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
return await func(self, *args, **kwargs)
# 保存关系声明供验证和内省使用
wrapper._required_relations = relations # type: ignore
return wrapper
return decorator
class RelationPreloadMixin:
"""
关系预加载 Mixin
提供按需增量加载能力,确保 SQL 查询数理论最优。
特性:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 原地更新:直接修改 self无需替换实例
- 导入时验证:字符串关系名在类创建时验证
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
"""
def __init_subclass__(cls, **kwargs) -> None:
"""类创建时验证所有 @requires_relations 声明"""
super().__init_subclass__(**kwargs)
# 收集类及其父类的所有注解(包含普通字段)
all_annotations: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__annotations__'):
all_annotations.update(klass.__annotations__.keys())
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__
sqlmodel_relationships: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__sqlmodel_relationships__'):
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
# 合并所有可用的属性名
all_available_names = all_annotations | sqlmodel_relationships
for method_name in dir(cls):
if method_name.startswith('__'):
continue
try:
method = getattr(cls, method_name, None)
except AttributeError:
continue
if method is None or not hasattr(method, '_required_relations'):
continue
# 验证字符串格式的关系名
for spec in method._required_relations:
if isinstance(spec, str):
# 检查注解、Relationship 或已有属性
if spec not in all_available_names and not hasattr(cls, spec):
raise AttributeError(
f"{cls.__name__}.{method_name} 声明了关系 '{spec}'"
f"{cls.__name__} 没有此属性"
)
def _is_relation_loaded(self, rel_name: str) -> bool:
"""
检查关系是否真正已加载(基于 SQLAlchemy inspect
使用 SQLAlchemy 的 inspect 检测真实加载状态,
自动处理 commit 导致的 expire 问题。
Args:
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(self)
# unloaded 包含未加载的关系属性名
return rel_name not in state.unloaded
except Exception:
# 对象可能未被 SQLAlchemy 管理
return False
async def _ensure_relations_loaded(
self,
session: AsyncSession,
relations: tuple[str | RelationshipInfo, ...],
) -> None:
"""
确保指定关系已加载,只加载未加载的部分
基于 SQLAlchemy inspect 检测真实状态,自动处理:
- 首次访问的关系
- commit 后 expire 的关系
- 嵌套关系(如 KlingO1Generator.kling_o1
Args:
session: 数据库会话
relations: 需要的关系规格
"""
# 找出真正未加载的关系(基于 SQLAlchemy inspect
to_load: list[str | RelationshipInfo] = []
# 区分直接关系和嵌套关系的 key
direct_keys: set[str] = set() # 本类的直接关系属性名
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
for rel in relations:
if isinstance(rel, str):
# 直接关系:检查本类的关系是否已加载
if not self._is_relation_loaded(rel):
to_load.append(rel)
direct_keys.add(rel)
else:
# 嵌套关系InstrumentedAttribute如 KlingO1Generator.kling_o1
# 1. 查找指向父类的关系属性
parent_class = rel.parent.class_
parent_attr = _find_relation_to_class(self.__class__, parent_class)
if parent_attr is None:
# 找不到路径,可能是配置错误,但仍尝试加载
l.warning(
f"无法找到从 {self.__class__.__name__}{parent_class.__name__} 的关系路径,"
f"无法检查 {rel.key} 是否已加载"
)
to_load.append(rel)
continue
# 2. 检查父对象是否已加载
if not self._is_relation_loaded(parent_attr):
# 父对象未加载,需要同时加载父对象和嵌套关系
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
else:
# 3. 父对象已加载,检查嵌套关系是否已加载
parent_obj = getattr(self, parent_attr)
if not _is_obj_relation_loaded(parent_obj, rel.key):
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
# 因为 _build_load_chains 需要完整的链来构建 selectinload
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
if not to_load:
return # 全部已加载,跳过
# 构建 load 参数
load_options = self._specs_to_load_options(to_load)
if not load_options:
return
# 安全地获取主键值(避免触发懒加载)
state = sa_inspect(self)
pk_tuple = state.key[1] if state.key else None
if pk_tuple is None:
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
return
# 主键是元组,取第一个值(假设单列主键)
pk_value = pk_tuple[0]
# 单次查询加载缺失的关系
fresh = await self.__class__.get(
session,
self.__class__.id == pk_value,
load=load_options,
)
if fresh is None:
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
return
# 原地复制到 self只复制直接关系嵌套关系通过父关系自动可访问
all_direct_keys = direct_keys | nested_parent_keys
for key in all_direct_keys:
value = getattr(fresh, key, None)
object.__setattr__(self, key, value)
def _specs_to_load_options(
self,
specs: list[str | RelationshipInfo],
) -> list[RelationshipInfo]:
"""
将关系规格转换为 load 参数
- 字符串 → cls.{name}
- RelationshipInfo → 直接使用
"""
result: list[RelationshipInfo] = []
for spec in specs:
if isinstance(spec, str):
rel = getattr(self.__class__, spec, None)
if rel is not None:
result.append(rel)
else:
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
else:
result.append(spec)
return result
# ==================== 可选的手动预加载 API ====================
@classmethod
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
"""
获取指定方法声明的关系(用于外部预加载场景)
Args:
method_name: 方法名
Returns:
RelationshipInfo 列表
"""
method = getattr(cls, method_name, None)
if method is None or not hasattr(method, '_required_relations'):
return []
result: list[RelationshipInfo] = []
for spec in method._required_relations:
if isinstance(spec, str):
rel = getattr(cls, spec, None)
if rel:
result.append(rel)
else:
result.append(spec)
return result
@classmethod
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
"""
获取多个方法的关系并去重(用于批量预加载场景)
Args:
method_names: 方法名列表
Returns:
去重后的 RelationshipInfo 列表
"""
seen: set[str] = set()
result: list[RelationshipInfo] = []
for method_name in method_names:
for rel in cls.get_relations_for_method(method_name):
key = rel.key
if key not in seen:
seen.add(key)
result.append(rel)
return result
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
"""
手动预加载指定方法的关系(可选优化 API
当需要确保在调用方法前完成所有加载时使用。
通常情况下不需要调用此方法,装饰器会自动处理。
Args:
session: 数据库会话
method_names: 方法名列表
Returns:
self支持链式调用
Example:
# 可选:显式预加载(通常不需要)
tool = await tool.preload_for(session, 'cost', '_call')
"""
all_relations: list[str | RelationshipInfo] = []
for method_name in method_names:
method = getattr(self.__class__, method_name, None)
if method and hasattr(method, '_required_relations'):
all_relations.extend(method._required_relations)
if all_relations:
await self._ensure_relations_loaded(session, tuple(all_relations))
return self

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ from enum import StrEnum
from sqlmodel import Field from sqlmodel import Field
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase
class ResponseBase(SQLModelBase): class ResponseBase(SQLModelBase):

View File

@@ -3,8 +3,7 @@ from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Index from sqlmodel import Field, Relationship, text, Index
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from .mixin import TableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .download import Download from .download import Download
@@ -29,13 +28,13 @@ class NodeType(StrEnum):
class Aria2ConfigurationBase(SQLModelBase): class Aria2ConfigurationBase(SQLModelBase):
"""Aria2配置基础模型""" """Aria2配置基础模型"""
rpc_url: str | None = Field(default=None, max_length=255) rpc_url: Str255 | None = None
"""RPC地址""" """RPC地址"""
rpc_secret: str | None = None rpc_secret: str | None = None
"""RPC密钥""" """RPC密钥"""
temp_path: str | None = Field(default=None, max_length=255) temp_path: Str255 | None = None
"""临时下载路径""" """临时下载路径"""
max_concurrent: int = Field(default=5, ge=1, le=50) max_concurrent: int = Field(default=5, ge=1, le=50)
@@ -71,19 +70,19 @@ class Node(SQLModelBase, TableBaseMixin):
status: NodeStatus = Field(default=NodeStatus.ONLINE) status: NodeStatus = Field(default=NodeStatus.ONLINE)
"""节点状态""" """节点状态"""
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""节点名称""" """节点名称"""
type: NodeType type: NodeType
"""节点类型""" """节点类型"""
server: str = Field(max_length=255) server: Str255
"""节点地址IP或域名""" """节点地址IP或域名"""
slave_key: str | None = Field(default=None, max_length=255) slave_key: Str255 | None = None
"""从机通讯密钥""" """从机通讯密钥"""
master_key: str | None = Field(default=None, max_length=255) master_key: Str255 | None = None
"""主机通讯密钥""" """主机通讯密钥"""
aria2_enabled: bool = False aria2_enabled: bool = False

View File

@@ -5,10 +5,11 @@ from uuid import UUID
from enum import StrEnum from enum import StrEnum
from sqlalchemy import BigInteger from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text from sqlmodel import Field, Relationship, CheckConstraint, Index, text
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255, Str256
from .mixin import UUIDTableBaseMixin
from .policy import PolicyType
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -17,6 +18,7 @@ if TYPE_CHECKING:
from .share import Share from .share import Share
from .physical_file import PhysicalFile from .physical_file import PhysicalFile
from .uri import DiskNextURI from .uri import DiskNextURI
from .object_metadata import ObjectMetadata
class ObjectType(StrEnum): class ObjectType(StrEnum):
@@ -24,42 +26,13 @@ class ObjectType(StrEnum):
FILE = "file" FILE = "file"
FOLDER = "folder" FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileCategory(StrEnum):
class FileMetadataBase(SQLModelBase): """文件类型分类枚举,用于按类别筛选文件"""
"""文件元数据基础模型""" IMAGE = "image"
VIDEO = "video"
width: int | None = Field(default=None) AUDIO = "audio"
"""图片宽度(像素)""" DOCUMENT = "document"
height: int | None = Field(default=None)
"""图片高度(像素)"""
duration: float | None = Field(default=None)
"""音视频时长(秒)"""
bitrate: int | None = Field(default=None)
"""比特率kbps"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和"""
checksum_sha256: str | None = Field(default=None, max_length=64)
"""SHA256校验和"""
# ==================== Base 模型 ==================== # ==================== Base 模型 ====================
@@ -76,9 +49,32 @@ class ObjectBase(SQLModelBase):
size: int | None = None size: int | None = None
"""文件大小(字节),目录为 None""" """文件大小(字节),目录为 None"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型仅文件有效"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
class ObjectFileFinalize(SQLModelBase):
"""文件上传完成后更新 Object 的 DTO"""
size: int
"""文件大小(字节)"""
physical_file_id: UUID
"""关联的物理文件UUID"""
class ObjectMoveUpdate(SQLModelBase):
"""移动/重命名 Object 的 DTO"""
parent_id: UUID
"""新的父目录UUID"""
name: str
"""新名称"""
class DirectoryCreateRequest(SQLModelBase): class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO""" """创建目录请求 DTO"""
@@ -137,7 +133,7 @@ class PolicyResponse(SQLModelBase):
name: str name: str
"""策略名称""" """策略名称"""
type: StorageType type: PolicyType
"""存储类型""" """存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger) max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
@@ -165,22 +161,6 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
"""文件元数据模型与Object一对一关联"""
object_id: UUID = Field(
foreign_key="object.id",
unique=True,
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID"""
# 反向关系
object: "Object" = Relationship(back_populates="file_metadata")
"""关联的对象"""
class Object(ObjectBase, UUIDTableBaseMixin): class Object(ObjectBase, UUIDTableBaseMixin):
""" """
统一对象模型 统一对象模型
@@ -195,8 +175,13 @@ class Object(ObjectBase, UUIDTableBaseMixin):
""" """
__table_args__ = ( __table_args__ = (
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况 # 同一父目录下名称唯一(仅对未删除记录生效
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"), Index(
"uq_object_parent_name_active",
"owner_id", "parent_id", "name",
unique=True,
postgresql_where=text("deleted_at IS NULL"),
),
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/" # 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/"
CheckConstraint( CheckConstraint(
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')", "parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
@@ -207,17 +192,19 @@ class Object(ObjectBase, UUIDTableBaseMixin):
Index("ix_object_parent_updated", "parent_id", "updated_at"), Index("ix_object_parent_updated", "parent_id", "updated_at"),
Index("ix_object_owner_type", "owner_id", "type"), Index("ix_object_owner_type", "owner_id", "type"),
Index("ix_object_owner_size", "owner_id", "size"), Index("ix_object_owner_size", "owner_id", "size"),
# 回收站查询索引
Index("ix_object_owner_deleted", "owner_id", "deleted_at"),
) )
# ==================== 基础字段 ==================== # ==================== 基础字段 ====================
name: str = Field(max_length=255) name: Str255
"""对象名称(文件名或目录名)""" """对象名称(文件名或目录名)"""
type: ObjectType type: ObjectType
"""对象类型file 或 folder""" """对象类型file 或 folder"""
password: str | None = Field(default=None, max_length=255) password: Str255 | None = None
"""对象独立密码(仅当用户为对象单独设置密码时有效)""" """对象独立密码(仅当用户为对象单独设置密码时有效)"""
# ==================== 文件专属字段 ==================== # ==================== 文件专属字段 ====================
@@ -225,7 +212,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}) size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0""" """文件大小(字节),目录为 0"""
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) upload_session_id: Str255 | None = Field(default=None, unique=True, index=True)
"""分块上传会话ID仅文件有效""" """分块上传会话ID仅文件有效"""
physical_file_id: UUID | None = Field( physical_file_id: UUID | None = Field(
@@ -280,6 +267,18 @@ class Object(ObjectBase, UUIDTableBaseMixin):
ban_reason: str | None = Field(default=None, max_length=500) ban_reason: str | None = Field(default=None, max_length=500)
"""封禁原因""" """封禁原因"""
# ==================== 软删除相关字段 ====================
deleted_at: datetime | None = Field(default=None, index=True)
"""软删除时间戳NULL 表示未删除"""
deleted_original_parent_id: UUID | None = Field(
default=None,
foreign_key="object.id",
ondelete="SET NULL",
)
"""软删除前的原始父目录UUID恢复时用于还原位置"""
# ==================== 关系 ==================== # ==================== 关系 ====================
owner: "User" = Relationship( owner: "User" = Relationship(
@@ -299,22 +298,28 @@ class Object(ObjectBase, UUIDTableBaseMixin):
# 自引用关系 # 自引用关系
parent: "Object" = Relationship( parent: "Object" = Relationship(
back_populates="children", back_populates="children",
sa_relationship_kwargs={"remote_side": "Object.id"}, sa_relationship_kwargs={
"remote_side": "Object.id",
"foreign_keys": "[Object.parent_id]",
},
) )
"""父目录""" """父目录"""
children: list["Object"] = Relationship( children: list["Object"] = Relationship(
back_populates="parent", back_populates="parent",
sa_relationship_kwargs={"cascade": "all, delete-orphan"} sa_relationship_kwargs={
"cascade": "all, delete-orphan",
"foreign_keys": "[Object.parent_id]",
},
) )
"""子对象(文件和子目录)""" """子对象(文件和子目录)"""
# 仅文件有效的关系 # 仅文件有效的关系
file_metadata: FileMetadata | None = Relationship( metadata_entries: list["ObjectMetadata"] = Relationship(
back_populates="object", back_populates="object",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"}, sa_relationship_kwargs={"cascade": "all, delete-orphan"},
) )
"""文件元数据(仅文件有效)""" """元数据键值对列表"""
source_links: list["SourceLink"] = Relationship( source_links: list["SourceLink"] = Relationship(
back_populates="object", back_populates="object",
@@ -367,7 +372,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
""" """
return await cls.get( return await cls.get(
session, session,
(cls.owner_id == user_id) & (cls.parent_id == None) (cls.owner_id == user_id) & (cls.parent_id == None) & (cls.deleted_at == None)
) )
@classmethod @classmethod
@@ -416,7 +421,8 @@ class Object(ObjectBase, UUIDTableBaseMixin):
session, session,
(cls.owner_id == user_id) & (cls.owner_id == user_id) &
(cls.parent_id == current.id) & (cls.parent_id == current.id) &
(cls.name == part) (cls.name == part) &
(cls.deleted_at == None)
) )
return current return current
@@ -424,7 +430,23 @@ class Object(ObjectBase, UUIDTableBaseMixin):
@classmethod @classmethod
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]: async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
""" """
获取目录下的所有子对象 获取目录下的所有子对象(不包含已软删除的)
:param session: 数据库会话
:param user_id: 用户UUID
:param parent_id: 父目录UUID
:return: 子对象列表
"""
return await cls.get(
session,
(cls.owner_id == user_id) & (cls.parent_id == parent_id) & (cls.deleted_at == None),
fetch_mode="all"
)
@classmethod
async def get_all_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
"""
获取目录下的所有子对象(包含已软删除的,用于永久删除场景)
:param session: 数据库会话 :param session: 数据库会话
:param user_id: 用户UUID :param user_id: 用户UUID
@@ -437,6 +459,55 @@ class Object(ObjectBase, UUIDTableBaseMixin):
fetch_mode="all" fetch_mode="all"
) )
@classmethod
async def get_trash_items(cls, session, user_id: UUID) -> list["Object"]:
"""
获取用户回收站中的顶层对象
只返回被直接软删除的顶层对象deleted_at 非 NULL
不返回其子对象(子对象的 deleted_at 为 NULL通过 parent 关系间接处于回收站中)。
:param session: 数据库会话
:param user_id: 用户UUID
:return: 回收站顶层对象列表
"""
return await cls.get(
session,
(cls.owner_id == user_id) & (cls.deleted_at != None),
fetch_mode="all"
)
@classmethod
async def get_by_category(
cls,
session: 'AsyncSession',
user_id: UUID,
extensions: list[str],
table_view: 'TableViewRequest | None' = None,
) -> 'ListResponse[Object]':
"""
按扩展名列表查询用户的所有文件(跨目录)
只查询未删除、未封禁的文件对象,使用 ILIKE 匹配文件名后缀。
:param session: 数据库会话
:param user_id: 用户UUID
:param extensions: 扩展名列表(不含点号)
:param table_view: 分页排序参数
:return: 分页文件列表
"""
from sqlalchemy import or_
ext_conditions = [cls.name.ilike(f"%.{ext}") for ext in extensions]
condition = (
(cls.owner_id == user_id) &
(cls.type == ObjectType.FILE) &
(cls.deleted_at == None) &
(cls.is_banned == False) &
or_(*ext_conditions)
)
return await cls.get_with_count(session, condition, table_view=table_view)
@classmethod @classmethod
async def resolve_uri( async def resolve_uri(
cls, cls,
@@ -514,7 +585,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
class UploadSessionBase(SQLModelBase): class UploadSessionBase(SQLModelBase):
"""上传会话基础字段""" """上传会话基础字段"""
file_name: str = Field(max_length=255) file_name: Str255
"""原始文件名""" """原始文件名"""
file_size: int = Field(ge=0, sa_type=BigInteger) file_size: int = Field(ge=0, sa_type=BigInteger)
@@ -545,6 +616,12 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
storage_path: str | None = Field(default=None, max_length=512) storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径""" """文件存储路径"""
s3_upload_id: Str256 | None = None
"""S3 Multipart Upload ID仅 S3 策略使用)"""
s3_part_etags: str | None = None
"""S3 已上传分片的 ETag 列表JSON 格式 [[1,"etag1"],[2,"etag2"]](仅 S3 策略使用)"""
expires_at: datetime expires_at: datetime
"""会话过期时间""" """会话过期时间"""
@@ -586,7 +663,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
class CreateUploadSessionRequest(SQLModelBase): class CreateUploadSessionRequest(SQLModelBase):
"""创建上传会话请求 DTO""" """创建上传会话请求 DTO"""
file_name: str = Field(max_length=255) file_name: Str255
"""文件名""" """文件名"""
file_size: int = Field(ge=0) file_size: int = Field(ge=0)
@@ -643,7 +720,7 @@ class UploadChunkResponse(SQLModelBase):
class CreateFileRequest(SQLModelBase): class CreateFileRequest(SQLModelBase):
"""创建空白文件请求 DTO""" """创建空白文件请求 DTO"""
name: str = Field(max_length=255) name: Str255
"""文件名""" """文件名"""
parent_id: UUID parent_id: UUID
@@ -653,6 +730,16 @@ class CreateFileRequest(SQLModelBase):
"""存储策略UUID不指定则使用父目录的策略""" """存储策略UUID不指定则使用父目录的策略"""
class ObjectSwitchPolicyRequest(SQLModelBase):
"""切换对象存储策略请求"""
policy_id: UUID
"""目标存储策略UUID"""
is_migrate_existing: bool = False
"""(仅目录)是否迁移已有文件,默认 false 只影响新文件"""
# ==================== 对象操作相关 DTO ==================== # ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase): class ObjectCopyRequest(SQLModelBase):
@@ -671,7 +758,7 @@ class ObjectRenameRequest(SQLModelBase):
id: UUID id: UUID
"""对象UUID""" """对象UUID"""
new_name: str = Field(max_length=255) new_name: Str255
"""新名称""" """新名称"""
@@ -690,6 +777,9 @@ class ObjectPropertyResponse(SQLModelBase):
size: int size: int
"""文件大小(字节)""" """文件大小(字节)"""
mime_type: str | None = None
"""MIME类型"""
created_at: datetime created_at: datetime
"""创建时间""" """创建时间"""
@@ -703,22 +793,13 @@ class ObjectPropertyResponse(SQLModelBase):
class ObjectPropertyDetailResponse(ObjectPropertyResponse): class ObjectPropertyDetailResponse(ObjectPropertyResponse):
"""对象详细属性响应 DTO继承基本属性""" """对象详细属性响应 DTO继承基本属性"""
# 元数据信息 # 校验和(从 PhysicalFile 读取)
mime_type: str | None = None
"""MIME类型"""
width: int | None = None
"""图片宽度(像素)"""
height: int | None = None
"""图片高度(像素)"""
duration: float | None = None
"""音视频时长(秒)"""
checksum_md5: str | None = None checksum_md5: str | None = None
"""MD5校验和""" """MD5校验和"""
checksum_sha256: str | None = None
"""SHA256校验和"""
# 分享统计 # 分享统计
share_count: int = 0 share_count: int = 0
"""分享次数""" """分享次数"""
@@ -736,6 +817,10 @@ class ObjectPropertyDetailResponse(ObjectPropertyResponse):
reference_count: int = 1 reference_count: int = 1
"""物理文件引用计数(仅文件有效)""" """物理文件引用计数(仅文件有效)"""
# 元数据KV 格式)
metadatas: dict[str, str] = {}
"""所有元数据条目(键名 → 值)"""
# ==================== 管理员文件管理 DTO ==================== # ==================== 管理员文件管理 DTO ====================
@@ -805,3 +890,41 @@ class AdminFileListResponse(SQLModelBase):
total: int = 0 total: int = 0
"""总数""" """总数"""
# ==================== 回收站相关 DTO ====================
class TrashItemResponse(SQLModelBase):
"""回收站对象响应 DTO"""
id: UUID
"""对象UUID"""
name: str
"""对象名称"""
type: ObjectType
"""对象类型"""
size: int
"""文件大小(字节)"""
deleted_at: datetime
"""删除时间"""
original_parent_id: UUID | None
"""原始父目录UUID"""
class TrashRestoreRequest(SQLModelBase):
"""恢复对象请求 DTO"""
ids: list[UUID]
"""待恢复对象UUID列表"""
class TrashDeleteRequest(SQLModelBase):
"""永久删除对象请求 DTO"""
ids: list[UUID]
"""待永久删除对象UUID列表"""

View File

@@ -0,0 +1,127 @@
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
如 exif:width, stream:duration, music:artist 等。
架构:
ObjectMetadata (KV 表,与 Object 一对多关系)
└── 每个 Object 可以有多条元数据记录
└── (object_id, name) 组合唯一索引
命名空间:
- exif: 图片 EXIF 信息(尺寸、相机参数、拍摄时间等)
- stream: 音视频流信息(时长、比特率、视频尺寸、编解码等)
- music: 音乐标签(标题、艺术家、专辑等)
- geo: 地理位置(经纬度、地址)
- apk: Android 安装包信息
- custom: 用户自定义属性
- sys: 系统内部元数据
- thumb: 缩略图信息
"""
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, UniqueConstraint, Index, Relationship
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .object import Object
# ==================== 枚举 ====================
class MetadataNamespace(StrEnum):
"""元数据命名空间枚举"""
EXIF = "exif"
"""图片 EXIF 信息(含尺寸、相机参数、拍摄时间等)"""
MUSIC = "music"
"""音乐标签title/artist/album/genre 等)"""
STREAM = "stream"
"""音视频流信息codec/duration/bitrate/resolution 等)"""
GEO = "geo"
"""地理位置latitude/longitude/address"""
APK = "apk"
"""Android 安装包信息package_name/version 等)"""
THUMB = "thumb"
"""缩略图信息(内部使用)"""
SYS = "sys"
"""系统元数据(内部使用)"""
CUSTOM = "custom"
"""用户自定义属性"""
# 对外不可见的命名空间API 不返回给普通用户)
INTERNAL_NAMESPACES: set[str] = {MetadataNamespace.SYS, MetadataNamespace.THUMB}
# 用户可写的命名空间
USER_WRITABLE_NAMESPACES: set[str] = {MetadataNamespace.CUSTOM}
# ==================== Base 模型 ====================
class ObjectMetadataBase(SQLModelBase):
"""对象元数据 KV 基础模型"""
name: Str255
"""元数据键名格式namespace:key如 exif:width, stream:duration"""
value: str
"""元数据值(统一为字符串存储)"""
# ==================== 数据库模型 ====================
class ObjectMetadata(ObjectMetadataBase, UUIDTableBaseMixin):
"""
对象元数据 KV 模型
以键值对形式存储文件的扩展元数据。键名使用命名空间前缀分类,
每个对象的每个键名唯一(通过唯一索引保证)。
"""
__table_args__ = (
UniqueConstraint("object_id", "name", name="uq_object_metadata_object_name"),
Index("ix_object_metadata_object_id", "object_id"),
)
object_id: UUID = Field(
foreign_key="object.id",
ondelete="CASCADE",
)
"""关联的对象UUID"""
is_public: bool = False
"""是否对分享页面公开"""
# 关系
object: "Object" = Relationship(back_populates="metadata_entries")
"""关联的对象"""
# ==================== DTO 模型 ====================
class MetadataResponse(SQLModelBase):
"""元数据查询响应 DTO"""
metadatas: dict[str, str]
"""元数据字典(键名 → 值)"""
class MetadataPatchItem(SQLModelBase):
"""单条元数据补丁 DTO"""
key: Str255
"""元数据键名"""
value: str | None = None
"""None 表示删除此条目"""
class MetadataPatchRequest(SQLModelBase):
"""元数据批量更新请求 DTO"""
patches: list[MetadataPatchItem]
"""补丁列表"""

View File

@@ -1,55 +1,118 @@
from decimal import Decimal
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import Numeric
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from .mixin import TableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .product import Product
from .user import User from .user import User
class OrderStatus(StrEnum): class OrderStatus(StrEnum):
"""订单状态枚举""" """订单状态枚举"""
PENDING = "pending" PENDING = "pending"
"""待支付""" """待支付"""
COMPLETED = "completed" COMPLETED = "completed"
"""已完成""" """已完成"""
CANCELLED = "cancelled" CANCELLED = "cancelled"
"""已取消""" """已取消"""
class OrderType(StrEnum): class OrderType(StrEnum):
"""订单类型枚举""" """订单类型枚举"""
# [TODO] 补充具体订单类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class CreateOrderRequest(SQLModelBase):
"""创建订单请求 DTO"""
product_id: UUID
"""商品UUID"""
num: int = Field(default=1, ge=1)
"""购买数量"""
method: str
"""支付方式"""
class OrderResponse(SQLModelBase):
"""订单响应 DTO"""
id: int
"""订单ID"""
order_no: str
"""订单号"""
type: OrderType
"""订单类型"""
method: str | None = None
"""支付方式"""
product_id: UUID | None = None
"""商品UUID"""
num: int
"""购买数量"""
name: str
"""商品名称"""
price: float
"""订单价格(元)"""
status: OrderStatus
"""订单状态"""
user_id: UUID
"""所属用户UUID"""
# ==================== 数据库模型 ====================
class Order(SQLModelBase, TableBaseMixin): class Order(SQLModelBase, TableBaseMixin):
"""订单模型""" """订单模型"""
order_no: str = Field(max_length=255, unique=True, index=True) order_no: Str255 = Field(unique=True, index=True)
"""订单号,唯一""" """订单号,唯一"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) type: OrderType
"""订单类型 [TODO] 待定义枚举""" """订单类型"""
method: str | None = Field(default=None, max_length=255) method: Str255 | None = None
"""支付方式""" """支付方式"""
product_id: int | None = Field(default=None) product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
"""商品ID""" """关联商品UUID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}) num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""购买数量""" """购买数量"""
name: str = Field(max_length=255) name: Str255
"""商品名称""" """商品名称"""
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""订单价格(""" """订单价格("""
status: OrderStatus = Field(default=OrderStatus.PENDING) status: OrderStatus = Field(default=OrderStatus.PENDING)
"""订单状态""" """订单状态"""
@@ -64,3 +127,19 @@ class Order(SQLModelBase, TableBaseMixin):
# 关系 # 关系
user: "User" = Relationship(back_populates="orders") user: "User" = Relationship(back_populates="orders")
product: "Product" = Relationship(back_populates="orders")
def to_response(self) -> OrderResponse:
"""转换为响应 DTO"""
return OrderResponse(
id=self.id,
order_no=self.order_no,
type=self.type,
method=self.method,
product_id=self.product_id,
num=self.num,
name=self.name,
price=float(self.price),
status=self.status,
user_id=self.user_id,
)

View File

@@ -15,8 +15,7 @@ from uuid import UUID
from sqlalchemy import BigInteger from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index from sqlmodel import Field, Relationship, Index
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str32, Str64
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .object import Object from .object import Object
@@ -32,9 +31,12 @@ class PhysicalFileBase(SQLModelBase):
size: int = Field(default=0, sa_type=BigInteger) size: int = Field(default=0, sa_type=BigInteger)
"""文件大小(字节)""" """文件大小(字节)"""
checksum_md5: str | None = Field(default=None, max_length=32) checksum_md5: Str32 | None = None
"""MD5校验和用于文件去重和完整性校验""" """MD5校验和用于文件去重和完整性校验"""
checksum_sha256: Str64 | None = None
"""SHA256校验和"""
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin): class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
""" """

View File

@@ -4,8 +4,7 @@ from uuid import UUID
from enum import StrEnum from enum import StrEnum
from sqlmodel import Field, Relationship, text from sqlmodel import Field, Relationship, text
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .object import Object from .object import Object
@@ -38,22 +37,22 @@ class PolicyType(StrEnum):
class PolicyBase(SQLModelBase): class PolicyBase(SQLModelBase):
"""存储策略基础字段,供 DTO 和数据库模型共享""" """存储策略基础字段,供 DTO 和数据库模型共享"""
name: str = Field(max_length=255) name: Str255
"""策略名称""" """策略名称"""
type: PolicyType type: PolicyType
"""存储策略类型""" """存储策略类型"""
server: str | None = Field(default=None, max_length=255) server: Str255 | None = None
"""服务器地址(本地策略为绝对路径)""" """服务器地址(本地策略为绝对路径)"""
bucket_name: str | None = Field(default=None, max_length=255) bucket_name: Str255 | None = None
"""存储桶名称""" """存储桶名称"""
is_private: bool = True is_private: bool = True
"""是否为私有空间""" """是否为私有空间"""
base_url: str | None = Field(default=None, max_length=255) base_url: Str255 | None = None
"""访问文件的基础URL""" """访问文件的基础URL"""
access_key: str | None = None access_key: str | None = None
@@ -68,10 +67,10 @@ class PolicyBase(SQLModelBase):
auto_rename: bool = False auto_rename: bool = False
"""是否自动重命名""" """是否自动重命名"""
dir_name_rule: str | None = Field(default=None, max_length=255) dir_name_rule: Str255 | None = None
"""目录命名规则""" """目录命名规则"""
file_name_rule: str | None = Field(default=None, max_length=255) file_name_rule: Str255 | None = None
"""文件命名规则""" """文件命名规则"""
is_origin_link_enable: bool = False is_origin_link_enable: bool = False
@@ -103,6 +102,94 @@ class PolicySummary(SQLModelBase):
"""是否私有""" """是否私有"""
class PolicyCreateRequest(PolicyBase):
"""创建存储策略请求 DTO包含 PolicyOptions 扁平字段"""
# PolicyOptions 字段(平铺到请求体中,与 GroupCreateRequest 模式一致)
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, ge=1)
"""分片上传大小字节默认50MB"""
s3_path_style: bool = False
"""是否使用S3路径风格"""
s3_region: Str64 = 'us-east-1'
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyUpdateRequest(SQLModelBase):
"""更新存储策略请求 DTO所有字段可选"""
name: Str255 | None = None
"""策略名称"""
server: Str255 | None = None
"""服务器地址"""
bucket_name: Str255 | None = None
"""存储桶名称"""
is_private: bool | None = None
"""是否为私有空间"""
base_url: Str255 | None = None
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int | None = Field(default=None, ge=0)
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool | None = None
"""是否自动重命名"""
dir_name_rule: Str255 | None = None
"""目录命名规则"""
file_name_rule: Str255 | None = None
"""文件命名规则"""
is_origin_link_enable: bool | None = None
"""是否开启源链接访问"""
# PolicyOptions 字段
token: str | None = None
"""访问令牌"""
file_type: str | None = None
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: Str255 | None = None
"""OneDrive重定向地址"""
chunk_size: int | None = Field(default=None, ge=1)
"""分片上传大小(字节)"""
s3_path_style: bool | None = None
"""是否使用S3路径风格"""
s3_region: Str64 | None = None
"""S3 区域"""
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
@@ -118,7 +205,7 @@ class PolicyOptionsBase(SQLModelBase):
mimetype: str | None = Field(default=None, max_length=127) mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型""" """MIME类型"""
od_redirect: str | None = Field(default=None, max_length=255) od_redirect: Str255 | None = None
"""OneDrive重定向地址""" """OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"}) chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
@@ -127,6 +214,9 @@ class PolicyOptionsBase(SQLModelBase):
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}) s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格""" """是否使用S3路径风格"""
s3_region: Str64 = Field(default='us-east-1', sa_column_kwargs={"server_default": "'us-east-1'"})
"""S3 区域(如 us-east-1、ap-southeast-1仅 S3 策略使用"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin): class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联""" """存储策略选项模型与Policy一对一关联"""
@@ -147,7 +237,7 @@ class Policy(PolicyBase, UUIDTableBaseMixin):
"""存储策略模型""" """存储策略模型"""
# 覆盖基类字段以添加数据库专有配置 # 覆盖基类字段以添加数据库专有配置
name: str = Field(max_length=255, unique=True) name: Str255 = Field(unique=True)
"""策略名称""" """策略名称"""
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}) is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})

206
sqlmodels/product.py Normal file
View File

@@ -0,0 +1,206 @@
from decimal import Decimal
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Numeric, BigInteger
from sqlmodel import Field, Relationship, text
from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str255
if TYPE_CHECKING:
from .order import Order
from .redeem import Redeem
class ProductType(StrEnum):
"""商品类型枚举"""
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
class PaymentMethod(StrEnum):
"""支付方式枚举"""
ALIPAY = "alipay"
"""支付宝"""
WECHAT = "wechat"
"""微信支付"""
STRIPE = "stripe"
"""Stripe"""
EASYPAY = "easypay"
"""易支付"""
CUSTOM = "custom"
"""自定义支付"""
# ==================== DTO 模型 ====================
class ProductBase(SQLModelBase):
"""商品基础字段"""
name: str
"""商品名称"""
type: ProductType
"""商品类型"""
description: str | None = None
"""商品描述"""
class ProductCreateRequest(ProductBase):
"""创建商品请求 DTO"""
name: Str255
"""商品名称"""
price: Decimal = Field(ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool = True
"""是否上架"""
sort_order: int = Field(default=0, ge=0)
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, ge=0)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = None
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = Field(default=None, ge=1)
"""积分数量type=score 时必填"""
class ProductUpdateRequest(SQLModelBase):
"""更新商品请求 DTO所有字段可选"""
name: Str255 | None = None
"""商品名称"""
description: str | None = None
"""商品描述"""
price: Decimal | None = Field(default=None, ge=0, decimal_places=2)
"""商品价格(元)"""
is_active: bool | None = None
"""是否上架"""
sort_order: int | None = Field(default=None, ge=0)
"""排序权重"""
size: int | None = Field(default=None, ge=0)
"""容量大小(字节)"""
duration_days: int | None = Field(default=None, ge=1)
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = Field(default=None, ge=1)
"""积分数量"""
class ProductResponse(ProductBase):
"""商品响应 DTO"""
id: UUID
"""商品UUID"""
price: float
"""商品价格(元)"""
is_active: bool
"""是否上架"""
sort_order: int
"""排序权重"""
size: int | None = None
"""容量大小(字节)"""
duration_days: int | None = None
"""有效天数"""
group_id: UUID | None = None
"""目标用户组UUID"""
score_amount: int | None = None
"""积分数量"""
# ==================== 数据库模型 ====================
class Product(ProductBase, UUIDTableBaseMixin):
"""商品模型"""
name: Str255
"""商品名称"""
price: Decimal = Field(sa_type=Numeric(12, 2), default=Decimal("0.00"))
"""商品价格(元)"""
is_active: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否上架"""
sort_order: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""排序权重(越大越靠前)"""
# storage_pack 专用
size: int | None = Field(default=None, sa_type=BigInteger)
"""容量大小字节type=storage_pack 时必填"""
duration_days: int | None = None
"""有效天数type=storage_pack/group_time 时必填"""
# group_time 专用
group_id: UUID | None = Field(default=None, foreign_key="group.id", ondelete="SET NULL")
"""目标用户组UUIDtype=group_time 时必填"""
# score 专用
score_amount: int | None = None
"""积分数量type=score 时必填"""
# 关系
orders: list["Order"] = Relationship(back_populates="product")
"""关联的订单列表"""
redeems: list["Redeem"] = Relationship(back_populates="product")
"""关联的兑换码列表"""
def to_response(self) -> ProductResponse:
"""转换为响应 DTO"""
return ProductResponse(
id=self.id,
name=self.name,
type=self.type,
description=self.description,
price=float(self.price),
is_active=self.is_active,
sort_order=self.sort_order,
size=self.size,
duration_days=self.duration_days,
group_id=self.group_id,
score_amount=self.score_amount,
)

View File

@@ -1,23 +1,141 @@
from datetime import datetime
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, text from sqlmodel import Field, Relationship, text
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .product import Product
from .user import User
class RedeemType(StrEnum): class RedeemType(StrEnum):
"""兑换码类型枚举""" """兑换码类型枚举"""
# [TODO] 补充具体兑换码类型
pass
STORAGE_PACK = "storage_pack"
"""容量包"""
GROUP_TIME = "group_time"
"""用户组时长"""
SCORE = "score"
"""积分充值"""
# ==================== DTO 模型 ====================
class RedeemCreateRequest(SQLModelBase):
"""批量生成兑换码请求 DTO"""
product_id: UUID
"""关联商品UUID"""
count: int = Field(default=1, ge=1, le=100)
"""生成数量"""
class RedeemUseRequest(SQLModelBase):
"""使用兑换码请求 DTO"""
code: str
"""兑换码"""
class RedeemInfoResponse(SQLModelBase):
"""兑换码信息响应 DTO用户侧"""
type: RedeemType
"""兑换码类型"""
product_name: str | None = None
"""关联商品名称"""
num: int
"""可兑换数量"""
is_used: bool
"""是否已使用"""
class RedeemAdminResponse(SQLModelBase):
"""兑换码管理响应 DTO管理侧"""
id: int
"""兑换码ID"""
type: RedeemType
"""兑换码类型"""
product_id: UUID | None = None
"""关联商品UUID"""
num: int
"""可兑换数量"""
code: str
"""兑换码"""
is_used: bool
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = None
"""使用者UUID"""
# ==================== 数据库模型 ====================
class Redeem(SQLModelBase, TableBaseMixin): class Redeem(SQLModelBase, TableBaseMixin):
"""兑换码模型""" """兑换码模型"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) type: RedeemType
"""兑换码类型 [TODO] 待定义枚举""" """兑换码类型"""
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等") product_id: UUID | None = Field(default=None, foreign_key="product.id", ondelete="SET NULL")
code: str = Field(unique=True, index=True, description="兑换码,唯一") """关联商品UUID"""
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""可兑换数量/时长等"""
code: str = Field(unique=True, index=True)
"""兑换码,唯一"""
is_used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否已使用"""
used_at: datetime | None = None
"""使用时间"""
used_by: UUID | None = Field(default=None, foreign_key="user.id", ondelete="SET NULL")
"""使用者UUID"""
# 关系
product: "Product" = Relationship(back_populates="redeems")
user: "User" = Relationship(back_populates="redeems")
def to_admin_response(self) -> RedeemAdminResponse:
"""转换为管理侧响应 DTO"""
return RedeemAdminResponse(
id=self.id,
type=self.type,
product_id=self.product_id,
num=self.num,
code=self.code,
is_used=self.is_used,
used_at=self.used_at,
used_by=self.used_by,
)
def to_info_response(self, product_name: str | None = None) -> RedeemInfoResponse:
"""转换为用户侧响应 DTO"""
return RedeemInfoResponse(
type=self.type,
product_name=product_name,
num=self.num,
is_used=self.is_used,
)

View File

@@ -1,10 +1,10 @@
from enum import StrEnum from enum import StrEnum
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship from sqlmodel import Field, Relationship
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin, Str255
from .mixin import TableBaseMixin
if TYPE_CHECKING: if TYPE_CHECKING:
from .share import Share from .share import Share
@@ -21,10 +21,10 @@ class Report(SQLModelBase, TableBaseMixin):
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""举报原因 [TODO] 待定义枚举""" """举报原因 [TODO] 待定义枚举"""
description: str | None = Field(default=None, max_length=255, description="补充描述") description: Str255 | None = Field(default=None, description="补充描述")
# 外键 # 外键
share_id: int = Field( share_id: UUID = Field(
foreign_key="share.id", foreign_key="share.id",
index=True, index=True,
ondelete="CASCADE" ondelete="CASCADE"

View File

@@ -2,8 +2,9 @@ from enum import StrEnum
from sqlmodel import UniqueConstraint from sqlmodel import UniqueConstraint
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, TableBaseMixin
from .mixin import TableBaseMixin
from .auth_identity import AuthProviderType
from .user import UserResponse from .user import UserResponse
class CaptchaType(StrEnum): class CaptchaType(StrEnum):
@@ -12,6 +13,19 @@ class CaptchaType(StrEnum):
GCAPTCHA = "gcaptcha" GCAPTCHA = "gcaptcha"
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile" CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
# ==================== Auth 配置 DTO ====================
class AuthMethodConfig(SQLModelBase):
"""认证方式配置 DTO"""
provider: AuthProviderType
"""提供者类型"""
is_enabled: bool
"""是否启用"""
# ==================== DTO 模型 ==================== # ==================== DTO 模型 ====================
class SiteConfigResponse(SQLModelBase): class SiteConfigResponse(SQLModelBase):
@@ -50,6 +64,30 @@ class SiteConfigResponse(SQLModelBase):
captcha_key: str | None = None captcha_key: str | None = None
"""验证码 public keyDEFAULT 类型时为 None""" """验证码 public keyDEFAULT 类型时为 None"""
auth_methods: list[AuthMethodConfig] = []
"""可用的登录方式列表"""
password_required: bool = True
"""注册时是否必须设置密码"""
phone_binding_required: bool = False
"""是否强制绑定手机号"""
email_binding_required: bool = True
"""是否强制绑定邮箱"""
avatar_max_size: int = 2097152
"""头像文件最大字节数(默认 2MB"""
footer_code: str | None = None
"""自定义页脚代码"""
tos_url: str | None = None
"""服务条款 URL"""
privacy_url: str | None = None
"""隐私政策 URL"""
# ==================== 管理员设置 DTO ==================== # ==================== 管理员设置 DTO ====================
@@ -125,6 +163,7 @@ class SettingsType(StrEnum):
VERSION = "version" VERSION = "version"
VIEW = "view" VIEW = "view"
WOPI = "wopi" WOPI = "wopi"
FILE_CATEGORY = "file_category"
# 数据库模型 # 数据库模型
class Setting(SettingItem, TableBaseMixin): class Setting(SettingItem, TableBaseMixin):

View File

@@ -5,8 +5,10 @@ from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from .base import SQLModelBase from sqlmodel_ext import SQLModelBase, UUIDTableBaseMixin, Str64, Str255
from .mixin import TableBaseMixin
from .model_base import ResponseBase
from .object import ObjectType
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import User from .user import User
@@ -34,13 +36,13 @@ class ShareBase(SQLModelBase):
preview_enabled: bool = True preview_enabled: bool = True
"""是否允许预览""" """是否允许预览"""
score: int = 0 score: int = Field(default=0, ge=0)
"""兑换此分享所需的积分""" """兑换此分享所需的积分"""
# ==================== 数据库模型 ==================== # ==================== 数据库模型 ====================
class Share(SQLModelBase, TableBaseMixin): class Share(SQLModelBase, UUIDTableBaseMixin):
"""分享模型""" """分享模型"""
__table_args__ = ( __table_args__ = (
@@ -50,10 +52,10 @@ class Share(SQLModelBase, TableBaseMixin):
Index("ix_share_object", "object_id"), Index("ix_share_object", "object_id"),
) )
code: str = Field(max_length=64, nullable=False, index=True) code: Str64 = Field(nullable=False, index=True)
"""分享码""" """分享码"""
password: str | None = Field(default=None, max_length=255) password: Str255 | None = None
"""分享密码(加密后)""" """分享密码(加密后)"""
object_id: UUID = Field( object_id: UUID = Field(
@@ -78,10 +80,10 @@ class Share(SQLModelBase, TableBaseMixin):
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}) preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否允许预览""" """是否允许预览"""
source_name: str | None = Field(default=None, max_length=255) source_name: Str255 | None = None
"""源名称(冗余字段,便于展示)""" """源名称(冗余字段,便于展示)"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) score: int = Field(default=0, ge=0)
"""兑换此分享所需的积分""" """兑换此分享所需的积分"""
# 外键 # 外键
@@ -119,10 +121,17 @@ class ShareCreateRequest(ShareBase):
pass pass
class ShareResponse(SQLModelBase): class CreateShareResponse(ResponseBase):
"""分享响应 DTO""" """创建分享响应 DTO"""
id: int share_id: UUID
"""新创建的分享记录 ID"""
class ShareResponse(SQLModelBase):
"""查看分享响应 DTO"""
id: UUID
"""分享ID""" """分享ID"""
code: str code: str
@@ -162,10 +171,67 @@ class ShareResponse(SQLModelBase):
"""是否有密码""" """是否有密码"""
class ShareOwnerInfo(SQLModelBase):
"""分享者公开信息 DTO"""
nickname: str | None
"""昵称"""
avatar: str
"""头像"""
class ShareObjectItem(SQLModelBase):
"""分享中的文件/文件夹信息 DTO"""
id: UUID
"""对象UUID"""
name: str
"""名称"""
type: ObjectType
"""类型file 或 folder"""
size: int
"""文件大小(字节),目录为 0"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""修改时间"""
class ShareDetailResponse(SQLModelBase):
"""获取分享详情响应 DTO面向访客隐藏内部统计数据"""
expires: datetime | None
"""过期时间"""
preview_enabled: bool
"""是否允许预览"""
score: int
"""积分"""
created_at: datetime
"""创建时间"""
owner: ShareOwnerInfo
"""分享者信息"""
object: ShareObjectItem
"""分享的根对象"""
children: list[ShareObjectItem]
"""子文件/文件夹列表(仅目录分享有内容)"""
class ShareListItemBase(SQLModelBase): class ShareListItemBase(SQLModelBase):
"""分享列表项基础字段""" """分享列表项基础字段"""
id: int id: UUID
"""分享ID""" """分享ID"""
code: str code: str

Some files were not shown because too many files have changed in this diff Show More