feat: add models for physical files, policies, and user management

- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
2026-02-10 16:25:49 +08:00
parent 62c671e07b
commit 209cb24ab4
92 changed files with 3640 additions and 1444 deletions

1274
sqlmodels/README.md Normal file

File diff suppressed because it is too large Load Diff

105
sqlmodels/__init__.py Normal file
View File

@@ -0,0 +1,105 @@
from .user import (
BatchDeleteRequest,
LoginRequest,
RefreshTokenRequest,
RegisterRequest,
AccessTokenBase,
RefreshTokenBase,
TokenResponse,
User,
UserBase,
UserStorageResponse,
UserPublic,
UserResponse,
UserSettingResponse,
WebAuthnInfo,
# 管理员DTO
UserAdminUpdateRequest,
UserCalibrateResponse,
UserAdminDetailResponse,
)
from .user_authn import AuthnResponse, UserAuthn
from .color import ThemeResponse
from .download import (
Download,
DownloadAria2File,
DownloadAria2Info,
DownloadAria2InfoBase,
DownloadStatus,
DownloadType,
)
from .node import (
Aria2Configuration,
Aria2ConfigurationBase,
Node,
NodeStatus,
NodeType,
)
from .group import (
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
# 管理员DTO
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
)
from .object import (
CreateFileRequest,
CreateUploadSessionRequest,
DirectoryCreateRequest,
DirectoryResponse,
FileMetadata,
FileMetadataBase,
Object,
ObjectBase,
ObjectCopyRequest,
ObjectDeleteRequest,
ObjectMoveRequest,
ObjectPropertyDetailResponse,
ObjectPropertyResponse,
ObjectRenameRequest,
ObjectResponse,
ObjectType,
PolicyResponse,
UploadChunkResponse,
UploadSession,
UploadSessionBase,
UploadSessionResponse,
# 管理员DTO
AdminFileResponse,
AdminFileListResponse,
FileBanRequest,
)
from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
from .redeem import Redeem, RedeemType
from .report import Report, ReportReason
from .setting import (
Setting, SettingsType, SiteConfigResponse,
# 管理员DTO
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
)
from .share import Share, ShareBase, ShareCreateRequest, ShareResponse, AdminShareListItem
from .source_link import SourceLink
from .storage_pack import StoragePack
from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
from .webdav import WebDAV
from .database_connection import DatabaseManager
from .model_base import (
MCPBase,
MCPMethod,
MCPRequestBase,
MCPResponseBase,
ResponseBase,
# Admin Summary DTO
MetricsSummary,
LicenseInfo,
VersionInfo,
AdminSummaryResponse,
)
# mixin 中的通用分页模型
from .mixin import ListResponse

657
sqlmodels/base/README.md Normal file
View File

@@ -0,0 +1,657 @@
# SQLModels Base Module
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
## Table of Contents
- [Overview](#overview)
- [Migration Notice](#migration-notice)
- [Python 3.14 Compatibility](#python-314-compatibility)
- [Core Component](#core-component)
- [SQLModelBase](#sqlmodelbase)
- [Metaclass Features](#metaclass-features)
- [Automatic sa_type Injection](#automatic-sa_type-injection)
- [Table Configuration](#table-configuration)
- [Polymorphic Support](#polymorphic-support)
- [Custom Types Integration](#custom-types-integration)
- [Best Practices](#best-practices)
- [Troubleshooting](#troubleshooting)
## Overview
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
- **Flexible configuration** through class parameters and automatic docstring support
- **Type-safe annotations** with automatic validation
All models in this project should directly or indirectly inherit from `SQLModelBase`.
---
## Migration Notice
As of the recent refactoring, the following components have been moved:
| Component | Old Location | New Location |
|-----------|-------------|--------------|
| `TableBase``TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `UUIDTableBase``UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
**Update your imports**:
```python
# ❌ Old (deprecated)
from sqlmodels.base import TableBase, UUIDTableBase
# ✅ New (correct)
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
```
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Python 3.14 Compatibility
### Overview
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
**Key Changes in Python 3.14**:
- Annotations are no longer evaluated at class definition time
- Type hints are stored as deferred code objects
- `__annotate__` function generates annotations on demand
- Forward references become `ForwardRef` objects
### Implementation Strategy
We use **`typing.get_type_hints()`** as the universal annotations resolver:
```python
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
# Create temporary proxy class
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
# Use get_type_hints with include_extras=True
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
include_extras=True # Preserve Annotated metadata
)
return dict(evaluated), {}, module_globals, localns
```
**Why `get_type_hints()`?**
- ✅ Works across Python 3.10-3.14+
- ✅ Handles PEP 649 automatically
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
- ✅ Resolves forward references
- ✅ Recommended by Python documentation
### SQLModel Compatibility Patch
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
```python
if sys.version_info >= (3, 14):
def _patched_get_sqlalchemy_type(field):
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# Priority 1: Check __sqlmodel_sa_type__ attribute
# Handles NumpyVector[dims, dtype] and similar custom types
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# Priority 2: Check Annotated metadata
if get_origin(annotation) is Annotated:
for metadata in get_args(annotation)[1:]:
if hasattr(metadata, '__sqlmodel_sa_type__'):
return metadata.__sqlmodel_sa_type__
# ... handle ForwardRef, ClassVar, etc.
return _original_get_sqlalchemy_type(field)
```
### Supported Patterns
#### Pattern 1: Direct Custom Type Usage
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding - sa_type automatically extracted"""
```
#### Pattern 2: Annotated Wrapper
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
```
#### Pattern 3: Array Type
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
from sqlmodels.mixin import TableBaseMixin
class ServerConfig(TableBaseMixin, table=True):
protocols: Array[ProtocolEnum]
"""Allowed protocols - sa_type from Array handler"""
```
### Migration from Python 3.13
**No code changes required!** The implementation is transparent:
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
- Custom types already use `__sqlmodel_sa_type__` attribute
- Monkey-patch only activates for Python 3.14+
---
## Core Component
### SQLModelBase
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
**Key Features**:
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
- Automatic `validate_by_name` configuration
- Custom metaclass for sa_type injection and polymorphic setup
- Integration with Pydantic v2
- Python 3.14 PEP 649 compatibility
**Usage**:
```python
from sqlmodels.base import SQLModelBase
class UserBase(SQLModelBase):
name: str
"""User's display name"""
email: str
"""User's email address"""
```
**Important Notes**:
- Use **docstrings** for field descriptions, not `Field(description=...)`
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
- This class should be used for non-table models (DTOs, request/response models)
**For table models**, use mixins from `sqlmodels.mixin`:
- `TableBaseMixin` - Integer primary key with timestamps
- `UUIDTableBaseMixin` - UUID primary key with timestamps
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
---
## Metaclass Features
### Automatic sa_type Injection
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
**Before** (verbose):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
**After** (clean):
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding"""
```
**How It Works**:
The metaclass uses a three-tier detection strategy:
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
```python
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
```
2. **Annotated metadata** (Priority 2)
```python
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
if get_origin(annotation) is typing.Annotated:
for item in metadata_items:
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
```
3. **Pydantic Core Schema metadata** (Priority 3)
```python
schema = annotation.__get_pydantic_core_schema__(...)
if schema['metadata'].get('sa_type'):
return schema['metadata']['sa_type']
```
After extracting `sa_type`, the metaclass:
- Creates `Field(sa_type=sa_type)` if no Field is defined
- Injects `sa_type` into existing Field if not already set
- Respects explicit `Field(sa_type=...)` (no override)
**Supported Patterns**:
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# Pattern 1: Direct usage (recommended)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
# Pattern 2: With Field constraints
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
# Pattern 3: Annotated wrapper
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
class Model(UUIDTableBaseMixin, table=True):
embedding: EmbeddingVector
# Pattern 4: Explicit sa_type (override)
class Model(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32] = Field(
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
)
```
### Table Configuration
The metaclass provides smart defaults and flexible configuration:
**Automatic `table=True`**:
```python
# Classes inheriting from TableBaseMixin automatically get table=True
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin): # table=True is automatic
pass
```
**Convenient mapper arguments**:
```python
# Instead of verbose __mapper_args__
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
polymorphic_on='_polymorphic_name',
polymorphic_abstract=True
):
pass
# Equivalent to:
class MyModel(UUIDTableBaseMixin):
__mapper_args__ = {
'polymorphic_on': '_polymorphic_name',
'polymorphic_abstract': True
}
```
**Smart merging**:
```python
# Dictionary and keyword arguments are merged
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(
UUIDTableBaseMixin,
mapper_args={'version_id_col': 'version'},
polymorphic_on='type' # Merged into __mapper_args__
):
pass
```
### Polymorphic Support
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
**Supported parameters**:
- `polymorphic_on`: Discriminator column name
- `polymorphic_identity`: Identity value for this class
- `polymorphic_abstract`: Whether this is an abstract base
- `table_args`: SQLAlchemy table arguments
- `table_name`: Override table name (becomes `__tablename__`)
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Custom Types Integration
### Using NumpyVector
The `NumpyVector` type demonstrates automatic sa_type injection:
```python
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
from sqlmodels.mixin import UUIDTableBaseMixin
import numpy as np
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Speaker voice embedding - sa_type automatically injected"""
```
**How NumpyVector works**:
```python
# NumpyVector[dims, dtype] returns a class with:
class _NumpyVectorType:
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return handler.generate_schema(np.ndarray)
```
This dual approach ensures:
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
2. Pydantic can validate as `np.ndarray`
### Creating Custom SQLAlchemy Types
To create types that work with automatic injection, provide one of:
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
```python
from sqlalchemy import TypeDecorator, String
class UpperCaseString(TypeDecorator):
impl = String
def process_bind_param(self, value, dialect):
return value.upper() if value else value
class UpperCaseType:
__sqlmodel_sa_type__ = UpperCaseString()
@classmethod
def __get_pydantic_core_schema__(cls, source_type, handler):
return core_schema.str_schema()
# Usage
from sqlmodels.mixin import UUIDTableBaseMixin
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCaseType # Automatically uses UpperCaseString()
```
**Option 2: Pydantic metadata with sa_type**:
```python
def __get_pydantic_core_schema__(self, source_type, handler):
return core_schema.json_or_python_schema(
json_schema=core_schema.str_schema(),
python_schema=core_schema.str_schema(),
metadata={'sa_type': UpperCaseString()}
)
```
**Option 3: Using Annotated**:
```python
from typing import Annotated
from sqlmodels.mixin import UUIDTableBaseMixin
UpperCase = Annotated[str, UpperCaseType()]
class MyModel(UUIDTableBaseMixin, table=True):
code: UpperCase
```
---
## Best Practices
### 1. Inherit from correct base classes
```python
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
# ✅ For non-table models (DTOs, requests, responses)
class UserBase(SQLModelBase):
name: str
# ✅ For table models with UUID primary key
class User(UserBase, UUIDTableBaseMixin, table=True):
email: str
# ✅ For table models with custom primary key
class LegacyUser(TableBaseMixin, table=True):
id: int = Field(primary_key=True)
username: str
```
### 2. Use docstrings for field descriptions
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Recommended
class User(UUIDTableBaseMixin, table=True):
name: str
"""User's display name"""
# ❌ Avoid
class User(UUIDTableBaseMixin, table=True):
name: str = Field(description="User's display name")
```
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
### 3. Leverage automatic sa_type injection
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Clean and recommended
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: NumpyVector[256, np.float32]
"""Voice embedding"""
# ❌ Verbose and unnecessary
class SpeakerInfo(UUIDTableBaseMixin, table=True):
embedding: np.ndarray = Field(
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
)
```
### 4. Follow polymorphic naming conventions
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
### 5. Separate Base, Parent, and Implementation classes
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
# ✅ Recommended structure
class ASRBase(SQLModelBase):
"""Pure data fields, no table"""
name: str
base_url: str
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract parent with table"""
@abstractmethod
async def transcribe(self, audio: bytes) -> str:
pass
class WhisperASR(ASR, table=True):
"""Concrete implementation"""
model_size: str
async def transcribe(self, audio: bytes) -> str:
# Implementation
pass
```
**Why?**
- Base class can be reused for DTOs
- Parent class defines the polymorphic hierarchy
- Implementation classes are clean and focused
---
## Troubleshooting
### Issue: ValueError: X has no matching SQLAlchemy type
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
```python
# ✅ Provide __sqlmodel_sa_type__
class MyType:
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
```
### Issue: Can't generate DDL for NullType()
**Symptoms**: Error during table creation saying a column has `NullType`.
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
**Solution**:
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
3. Verify type annotation is correct (not a string forward reference)
```python
from sqlmodels.mixin import UUIDTableBaseMixin
# ✅ Correct
class Model(UUIDTableBaseMixin, table=True):
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
# ❌ Wrong (string annotation)
class Model(UUIDTableBaseMixin, table=True):
data: 'NumpyVector[256, np.float32]' # sa_type lost
```
### Issue: Polymorphic identity conflicts
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
**Solution**:
1. Check that each concrete class has a unique identity
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
3. Manually specify identity if needed:
```python
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
pass
```
### Issue: Python 3.14 annotation errors
**Symptoms**: Errors related to `__annotations__` or type resolution.
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
1. Check for manual `__annotations__` manipulation (avoid it)
2. Ensure all types are properly imported
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
### Issue: Polymorphic and CRUD-related errors
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## Implementation Details
For developers modifying this module:
**Core files**:
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
**Key functions in this module**:
1. **`_resolve_annotations(attrs: dict[str, Any])`**
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
- Preserves `Annotated` metadata with `include_extras=True`
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
- Extracts SQLAlchemy type from type annotations
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
- Called by metaclass during class creation
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
- Global monkey-patch for SQLModel
- Checks `__sqlmodel_sa_type__` before falling back to original logic
- Handles custom types like `NumpyVector` and `Array`
4. **`__DeclarativeMeta.__new__()`**
- Processes class definition parameters
- Injects `sa_type` into field definitions
- Sets up `__mapper_args__`, `__table_args__`, etc.
- Handles Python 3.14 annotations via `get_type_hints()`
**Metaclass processing order**:
1. Check if class should be a table (`_has_table_mixin`)
2. Collect `__mapper_args__` from kwargs and explicit dict
3. Process `table_args`, `table_name`, `abstract` parameters
4. Resolve annotations using `get_type_hints()`
5. For each field, try to extract `sa_type` and inject into Field
6. Call parent metaclass with cleaned kwargs
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
---
## See Also
**Project Documentation**:
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
**External References**:
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)

View File

@@ -0,0 +1,12 @@
"""
SQLModel 基础模块
包含:
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
注意:
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
为了避免循环导入,此处不再重新导出它们
请直接从 sqlmodels.mixin 导入这些类
"""
from .sqlmodel_base import SQLModelBase

View File

@@ -0,0 +1,846 @@
import sys
import typing
from typing import Any, Mapping, get_args, get_origin, get_type_hints
from pydantic import ConfigDict
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined as Undefined
from sqlalchemy.orm import Mapped
from sqlmodel import Field, SQLModel
from sqlmodel.main import SQLModelMetaclass
# Python 3.14+ PEP 649支持
if sys.version_info >= (3, 14):
import annotationlib
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
import sqlmodel.main
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
def _patched_get_sqlalchemy_type(field):
"""
修复SQLModel的get_sqlalchemy_type函数处理Python 3.14的类型问题。
问题:
1. ForwardRef对象来自Relationship字段会导致issubclass错误
2. typing._GenericAlias对象如ClassVar[T])也会导致同样问题
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
4. Mapped类型在Python 3.14下可能出现在annotation中
5. Annotated类型可能包含sa_type metadata如Array[T]
6. 自定义类型如NumpyVector有__sqlmodel_sa_type__属性
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
解决:
- 优先检查field.metadata中的__get_pydantic_core_schema__Pydantic已处理的情况
- 检测__sqlmodel_sa_type__属性NumpyVector等
- 检测Relationship/ClassVar等返回None
- 对于Annotated类型尝试提取sa_type metadata
- 其他情况调用原始函数
"""
# 优先检查 field.metadataPydantic已处理Annotated类型的情况
# 当使用 Array[T] 或 Annotated[T, metadata] 时Pydantic会将metadata存储在这里
metadata = getattr(field, 'metadata', None)
if metadata:
# metadata是一个列表包含所有Annotated的元数据项
for metadata_item in metadata:
# 检查metadata_item是否有__get_pydantic_core_schema__方法
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata_item.__get_pydantic_core_schema__(None, None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Pydantic schema获取可能失败类型不匹配、缺少属性等
# 这是正常情况继续检查下一个metadata项
pass
annotation = getattr(field, 'annotation', None)
if annotation is not None:
# 优先检查 __sqlmodel_sa_type__ 属性
# 这处理 NumpyVector[dims, dtype] 等自定义类型
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 检查自定义类型如JSON100K的 __get_pydantic_core_schema__ 方法
# 这些类型在schema的metadata中定义sa_type
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
# 调用获取schema传None作为handler因为我们只需要metadata
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Schema获取失败继续其他检查
pass
anno_type_name = type(annotation).__name__
# ForwardRef: Relationship字段的annotation
if anno_type_name == 'ForwardRef':
return None
# AnnotatedAlias: 检查是否有sa_type metadata如Array[T]
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
from typing import get_origin, get_args
import typing
# 尝试提取Annotated的metadata
if hasattr(typing, 'get_args'):
args = get_args(annotation)
# args[0]是实际类型args[1:]是metadata
for metadata in args[1:]:
# 检查metadata是否有__get_pydantic_core_schema__方法
if hasattr(metadata, '__get_pydantic_core_schema__'):
try:
# 调用获取schema
schema = metadata.__get_pydantic_core_schema__(None, None)
# 检查schema中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError):
# Annotated metadata的schema获取可能失败
# 这是正常的类型检查过程继续检查下一个metadata
pass
# _GenericAlias或GenericAlias: typing泛型类型
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
from typing import get_origin
import typing
origin = get_origin(annotation)
# ClassVar必须跳过
if origin is typing.ClassVar:
return None
# list/dict/tuple/set等内置泛型如果字段没有明确的Field或Relationship也跳过
# 这通常意味着它是Relationship字段或类变量
if origin in (list, dict, tuple, set):
# 检查field_info是否存在且有意义
# Relationship字段会有特殊的field_info
field_info = getattr(field, 'field_info', None)
if field_info is None:
return None
# Mapped: SQLAlchemy 2.0的Mapped类型SQLModel不应该处理
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
# 检查类型名称和annotation的字符串表示
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
return None
# 检查annotation是否是Mapped类或其实例
try:
from sqlalchemy.orm import Mapped as SAMapped
# 检查origin对于Mapped[T]这种泛型)
from typing import get_origin
if get_origin(annotation) is SAMapped:
return None
# 检查类型本身
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
return None
except (ImportError, TypeError):
# 如果SQLAlchemy没有Mapped或检查失败继续
pass
# 其他情况正常处理
return _original_get_sqlalchemy_type(field)
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
# 在Python 3.14 + SQLModel组合下当子类如SMSBaoProvider继承父类如VerificationCodeProvider
# 父类的关系字段如server_config会在子类的model_fields中出现
# 但其default值错误地设置为InstrumentedAttribute对象而不是None
# 这导致实例化时尝试设置InstrumentedAttribute为字段值触发SQLAlchemy内部错误
import sqlmodel._compat as _compat
from sqlalchemy.orm import attributes as _sa_attributes
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
def _patched_sqlmodel_table_construct(self_instance, values):
"""
修复sqlmodel_table_construct跳过InstrumentedAttribute默认值
问题:
- 继承自polymorphic基类的表类如FishAudioTTS, SMSBaoProvider
- 其model_fields中的继承字段default值为InstrumentedAttribute
- 原函数尝试将InstrumentedAttribute设置为字段值
- SQLAlchemy无法处理抛出 '_sa_instance_state' 错误
解决:
- 只设置用户提供的值和非InstrumentedAttribute默认值
- InstrumentedAttribute默认值跳过让SQLAlchemy自己处理
"""
cls = type(self_instance)
# 收集要设置的字段值
fields_to_set = {}
for name, field in cls.model_fields.items():
# 如果用户提供了值,直接使用
if name in values:
fields_to_set[name] = values[name]
continue
# 否则检查默认值
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
continue
# 使用正常的默认值
if field.default is not Undefined:
fields_to_set[name] = field.default
elif field.default_factory is not None:
fields_to_set[name] = field.get_default(call_default_factory=True)
# 设置属性 - 只设置非InstrumentedAttribute值
for key, value in fields_to_set.items():
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
setattr(self_instance, key, value)
# 设置Pydantic内部属性
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
if not cls.__pydantic_root_model__:
_extra = None
if cls.model_config.get('extra') == 'allow':
_extra = {}
for k, v in values.items():
if k not in cls.model_fields:
_extra[k] = v
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
if cls.__pydantic_post_init__:
self_instance.model_post_init(None)
elif not cls.__pydantic_root_model__:
object.__setattr__(self_instance, '__pydantic_private__', None)
# 设置关系
for key in self_instance.__sqlmodel_relationships__:
value = values.get(key, Undefined)
if value is not Undefined:
setattr(self_instance, key, value)
return self_instance
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
else:
annotationlib = None
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
"""
从类型注解中提取SQLAlchemy类型。
支持以下形式:
1. NumpyVector[256, np.float32] - 直接使用类型有__sqlmodel_sa_type__属性
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
Args:
annotation: 字段的类型注解
Returns:
提取到的SQLAlchemy类型如果没有则返回None
"""
# 方法1直接检查类型本身是否有__sqlmodel_sa_type__属性
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
if hasattr(annotation, '__sqlmodel_sa_type__'):
return annotation.__sqlmodel_sa_type__
# 方法2检查是否为Annotated类型
if get_origin(annotation) is typing.Annotated:
# 获取元数据项(跳过第一个实际类型参数)
args = get_args(annotation)
if len(args) >= 2:
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
# 遍历元数据查找包含sa_type的项
for item in metadata_items:
# 检查元数据项是否有__sqlmodel_sa_type__属性
if hasattr(item, '__sqlmodel_sa_type__'):
return item.__sqlmodel_sa_type__
# 检查是否有__get_pydantic_core_schema__方法
if hasattr(item, '__get_pydantic_core_schema__'):
try:
# 调用该方法获取core schema
schema = item.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
# 检查schema的metadata中是否有sa_type
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# Pydantic core schema获取可能失败
# - TypeError: 参数不匹配
# - AttributeError: metadata不存在
# - KeyError: schema结构不符合预期
# - ValueError: 无效的类型定义
# 这是正常的类型探测过程继续检查下一个metadata项
pass
# 方法3检查类型本身是否有__get_pydantic_core_schema__
# 虽然NumpyVector已经在方法1处理但这是通用的fallback
if hasattr(annotation, '__get_pydantic_core_schema__'):
try:
schema = annotation.__get_pydantic_core_schema__(
annotation,
lambda x: None # 虚拟handler
)
if isinstance(schema, dict) and 'metadata' in schema:
sa_type = schema['metadata'].get('sa_type')
if sa_type is not None:
return sa_type
except (TypeError, AttributeError, KeyError, ValueError):
# 类型本身的schema获取失败
# 这是正常的fallback机制annotation可能不支持此协议
pass
return None
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
dict[str, Any],
dict[str, str],
Mapping[str, Any],
Mapping[str, Any],
]:
"""
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
metadata and custom types remain accessible. Forward references that cannot be
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
whole resolution process.
"""
raw_annotations = attrs.get('__annotations__') or {}
try:
base_annotations = dict(raw_annotations)
except TypeError:
base_annotations = {}
module_name = attrs.get('__module__')
module_globals: dict[str, Any]
if module_name and module_name in sys.modules:
module_globals = dict(sys.modules[module_name].__dict__)
else:
module_globals = {}
module_globals.setdefault('__builtins__', __builtins__)
localns: dict[str, Any] = dict(attrs)
try:
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
temp_cls.__module__ = module_name
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
evaluated = get_type_hints(
temp_cls,
globalns=module_globals,
localns=localns,
**extras_kw,
)
except (NameError, AttributeError, TypeError, RecursionError):
# get_type_hints可能失败的原因
# - NameError: 前向引用无法解析(类型尚未定义)
# - AttributeError: 模块或类型不存在
# - TypeError: 无效的类型注解
# - RecursionError: 循环依赖的类型定义
# 这是正常情况,回退到原始注解字符串
evaluated = base_annotations
return dict(evaluated), {}, module_globals, localns
def _evaluate_annotation_from_string(
field_name: str,
annotation_strings: dict[str, str],
current_type: Any,
globalns: Mapping[str, Any],
localns: Mapping[str, Any],
) -> Any:
"""
Attempt to re-evaluate the original annotation string for a field.
This is used as a fallback when the resolved annotation lost its metadata
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
"""
if not annotation_strings:
return current_type
expr = annotation_strings.get(field_name)
if not expr or not isinstance(expr, str):
return current_type
try:
return eval(expr, globalns, localns)
except (NameError, SyntaxError, AttributeError, TypeError):
# eval可能失败的原因
# - NameError: 类型名称在namespace中不存在
# - SyntaxError: 注解字符串有语法错误
# - AttributeError: 访问不存在的模块属性
# - TypeError: 无效的类型表达式
# 这是正常的fallback机制返回当前已解析的类型
return current_type
class __DeclarativeMeta(SQLModelMetaclass):
"""
一个智能的混合模式元类,它提供了灵活性和清晰度:
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
"""
_KNOWN_MAPPER_KEYS = {
"polymorphic_on",
"polymorphic_identity",
"polymorphic_abstract",
"version_id_col",
"concrete",
}
def __new__(cls, name, bases, attrs, **kwargs):
# 1. 约定优于配置:自动设置 table=True
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
if is_intended_as_table and 'table' not in kwargs:
kwargs['table'] = True
# 2. 智能合并 __mapper_args__
collected_mapper_args = {}
# 首先,处理明确的 mapper_args 字典 (优先级较低)
if 'mapper_args' in kwargs:
collected_mapper_args.update(kwargs.pop('mapper_args'))
# 其次,处理便捷的关键字参数 (优先级更高)
for key in cls._KNOWN_MAPPER_KEYS:
if key in kwargs:
# .pop() 获取值并移除,避免传递给父类
collected_mapper_args[key] = kwargs.pop(key)
# 如果收集到了任何 mapper 参数,则更新到类的属性中
if collected_mapper_args:
existing = attrs.get('__mapper_args__', {}).copy()
existing.update(collected_mapper_args)
attrs['__mapper_args__'] = existing
# 3. 处理其他明确的参数
if 'table_args' in kwargs:
attrs['__table_args__'] = kwargs.pop('table_args')
if 'table_name' in kwargs:
attrs['__tablename__'] = kwargs.pop('table_name')
if 'abstract' in kwargs:
attrs['__abstract__'] = kwargs.pop('abstract')
# 4. 从Annotated元数据中提取sa_type并注入到Field
# 重要必须在调用父类__new__之前处理因为SQLModel会消费annotations
#
# Python 3.14兼容性问题:
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃issubclass错误
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
# - 虽然PEP 749建议不修改__annotations__但这是修复SQLModel bug的必要措施
#
# 获取annotations的策略
# - Python 3.14+: 优先从__annotate__获取如果存在
# - fallback: 从__annotations__读取如果存在
# - 最终fallback: 空字典
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
if annotations:
attrs['__annotations__'] = annotations
if annotationlib is not None:
# 在Python 3.14中禁用descriptor转为普通dict
attrs['__annotate__'] = None
for field_name, field_type in annotations.items():
field_type = _evaluate_annotation_from_string(
field_name,
annotation_strings,
field_type,
eval_globals,
eval_locals,
)
# 跳过字符串或ForwardRef类型注解让SQLModel自己处理
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
continue
# 跳过特殊类型的字段
origin = get_origin(field_type)
# 跳过 ClassVar 字段 - 它们不是数据库字段
if origin is typing.ClassVar:
continue
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
if origin is Mapped:
continue
# 尝试从注解中提取sa_type
sa_type = _extract_sa_type_from_annotation(field_type)
if sa_type is not None:
# 检查字段是否已有Field定义
field_value = attrs.get(field_name, Undefined)
if field_value is Undefined:
# 没有Field定义创建一个新的Field并注入sa_type
attrs[field_name] = Field(sa_type=sa_type)
elif isinstance(field_value, FieldInfo):
# 已有Field定义检查是否已设置sa_type
# 注意:只有在未设置时才注入,尊重显式配置
# SQLModel使用Undefined作为"未设置"的标记
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
field_value.sa_type = sa_type
# 如果field_value是其他类型如默认值不处理
# SQLModel会在后续处理中将其转换为Field
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
result = super().__new__(cls, name, bases, attrs, **kwargs)
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
if kwargs.get('table', False):
for base in bases:
if hasattr(base, '__sqlmodel_relationships__'):
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
# 只继承子类没有重新定义的关系
if rel_name not in result.__sqlmodel_relationships__:
result.__sqlmodel_relationships__[rel_name] = rel_info
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
if hasattr(base, rel_name):
base_attr = getattr(base, rel_name)
setattr(result, rel_name, base_attr)
# 7. 检测:禁止子类重定义父类的 Relationship 字段
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
for base in bases:
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
for rel_name in parent_relationships:
# 检查当前类是否在 attrs 中重新定义了这个关系字段
if rel_name in attrs:
raise TypeError(
f"{name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'"
f"如需修改关系配置,请在父类中修改。"
)
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
# SQLModel 0.0.27 bug子类会错误地继承父类的 Relationship 字段到 model_fields
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema因为类型是
# Mapped[list['Character']] 这种前向引用Pydantic 无法解析,
# 导致 __pydantic_complete__ = False
#
# 修复策略:
# - 检查类的 __sqlmodel_relationships__ 属性
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
relationships = getattr(result, '__sqlmodel_relationships__', {})
if relationships:
model_fields = getattr(result, 'model_fields', {})
pydantic_fields = getattr(result, '__pydantic_fields__', {})
fields_removed = False
for rel_name in relationships:
if rel_name in model_fields:
del model_fields[rel_name]
fields_removed = True
if rel_name in pydantic_fields:
del pydantic_fields[rel_name]
fields_removed = True
# 如果移除了字段,重新构建 Pydantic 模式
# 注意:只在有字段被移除时才 rebuild避免不必要的开销
if fields_removed and hasattr(result, 'model_rebuild'):
result.model_rebuild(force=True)
return result
def __init__(
cls,
classname: str,
bases: tuple[type, ...],
dict_: dict[str, typing.Any],
**kw: typing.Any,
) -> None:
"""
重写 SQLModel 的 __init__ 以支持联表继承Joined Table Inheritance
SQLModel 原始行为:
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
- 这阻止了子类创建自己的表
修复逻辑:
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
- 强制调用 DeclarativeMeta.__init__ 来创建子表
"""
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
# 检查是否是表模型
if not is_table_model_class(cls):
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
return
# 检查是否有基类是表模型
base_is_table = any(is_table_model_class(base) for base in bases)
if not base_is_table:
# 没有基类是表模型,走正常的 SQLModel 流程
# 处理关系字段
cls._setup_relationships()
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
return
# 关键:检测联表继承场景
# 条件:
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
# 2. 当前类有字段带有 foreign_key 指向父表
current_tablename = getattr(cls, '__tablename__', None)
# 查找父表信息
parent_table = None
parent_tablename = None
for base in bases:
if is_table_model_class(base) and hasattr(base, '__tablename__'):
parent_tablename = base.__tablename__
break
# 检查是否有不同的 tablename
has_different_tablename = (
current_tablename is not None
and parent_tablename is not None
and current_tablename != parent_tablename
)
# 检查是否有外键字段指向父表的主键
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
# 而不是当前类的合并后的 model_fields
has_fk_to_parent = False
def _normalize_tablename(name: str) -> str:
"""标准化表名以进行比较(移除下划线,转小写)"""
return name.replace('_', '').lower()
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
"""检查 FK 字符串是否指向父表"""
if not fk_str or not parent_table:
return False
# FK 格式: "tablename.column" 或 "schema.tablename.column"
parts = fk_str.split('.')
if len(parts) >= 2:
fk_table = parts[-2] # 取倒数第二个作为表名
# 标准化比较(处理下划线差异)
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
return False
if has_different_tablename and parent_tablename:
# 首先检查当前类的 model_fields
for field_name, field_info in cls.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
# 如果没找到,检查直接基类的 model_fields解决 mixin 字段被覆盖的问题)
if not has_fk_to_parent:
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
has_fk_to_parent = True
break
if has_fk_to_parent:
break
is_joined_inheritance = has_different_tablename and has_fk_to_parent
if is_joined_inheritance:
# 联表继承:需要创建子表
# 修复外键字段:由于字段合并,外键信息可能丢失
# 需要从基类的 mixin 中找回外键信息,并重建列
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm.attributes import InstrumentedAttribute
# 联表继承:子表只应该有 idFK 到父表)+ 子类特有的字段
# 所有继承自祖先表的列都不应该在子表中重复创建
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
# 需要遍历整个 MRO因为可能是多级继承如 Tool -> Function -> GetWeatherFunction
ancestor_column_names: set[str] = set()
for ancestor in cls.__mro__:
if ancestor is cls:
continue # 跳过当前类
if is_table_model_class(ancestor):
try:
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
mapper = sa_inspect(ancestor)
for col in mapper.local_table.columns:
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
if col.name.startswith('_polymorphic'):
continue
ancestor_column_names.add(col.name)
except NoInspectionAvailable:
continue
# 找到子类自己定义的字段(不在父类中的)
child_own_fields: set[str] = set()
for field_name in cls.model_fields:
# 检查这个字段是否是在当前类直接定义的(不是继承的)
# 通过检查父类是否有这个字段来判断
is_inherited = False
for base in bases:
if hasattr(base, 'model_fields') and field_name in base.model_fields:
is_inherited = True
break
if not is_inherited:
child_own_fields.add(field_name)
# 从子类类属性中移除父表已有的列定义
# 这样 SQLAlchemy 就不会在子表中创建这些列
fk_field_name = None
for base in bases:
if hasattr(base, 'model_fields'):
for field_name, field_info in base.model_fields.items():
fk = getattr(field_info, 'foreign_key', None)
pk = getattr(field_info, 'primary_key', False)
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
fk_field_name = field_name
# 找到了外键字段,重建它
# 创建一个新的 Column 对象包含外键约束
new_col = Column(
field_name,
SA_UUID(as_uuid=True),
ForeignKey(fk),
primary_key=pk if pk else False
)
setattr(cls, field_name, new_col)
break
else:
continue
break
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
# 这防止 SQLAlchemy 在子表中创建重复列
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
for col_name in ancestor_column_names:
if col_name == fk_field_name:
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
if col_name == 'id':
continue # id 会被 FK 字段覆盖
if col_name in child_own_fields:
continue # 保留子类自己定义的字段
# 检查类属性是否是 Column 或 InstrumentedAttribute
if col_name in cls.__dict__:
attr = cls.__dict__[col_name]
# Column 对象或 InstrumentedAttribute 都需要删除
if isinstance(attr, (Column, InstrumentedAttribute)):
try:
delattr(cls, col_name)
except AttributeError:
pass
# 找到子类自己定义的关系(不在父类中的)
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
child_own_relationships: set[str] = set()
for rel_name in cls.__sqlmodel_relationships__:
is_inherited = False
for base in bases:
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
is_inherited = True
break
if not is_inherited:
child_own_relationships.add(rel_name)
# 只为子类自己定义的新关系调用关系设置
if child_own_relationships:
cls._setup_relationships(only_these=child_own_relationships)
# 强制调用 DeclarativeMeta.__init__
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
else:
# 非联表继承:单表继承或正常 Pydantic 模型
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
"""
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
Args:
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
如果为 None设置所有关系默认行为
"""
from sqlalchemy.orm import relationship, Mapped
from sqlalchemy import inspect
from sqlmodel.main import get_relationship_to
from typing import get_origin
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
# 如果指定了 only_these只设置这些关系
if only_these is not None and rel_name not in only_these:
continue
if rel_info.sa_relationship:
setattr(cls, rel_name, rel_info.sa_relationship)
continue
raw_ann = cls.__annotations__[rel_name]
origin: typing.Any = get_origin(raw_ann)
if origin is Mapped:
ann = raw_ann.__args__[0]
else:
ann = raw_ann
cls.__annotations__[rel_name] = Mapped[ann]
relationship_to = get_relationship_to(
name=rel_name, rel_info=rel_info, annotation=ann
)
rel_kwargs: dict[str, typing.Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.cascade_delete:
rel_kwargs["cascade"] = "all, delete-orphan"
if rel_info.passive_deletes:
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table")
if local_table is None:
raise RuntimeError(
f"Couldn't find secondary table for {rel_info.link_model}"
)
rel_kwargs["secondary"] = local_table
rel_args: list[typing.Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
setattr(cls, rel_name, rel_value)
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
"""此类必须和TableBase系列类搭配使用"""
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)

7
sqlmodels/color.py Normal file
View File

@@ -0,0 +1,7 @@
from .base import SQLModelBase
class ThemeResponse(SQLModelBase):
"""主题响应 DTO"""
pass

33
sqlmodels/database.py Normal file
View File

@@ -0,0 +1,33 @@
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from utils.conf import appmeta
from sqlalchemy.orm import sessionmaker
from typing import AsyncGenerator
ASYNC_DATABASE_URL = appmeta.database_url
engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL,
echo=appmeta.debug,
connect_args={
"check_same_thread": False
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
future=True,
# pool_size=POOL_SIZE,
# max_overflow=64,
)
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async with _async_session_factory() as session:
yield session
async def init_db(
url: str = ASYNC_DATABASE_URL
):
"""创建数据库结构"""
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

View File

@@ -0,0 +1,78 @@
from typing import AsyncGenerator, ClassVar
from loguru import logger
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
class DatabaseManager:
engine: ClassVar[AsyncEngine | None] = None
_async_session_factory: ClassVar[sessionmaker | None] = None
@classmethod
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
async with cls._async_session_factory() as session:
yield session
@classmethod
async def init(
cls,
database_url: str,
debug: bool = False,
):
"""
初始化数据库连接引擎。
:param database_url: 数据库连接URL
:param debug: 是否开启调试模式
"""
# 构建引擎参数
engine_kwargs: dict = {
'echo': debug,
'future': True,
}
if debug:
# Debug 模式使用 NullPool无连接池每次创建新连接
engine_kwargs['poolclass'] = NullPool
else:
# 生产模式使用 AsyncAdaptedQueuePool 连接池
engine_kwargs.update({
'poolclass': AsyncAdaptedQueuePool,
'pool_size': 40,
'max_overflow': 80,
'pool_timeout': 30,
'pool_recycle': 1800,
'pool_pre_ping': True,
})
# 只在需要时添加 connect_args
if database_url.startswith("sqlite"):
engine_kwargs['connect_args'] = {'check_same_thread': False}
cls.engine = create_async_engine(database_url, **engine_kwargs)
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
# 开发阶段直接 create_all 创建表结构
async with cls.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("数据库引擎初始化完成")
@classmethod
async def close(cls):
"""
优雅地关闭数据库连接引擎。
仅应在应用结束时调用。
"""
if cls.engine:
logger.info("正在关闭数据库连接引擎...")
await cls.engine.dispose()
logger.info("数据库连接引擎已成功关闭。")
else:
logger.info("数据库连接引擎未初始化,无需关闭。")

198
sqlmodels/download.py Normal file
View File

@@ -0,0 +1,198 @@
from enum import StrEnum
from typing import TYPE_CHECKING, Annotated
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, Index
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin, TableBaseMixin
if TYPE_CHECKING:
from .user import User
from .task import Task
from .node import Node
class DownloadStatus(StrEnum):
"""下载状态枚举"""
PREPARING = "preparing"
"""准备中"""
RUNNING = "running"
"""进行中"""
COMPLETED = "completed"
"""已完成"""
ERROR = "error"
"""错误"""
class DownloadType(StrEnum):
"""下载类型枚举"""
# [TODO] 补充具体下载类型
pass
# ==================== Aria2 信息模型 ====================
class DownloadAria2InfoBase(SQLModelBase):
"""Aria2下载信息基础模型"""
info_hash: Annotated[str | None, Field(max_length=40)] = None
"""InfoHashBT种子"""
piece_length: int = 0
"""分片大小"""
num_pieces: int = 0
"""分片数量"""
num_seeders: int = 0
"""做种人数"""
connections: int = 0
"""连接数"""
upload_speed: int = 0
"""上传速度bytes/s"""
upload_length: int = 0
"""已上传大小(字节)"""
error_code: str | None = None
"""错误代码"""
error_message: str | None = None
"""错误信息"""
class DownloadAria2Info(DownloadAria2InfoBase, SQLModelBase, table=True):
"""Aria2下载信息模型与Download一对一关联"""
download_id: UUID = Field(
foreign_key="download.id",
primary_key=True,
ondelete="CASCADE"
)
"""关联的下载任务UUID"""
# 反向关系
download: "Download" = Relationship(back_populates="aria2_info")
"""关联的下载任务"""
class DownloadAria2File(SQLModelBase, TableBaseMixin):
"""Aria2下载文件列表与Download一对多关联"""
download_id: UUID = Field(
foreign_key="download.id",
index=True,
ondelete="CASCADE"
)
"""关联的下载任务UUID"""
file_index: int = Field(ge=1)
"""文件索引从1开始"""
path: str
"""文件路径"""
length: int = 0
"""文件大小(字节)"""
completed_length: int = 0
"""已完成大小(字节)"""
is_selected: bool = True
"""是否选中下载"""
# 反向关系
download: "Download" = Relationship(back_populates="aria2_files")
"""关联的下载任务"""
# ==================== 主模型 ====================
class DownloadBase(SQLModelBase):
pass
class Download(DownloadBase, UUIDTableBaseMixin):
"""离线下载任务模型"""
__table_args__ = (
UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"),
Index("ix_download_user_status", "user_id", "status"),
)
status: DownloadStatus = Field(default=DownloadStatus.PREPARING, index=True)
"""下载状态"""
type: int = Field(default=0)
"""任务类型 [TODO] 待定义枚举"""
source: str
"""来源URL或标识"""
total_size: int = Field(default=0)
"""总大小(字节)"""
downloaded_size: int = Field(default=0)
"""已下载大小(字节)"""
g_id: str | None = Field(default=None, index=True)
"""Aria2 GID"""
speed: int = Field(default=0)
"""下载速度bytes/s"""
parent: str | None = Field(default=None, max_length=255)
"""父任务标识"""
error: str | None = Field(default=None)
"""错误信息"""
dst: str
"""目标存储路径"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
task_id: int | None = Field(
default=None,
foreign_key="task.id",
index=True,
ondelete="SET NULL"
)
"""关联的任务ID"""
node_id: int = Field(
foreign_key="node.id",
index=True,
ondelete="RESTRICT"
)
"""执行下载的节点ID"""
# 关系
aria2_info: DownloadAria2Info | None = Relationship(
back_populates="download",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""Aria2下载信息"""
aria2_files: list[DownloadAria2File] = Relationship(
back_populates="download",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""Aria2文件列表"""
user: "User" = Relationship(back_populates="downloads")
"""所属用户"""
task: "Task" = Relationship(back_populates="downloads")
"""关联的任务"""
node: "Node" = Relationship(back_populates="downloads")
"""执行下载的节点"""

299
sqlmodels/group.py Normal file
View File

@@ -0,0 +1,299 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, text
from .base import SQLModelBase
from .mixin import TableBaseMixin, UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
from .policy import Policy
# ==================== Base 模型 ====================
class GroupBase(SQLModelBase):
"""用户组基础字段,供数据库模型和 DTO 共享"""
name: str
"""用户组名称"""
class GroupOptionsBase(SQLModelBase):
"""用户组基础选项字段"""
share_download: bool = False
"""是否允许分享下载"""
share_free: bool = False
"""是否免积分获取需要积分的内容"""
relocate: bool = False
"""是否允许文件重定位"""
source_batch: int = 0
"""批量获取源地址数量"""
select_node: bool = False
"""是否允许选择节点"""
advance_delete: bool = False
"""是否允许高级删除"""
class GroupAllOptionsBase(GroupOptionsBase):
"""用户组完整选项字段,供 DTO 和数据库模型共享"""
archive_download: bool = False
"""是否允许打包下载"""
archive_task: bool = False
"""是否允许创建打包任务"""
webdav_proxy: bool = False
"""是否允许WebDAV代理"""
aria2: bool = False
"""是否允许使用aria2"""
redirected_source: bool = False
"""是否使用重定向源"""
# ==================== DTO 模型 ====================
class GroupCreateRequest(GroupAllOptionsBase):
"""创建用户组请求 DTO"""
name: str = Field(max_length=255)
"""用户组名称"""
max_storage: int = Field(default=0, ge=0)
"""最大存储空间字节0表示不限制"""
share_enabled: bool = False
"""是否允许创建分享"""
web_dav_enabled: bool = False
"""是否允许使用WebDAV"""
speed_limit: int = Field(default=0, ge=0)
"""速度限制 (KB/s), 0为不限制"""
source_batch: int = Field(default=0, ge=0)
"""批量获取源地址数量(覆盖基类以添加 ge 约束)"""
policy_ids: list[UUID] = []
"""关联的存储策略UUID列表"""
class GroupUpdateRequest(SQLModelBase):
"""更新用户组请求 DTO所有字段可选"""
name: str | None = Field(default=None, max_length=255)
"""用户组名称"""
max_storage: int | None = Field(default=None, ge=0)
"""最大存储空间(字节)"""
share_enabled: bool | None = None
"""是否允许创建分享"""
web_dav_enabled: bool | None = None
"""是否允许使用WebDAV"""
speed_limit: int | None = Field(default=None, ge=0)
"""速度限制 (KB/s)"""
# 用户组选项
share_download: bool | None = None
share_free: bool | None = None
relocate: bool | None = None
source_batch: int | None = None
select_node: bool | None = None
advance_delete: bool | None = None
archive_download: bool | None = None
archive_task: bool | None = None
webdav_proxy: bool | None = None
aria2: bool | None = None
redirected_source: bool | None = None
policy_ids: list[UUID] | None = None
"""关联的存储策略UUID列表"""
class GroupCoreBase(SQLModelBase):
"""用户组核心字段(从 Group 模型提取)"""
id: UUID
"""用户组UUID"""
name: str
"""用户组名称"""
max_storage: int = 0
"""最大存储空间(字节)"""
share_enabled: bool = False
"""是否允许创建分享"""
web_dav_enabled: bool = False
"""是否允许使用WebDAV"""
admin: bool = False
"""是否为管理员组"""
speed_limit: int = 0
"""速度限制 (KB/s)"""
class GroupDetailResponse(GroupCoreBase, GroupAllOptionsBase):
"""用户组详情响应 DTO"""
user_count: int = 0
"""用户数量"""
policy_ids: list[UUID] = []
"""关联的存储策略UUID列表"""
@classmethod
def from_group(
cls,
group: "Group",
user_count: int,
policies: list["Policy"],
) -> "GroupDetailResponse":
"""从 Group ORM 对象构建"""
opts = group.options
return cls(
# GroupCoreBase 字段(从 Group 模型提取)
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
# GroupAllOptionsBase 字段(从 GroupOptions 提取)
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
# 计算字段
user_count=user_count,
policy_ids=[p.id for p in policies],
)
class GroupListResponse(SQLModelBase):
"""用户组列表响应 DTO"""
groups: list["GroupDetailResponse"] = []
"""用户组列表"""
total: int = 0
"""总数"""
class GroupResponse(GroupBase, GroupOptionsBase):
"""用户组响应 DTO"""
id: UUID
"""用户组UUID"""
allow_share: bool = False
"""是否允许分享"""
allow_remote_download: bool = False
"""是否允许离线下载"""
allow_archive_download: bool = False
"""是否允许打包下载"""
compress: bool = False
"""是否允许压缩"""
webdav: bool = False
"""是否允许WebDAV"""
allow_webdav_proxy: bool = False
"""是否允许WebDAV代理"""
# ==================== 数据库模型 ====================
# GroupPolicyLink 定义在 policy.py 中以避免循环导入
from .policy import GroupPolicyLink
class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
"""用户组选项模型"""
group_id: UUID = Field(
foreign_key="group.id",
unique=True,
ondelete="CASCADE"
)
"""关联的用户组UUID"""
# 反向关系
group: "Group" = Relationship(back_populates="options")
class Group(GroupBase, UUIDTableBaseMixin):
"""用户组模型"""
name: str = Field(max_length=255, unique=True)
"""用户组名"""
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""最大存储空间(字节)"""
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否允许创建分享"""
web_dav_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否允许使用WebDAV"""
admin: bool = False
"""是否为管理员组"""
speed_limit: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""速度限制 (KB/s), 0为不限制"""
# 一对一关系:用户组选项
options: GroupOptions | None = Relationship(
back_populates="group",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"}
)
# 多对多关系:用户组可以关联多个存储策略
policies: list["Policy"] = Relationship(
back_populates="groups",
link_model=GroupPolicyLink,
)
# 关系:一个组可以有多个用户
users: list["User"] = Relationship(
back_populates="group",
sa_relationship_kwargs={"foreign_keys": "User.group_id"}
)
"""当前属于该组的用户列表"""
previous_users: list["User"] = Relationship(
back_populates="previous_group",
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
)
"""之前属于该组的用户列表(用于过期后恢复)"""
def to_response(self) -> "GroupResponse":
"""转换为响应 DTO"""
opts = self.options
return GroupResponse(
id=self.id,
name=self.name,
allow_share=self.share_enabled,
webdav=self.web_dav_enabled,
share_download=opts.share_download if opts else False,
share_free=opts.share_free if opts else False,
relocate=opts.relocate if opts else False,
source_batch=opts.source_batch if opts else 0,
select_node=opts.select_node if opts else False,
advance_delete=opts.advance_delete if opts else False,
allow_remote_download=opts.aria2 if opts else False,
allow_archive_download=opts.archive_download if opts else False,
allow_webdav_proxy=opts.webdav_proxy if opts else False,
)

326
sqlmodels/migration.py Normal file
View File

@@ -0,0 +1,326 @@
from .setting import Setting, SettingsType
from .color import ThemeResponse
from utils.conf.appmeta import BackendVersion
from utils.password.pwd import Password
from loguru import logger as log
async def migration() -> None:
"""
数据库迁移函数,初始化默认设置和用户组。
:return: None
"""
log.info('开始进行数据库初始化...')
await init_default_settings()
await init_default_policy()
await init_default_group()
await init_default_user()
log.info('数据库初始化结束')
default_settings: list[Setting] = [
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
Setting(name="default_group", value="", type=SettingsType.REGISTER),
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
Setting(name="smtpHost", value="smtp.yxqi.cn", type=SettingsType.MAIL),
Setting(name="smtpPort", value="25", type=SettingsType.MAIL),
Setting(name="replyTo", value="feedback@yxqi.cn", type=SettingsType.MAIL),
Setting(name="smtpUser", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
Setting(name="smtpPass", value="", type=SettingsType.MAIL),
Setting(name="maxEditSize", value="4194304", type=SettingsType.FILE_EDIT),
Setting(name="archive_timeout", value="60", type=SettingsType.TIMEOUT),
Setting(name="download_timeout", value="60", type=SettingsType.TIMEOUT),
Setting(name="preview_timeout", value="60", type=SettingsType.TIMEOUT),
Setting(name="doc_preview_timeout", value="60", type=SettingsType.TIMEOUT),
Setting(name="upload_credential_timeout", value="1800", type=SettingsType.TIMEOUT),
Setting(name="upload_session_timeout", value="86400", type=SettingsType.TIMEOUT),
Setting(name="slave_api_timeout", value="60", type=SettingsType.TIMEOUT),
Setting(name="onedrive_monitor_timeout", value="600", type=SettingsType.TIMEOUT),
Setting(name="share_download_session_timeout", value="2073600", type=SettingsType.TIMEOUT),
Setting(name="onedrive_callback_check", value="20", type=SettingsType.TIMEOUT),
Setting(name="aria2_call_timeout", value="5", type=SettingsType.TIMEOUT),
Setting(name="onedrive_chunk_retries", value="1", type=SettingsType.RETRY),
Setting(name="onedrive_source_timeout", value="1800", type=SettingsType.TIMEOUT),
Setting(name="reset_after_upload_failed", value="0", type=SettingsType.UPLOAD),
Setting(name="login_captcha", value="0", type=SettingsType.LOGIN),
Setting(name="reg_captcha", value="0", type=SettingsType.LOGIN),
Setting(name="email_active", value="0", type=SettingsType.REGISTER),
Setting(name="mail_activation_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #009688; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">激活{siteTitle}账户</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您注册{siteTitle},请点击下方按钮完成账户激活。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{activationUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #009688; margin: 0; border-color: #009688; border-style: solid; border-width: 10px 20px;">激活账户</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
Setting(name="forget_captcha", value="0", type=SettingsType.LOGIN),
Setting(name="mail_reset_pwd_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>重设密码</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
Setting(name=f"db_version_{BackendVersion}", value="installed", type=SettingsType.VERSION),
Setting(name="hot_share_num", value="10", type=SettingsType.SHARE),
Setting(name="gravatar_server", value="https://www.gravatar.com/", type=SettingsType.AVATAR),
Setting(name="defaultTheme", value="#3f51b5", type=SettingsType.BASIC),
Setting(name="themes", value=ThemeResponse().model_dump_json(), type=SettingsType.BASIC),
Setting(name="aria2_token", value="", type=SettingsType.ARIA2),
Setting(name="aria2_rpcurl", value="", type=SettingsType.ARIA2),
Setting(name="aria2_temp_path", value="", type=SettingsType.ARIA2),
Setting(name="aria2_options", value="{}", type=SettingsType.ARIA2),
Setting(name="aria2_interval", value="60", type=SettingsType.ARIA2),
Setting(name="max_worker_num", value="10", type=SettingsType.TASK),
Setting(name="max_parallel_transfer", value="4", type=SettingsType.TASK),
Setting(name="secret_key", value=Password.generate(256), type=SettingsType.AUTH),
Setting(name="temp_path", value="temp", type=SettingsType.PATH),
Setting(name="avatar_path", value="avatar", type=SettingsType.PATH),
Setting(name="avatar_size", value="2097152", type=SettingsType.AVATAR),
Setting(name="avatar_size_l", value="200", type=SettingsType.AVATAR),
Setting(name="avatar_size_m", value="130", type=SettingsType.AVATAR),
Setting(name="avatar_size_s", value="50", type=SettingsType.AVATAR),
Setting(name="home_view_method", value="icon", type=SettingsType.VIEW),
Setting(name="share_view_method", value="list", type=SettingsType.VIEW),
Setting(name="cron_garbage_collect", value="@hourly", type=SettingsType.CRON),
Setting(name="authn_enabled", value="0", type=SettingsType.AUTHN),
Setting(name="captcha_height", value="60", type=SettingsType.CAPTCHA),
Setting(name="captcha_width", value="240", type=SettingsType.CAPTCHA),
Setting(name="captcha_mode", value="3", type=SettingsType.CAPTCHA),
Setting(name="captcha_ComplexOfNoiseText", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_ComplexOfNoiseDot", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowHollowLine", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowNoiseDot", value="1", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowNoiseText", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA),
Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA),
Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
Setting(name="pwa_medium_icon", value="/static/img/logo192.png", type=SettingsType.PWA),
Setting(name="pwa_large_icon", value="/static/img/logo512.png", type=SettingsType.PWA),
Setting(name="pwa_display", value="standalone", type=SettingsType.PWA),
Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA),
Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA),
]
async def init_default_settings() -> None:
from .setting import Setting
from .database_connection import DatabaseManager
log.info('初始化设置...')
async for session in DatabaseManager.get_session():
# 检查是否已经存在版本设置
ver = await Setting.get(
session,
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
)
if ver and ver.value == "installed":
return
# 批量添加默认设置
await Setting.add(session, default_settings)
async def init_default_group() -> None:
from .group import Group, GroupOptions
from .policy import Policy, GroupPolicyLink
from .setting import Setting
from .database_connection import DatabaseManager
log.info('初始化用户组...')
async for session in DatabaseManager.get_session():
# 获取默认存储策略
default_policy = await Policy.get(session, Policy.name == "本地存储")
default_policy_id = default_policy.id if default_policy else None
# 未找到初始管理组时,则创建
if not await Group.get(session, Group.name == "管理员"):
admin_group = Group(
name="管理员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
admin=True,
)
admin_group_id = admin_group.id # 在 save 前保存 UUID
await admin_group.save(session)
await GroupOptions(
group_id=admin_group_id,
archive_download=True,
archive_task=True,
share_download=True,
share_free=True,
aria2=True,
select_node=True,
advance_delete=True,
).save(session)
# 关联默认存储策略
if default_policy_id:
session.add(GroupPolicyLink(
group_id=admin_group_id,
policy_id=default_policy_id,
))
await session.commit()
# 未找到初始注册会员时,则创建
if not await Group.get(session, Group.name == "注册会员"):
member_group = Group(
name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
)
member_group_id = member_group.id # 在 save 前保存 UUID
await member_group.save(session)
await GroupOptions(
group_id=member_group_id,
share_download=True,
).save(session)
# 关联默认存储策略
if default_policy_id:
session.add(GroupPolicyLink(
group_id=member_group_id,
policy_id=default_policy_id,
))
await session.commit()
# 更新 default_group 设置为注册会员组的 UUID
default_group_setting = await Setting.get(session, Setting.name == "default_group")
if default_group_setting:
default_group_setting.value = str(member_group_id)
await default_group_setting.save(session)
# 未找到初始游客组时,则创建
if not await Group.get(session, Group.name == "游客"):
guest_group = Group(
name="游客",
share_enabled=False,
web_dav_enabled=False,
)
guest_group_id = guest_group.id # 在 save 前保存 UUID
await guest_group.save(session)
await GroupOptions(
group_id=guest_group_id,
share_download=True,
).save(session)
# 游客组不关联存储策略(无法上传)
async def init_default_user() -> None:
from .user import User
from .group import Group
from .object import Object, ObjectType
from .policy import Policy
from .database_connection import DatabaseManager
log.info('初始化管理员用户...')
async for session in DatabaseManager.get_session():
# 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断)
admin_id_setting = await Setting.get(
session,
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
)
admin_user = None
if admin_id_setting and admin_id_setting.value:
from uuid import UUID
admin_user = await User.get(session, User.id == UUID(admin_id_setting.value))
if not admin_user:
# 获取管理员组
admin_group = await Group.get(session, Group.name == "管理员")
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 获取默认存储策略
default_policy = await Policy.get(session, Policy.name == "本地存储")
if not default_policy:
raise RuntimeError("默认存储策略不存在,无法创建管理员用户")
default_policy_id = default_policy.id # 在后续 save 前保存 UUID
# 生成管理员密码
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@disknext.local",
nickname="admin",
group_id=admin_group.id,
password=hashed_admin_password,
)
admin_user_id = admin_user.id # 在 save 前保存 UUID
await admin_user.save(session)
# 记录默认管理员 ID 到 Setting
await Setting(
name="default_admin_id",
value=str(admin_user_id),
type=SettingsType.AUTH,
).save(session)
# 为管理员创建根目录
await Object(
name="/",
type=ObjectType.FOLDER,
owner_id=admin_user_id,
parent_id=None,
policy_id=default_policy_id,
).save(session)
log.warning('请注意,账号密码仅显示一次,请妥善保管')
log.info(f'初始管理员邮箱: admin@disknext.local')
log.info(f'初始管理员密码: {admin_password}')
async def init_default_policy() -> None:
from .policy import Policy, PolicyType
from .database_connection import DatabaseManager
from service.storage import LocalStorageService
log.info('初始化默认存储策略...')
async for session in DatabaseManager.get_session():
# 检查默认存储策略是否存在
default_policy = await Policy.get(session, Policy.name == "本地存储")
if not default_policy:
local_policy = Policy(
name="本地存储",
type=PolicyType.LOCAL,
server="./data",
is_private=True,
max_size=0,
auto_rename=True,
dir_name_rule="{date}/{randomkey16}",
file_name_rule="{randomkey16}_{originname}",
)
local_policy = await local_policy.save(session)
# 创建物理存储目录
storage_service = LocalStorageService(local_policy)
await storage_service.ensure_base_directory()
log.info('已创建默认本地存储策略,存储目录:./data')

543
sqlmodels/mixin/README.md Normal file
View File

@@ -0,0 +1,543 @@
# SQLModel Mixin Module
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
## Module Overview
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
- **JWT Authentication**: Token generation and validation
- **Pagination & Sorting**: Standardized table view parameters
- **Response DTOs**: Consistent id/timestamp fields for API responses
## Module Structure
```
sqlmodels/mixin/
├── __init__.py # Module exports
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
└── jwt/ # JWT authentication
├── __init__.py
├── key.py # JWTKey database model
├── payload.py # JWTPayloadBase
├── manager.py # JWTManager singleton
├── auth.py # JWTAuthMixin
├── exceptions.py # JWT-related exceptions
└── responses.py # TokenResponse DTO
```
## Dependency Hierarchy
The module has a strict import order to avoid circular dependencies:
1. **polymorphic.py** - Only depends on `SQLModelBase`
2. **table.py** - Depends on `polymorphic.py`
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
4. **info_response.py** - Only depends on `SQLModelBase`
## Core Components
### 1. TableBaseMixin
Base mixin for database table models with integer primary keys.
**Features:**
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
- Automatic timestamp management (`created_at`, `updated_at`)
- Async relationship loading support (via `AsyncAttrs`)
- Pagination and sorting via `TableViewRequest`
- Polymorphic subclass loading support
**Fields:**
- `id: int | None` - Integer primary key (auto-increment)
- `created_at: datetime` - Record creation timestamp
- `updated_at: datetime` - Record update timestamp (auto-updated)
**Usage:**
```python
from sqlmodels.mixin import TableBaseMixin
from sqlmodels.base import SQLModelBase
class User(SQLModelBase, TableBaseMixin, table=True):
name: str
email: str
"""User email"""
# CRUD operations
async def example(session: AsyncSession):
# Add
user = User(name="Alice", email="alice@example.com")
user = await user.save(session)
# Get
user = await User.get(session, User.id == 1)
# Update
update_data = UserUpdateRequest(name="Alice Smith")
user = await user.update(session, update_data)
# Delete
await User.delete(session, user)
# Count
count = await User.count(session, User.is_active == True)
```
**Important Notes:**
- `save()` and `update()` return refreshed instances - **always use the return value**:
```python
# ✅ Correct
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device
```
### 2. UUIDTableBaseMixin
Extends `TableBaseMixin` with UUID primary keys instead of integers.
**Differences from TableBaseMixin:**
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
- `get_exist_one()` accepts `UUID` instead of `int`
**Usage:**
```python
from sqlmodels.mixin import UUIDTableBaseMixin
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
name: str
description: str | None = None
"""Character description"""
```
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
### 3. TableViewRequest
Standardized pagination and sorting parameters for LIST endpoints.
**Fields:**
- `offset: int | None` - Skip first N records (default: 0)
- `limit: int | None` - Return max N records (default: 50, max: 100)
- `desc: bool | None` - Sort descending (default: True)
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
**Usage with TableBaseMixin.get():**
```python
from dependencies import TableViewRequestDep
@router.get("/list")
async def list_characters(
session: SessionDep,
table_view: TableViewRequestDep
) -> list[Character]:
"""List characters with pagination and sorting"""
return await Character.get(
session,
fetch_mode="all",
table_view=table_view # Automatically handles pagination and sorting
)
```
**Manual usage:**
```python
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
```
**Backward Compatibility:**
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
### 4. PolymorphicBaseMixin
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
**Automatic Configuration:**
- Defines `_polymorphic_name: str` field (indexed)
- Sets `polymorphic_on='_polymorphic_name'`
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
**Methods:**
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
**Usage:**
```python
from abc import ABC, abstractmethod
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
"""Abstract base class for all tools"""
name: str
description: str
"""Tool description"""
@abstractmethod
async def execute(self, params: dict) -> dict:
"""Execute the tool"""
pass
```
**Why Single Underscore Prefix?**
- SQLAlchemy maps single-underscore fields to database columns
- Pydantic treats them as private (excluded from serialization)
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
### 5. create_subclass_id_mixin()
Factory function to create ID mixins for subclasses in joined table inheritance.
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
**Signature:**
```python
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
"""
Args:
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
Returns:
A mixin class containing id field (foreign key + primary key)
"""
```
**Usage:**
```python
from sqlmodels.mixin import create_subclass_id_mixin
# Create mixin for ASR subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
"""FunASR implementation"""
pass
```
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
### 6. AutoPolymorphicIdentityMixin
Automatically generates `polymorphic_identity` based on class name.
**Naming Convention:**
- Format: `{parent_identity}.{classname_lowercase}`
- If no parent identity exists, uses `{classname_lowercase}`
**Usage:**
```python
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
"""Base class for function-type tools"""
pass
# polymorphic_identity = 'function'
class GetWeatherFunction(Function, table=True):
"""Weather query function"""
pass
# polymorphic_identity = 'function.getweatherfunction'
```
**Manual Override:**
```python
class CustomTool(
Tool,
AutoPolymorphicIdentityMixin,
polymorphic_identity='custom_name', # Override auto-generated name
table=True
):
pass
```
### 7. JWTAuthMixin
Provides JWT token generation and validation for entity classes (User, Client).
**Methods:**
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
**Requirements:**
Subclasses must define:
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
**Usage:**
```python
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
JWTPayload = UserJWTPayload # Define payload model
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
email: str
is_admin: bool = False
is_active: bool = True
"""User active status"""
# Generate token
async def login(session: AsyncSession, user: User) -> str:
token = await user.issue_jwt(session)
return token
# Validate token
async def verify(session: AsyncSession, token: str) -> User:
user = await User.from_jwt(session, token)
return user
```
### 8. Response DTO Mixins
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
**Available Mixins:**
- `IntIdInfoMixin` - Integer ID field
- `UUIDIdInfoMixin` - UUID ID field
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
**Usage:**
```python
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
"""Character response DTO with id and timestamps"""
pass # Inherits id, created_at, updated_at from mixin
```
## Complete Joined Table Inheritance Example
Here's a complete example demonstrating polymorphic inheritance:
```python
from abc import ABC, abstractmethod
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import (
UUIDTableBaseMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
)
# 1. Define Base class (fields only, no table)
class ASRBase(SQLModelBase):
name: str
"""Configuration name"""
base_url: str
"""Service URL"""
# 2. Define abstract parent class (with table)
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
"""Abstract base class for ASR configurations"""
# PolymorphicBaseMixin automatically provides:
# - _polymorphic_name field
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True (when ABC with abstract methods)
@abstractmethod
async def transcribe(self, pcm_data: bytes) -> str:
"""Transcribe audio to text"""
pass
# 3. Create ID Mixin for second-level subclasses
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. Create second-level abstract class (if needed)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
"""FunASR abstract base (may have multiple implementations)"""
pass
# polymorphic_identity = 'funasr'
# 5. Create concrete implementation classes
class FunASRLocal(FunASR, table=True):
"""FunASR local deployment"""
# polymorphic_identity = 'funasr.funasrlocal'
async def transcribe(self, pcm_data: bytes) -> str:
# Implementation...
return "transcribed text"
# 6. Get all concrete subclasses (for selectin_polymorphic)
concrete_asrs = ASR.get_concrete_subclasses()
# Returns: [FunASRLocal, ...]
```
## Import Guidelines
**Standard Import:**
```python
from sqlmodels.mixin import (
TableBaseMixin,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
TableViewRequest,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin,
JWTAuthMixin,
UUIDIdDatetimeInfoMixin,
now,
now_date,
)
```
**Backward Compatibility:**
Some exports are also available from `sqlmodels.base` for backward compatibility:
```python
# Legacy import path (still works)
from sqlmodels.base import UUIDTableBase, TableViewRequest
# Recommended new import path
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
```
## Best Practices
### 1. Mixin Order Matters
**Correct Order:**
```python
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
pass
```
**Wrong Order:**
```python
# ❌ ID Mixin not first - won't override parent's id field
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
pass
```
### 2. Always Use Return Values from save() and update()
```python
# ✅ Correct - use returned instance
device = await device.save(session)
return device
# ❌ Wrong - device is expired after commit
await device.save(session)
return device # AttributeError when accessing fields
```
### 3. Prefer table_view Over Manual Pagination
```python
# ✅ Recommended - consistent across all endpoints
characters = await Character.get(
session,
fetch_mode="all",
table_view=table_view
)
# ⚠️ Works but not recommended - manual parameter management
characters = await Character.get(
session,
fetch_mode="all",
offset=0,
limit=20,
order_by=[desc(Character.created_at)]
)
```
### 4. Polymorphic Loading for Many Subclasses
```python
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
)
# For fewer subclasses, specify the list explicitly
tool_set = await ToolSet.get(
session,
ToolSet.id == tool_set_id,
load=ToolSet.tools,
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
)
```
### 5. Response DTOs Should Inherit Base Classes
```python
# ✅ Correct - inherits from CharacterBase
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
pass
# ❌ Wrong - doesn't inherit from CharacterBase
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
name: str # Duplicated field definition
description: str | None = None
```
**Reason:** Inheriting from Base classes ensures:
- Type checking via `isinstance(obj, XxxBase)`
- Consistency across related DTOs
- Future field additions automatically propagate
### 6. Use Specific Types, Not Containers
```python
# ✅ Correct - specific DTO for config updates
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
weather_api_key: str | None = None
default_location: str | None = None
"""Default location"""
# ❌ Wrong - lose type safety
class ToolUpdateRequest(SQLModelBase):
config: dict[str, Any] # No field validation
```
## Type Variables
```python
from sqlmodels.mixin import T, M
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
M = TypeVar("M", bound="SQLModel") # For update() method
```
## Utility Functions
```python
from sqlmodels.mixin import now, now_date
# Lambda functions for default factories
now = lambda: datetime.now()
now_date = lambda: datetime.now().date()
```
## Related Modules
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
- **sqlmodels.user** - User model with JWT authentication
- **sqlmodels.client** - Client model with JWT authentication
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
## Additional Resources
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
- `CLAUDE.md` - Project coding standards and design philosophy
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)

View File

@@ -0,0 +1,62 @@
"""
SQLModel Mixin模块
提供各种Mixin类供SQLModel实体使用。
包含:
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- optimistic_lock: 乐观锁OptimisticLockMixin, OptimisticLockError
- table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- relation_preload: 关系预加载RelationPreloadMixin, requires_relations
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要,避免循环导入:
1. polymorphic只依赖 SQLModelBase
2. optimistic_lock只依赖 SQLAlchemy
3. table依赖 polymorphic 和 optimistic_lock
4. relation_preload只依赖 SQLModelBase
注意jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
"""
# polymorphic 必须先导入
from .polymorphic import (
AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin,
create_subclass_id_mixin,
register_sti_column_properties_for_all_subclasses,
register_sti_columns_for_all_subclasses,
)
# optimistic_lock 只依赖 SQLAlchemy必须在 table 之前
from .optimistic_lock import (
OptimisticLockError,
OptimisticLockMixin,
)
# table 依赖 polymorphic 和 optimistic_lock
from .table import (
ListResponse,
PaginationRequest,
T,
TableBaseMixin,
TableViewRequest,
TimeFilterRequest,
UUIDTableBaseMixin,
now,
now_date,
)
# relation_preload 只依赖 SQLModelBase
from .relation_preload import (
RelationPreloadMixin,
requires_relations,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import (
DatetimeInfoMixin,
IntIdDatetimeInfoMixin,
IntIdInfoMixin,
UUIDIdDatetimeInfoMixin,
UUIDIdInfoMixin,
)

View File

@@ -0,0 +1,46 @@
"""
InfoResponse DTO Mixin模块
提供用于InfoResponse类型DTO的Mixin统一定义id/created_at/updated_at字段。
设计说明:
- 这些Mixin用于**响应DTO**,不是数据库表
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
- TableBase中的id=None和default_factory=now是正确的入库前为None数据库生成
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
"""
from datetime import datetime
from uuid import UUID
from sqlmodels.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase):
"""整数ID响应mixin - 用于InfoResponse DTO"""
id: int
"""记录ID"""
class UUIDIdInfoMixin(SQLModelBase):
"""UUID ID响应mixin - 用于InfoResponse DTO"""
id: UUID
"""记录ID"""
class DatetimeInfoMixin(SQLModelBase):
"""时间戳响应mixin - 用于InfoResponse DTO"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
"""整数ID + 时间戳响应mixin"""
pass
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
"""UUID ID + 时间戳响应mixin"""
pass

View File

@@ -0,0 +1,90 @@
"""
乐观锁 Mixin
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
乐观锁适用场景:
- 涉及"状态转换"的表(如:待支付 -> 已支付)
- 涉及"数值变动"的表(如:余额、库存)
不适用场景:
- 日志表、纯插入表、低价值统计表
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
使用示例:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
status: OrderStatusEnum
amount: Decimal
# save/update 时自动检查版本号
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
try:
order = await order.save(session)
except OptimisticLockError as e:
# 处理冲突:重新查询并重试,或报错给用户
l.warning(f"乐观锁冲突: {e}")
"""
from typing import ClassVar
from sqlalchemy.orm.exc import StaleDataError
class OptimisticLockError(Exception):
"""
乐观锁冲突异常
当 save/update 操作检测到版本号不匹配时抛出。
这意味着在读取和写入之间,其他事务已经修改了该记录。
Attributes:
model_class: 发生冲突的模型类名
record_id: 记录 ID如果可用
expected_version: 期望的版本号(如果可用)
original_error: 原始的 StaleDataError
"""
def __init__(
self,
message: str,
model_class: str | None = None,
record_id: str | None = None,
expected_version: int | None = None,
original_error: StaleDataError | None = None,
):
super().__init__(message)
self.model_class = model_class
self.record_id = record_id
self.expected_version = expected_version
self.original_error = original_error
class OptimisticLockMixin:
"""
乐观锁 Mixin
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
session.commit() 会抛出 StaleDataError被 save/update 方法捕获并转换为 OptimisticLockError。
原理:
1. 每条记录有一个 version 字段,初始值为 0
2. 每次 UPDATE 时SQLAlchemy 生成的 SQL 类似:
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
3. 如果 WHERE 条件不匹配version 已被其他事务修改),
UPDATE 影响 0 行SQLAlchemy 抛出 StaleDataError
继承顺序:
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
...
配套重试:
如果加了乐观锁,业务层需要处理 OptimisticLockError
- 报错给用户:"数据已被修改,请刷新后重试"
- 自动重试:重新查询最新数据并再次尝试
"""
_has_optimistic_lock: ClassVar[bool] = True
"""标记此类启用了乐观锁"""
version: int = 0
"""乐观锁版本号,每次更新自动递增"""

View File

@@ -0,0 +1,710 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from loguru import logger as l
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import Column, String, inspect
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from sqlmodel.main import get_column_from_field
from sqlmodels.base.sqlmodel_base import SQLModelBase
# 用于延迟注册 STI 子类列的队列
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
_sti_subclasses_to_register: list[type] = []
def register_sti_columns_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
此函数应在 configure_mappers() 之前调用。
将 STI 子类的字段添加到父表的 metadata 中。
同时修复被 Column 对象污染的 model_fields。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_columns()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
# 修复被 Column 对象污染的 model_fields
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
try:
_fix_polluted_model_fields(cls)
except Exception as e:
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
def register_sti_column_properties_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类添加列属性到 mapper第二阶段
此函数应在 configure_mappers() 之后调用。
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_column_properties()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
# 清空队列
_sti_subclasses_to_register.clear()
def _fix_polluted_model_fields(cls: type) -> None:
"""
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
当 SQLModel 类继承有表的父类时SQLAlchemy 会在类上创建 InstrumentedAttribute
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
时错误地使用这些 SQLAlchemy 对象作为默认值。
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
:param cls: 要修复的类
"""
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 或 Column 污染的
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
return field_info
return None
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute 或 Column
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
# 被继承到有表的子类时SQLModel 会创建一个 Column 对象替换原始默认值
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, True且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制 SQLModel 特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的 default_factory 丢失问题。
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin并支持STI子类列注册
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
**重要:数据库迁移注意事项**
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
例如,对于以下继承链::
LLM (polymorphic_on='_polymorphic_name')
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
迁移脚本中设置 _polymorphic_name 时::
# ❌ 错误:缺少父类前缀
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
# ✅ 正确:包含完整的继承链前缀
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
**STI单表继承支持**
当子类与父类共用同一张表STI模式此Mixin会自动将子类的新字段
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
注册到父表的问题。
Example (JTI):
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Example (STI):
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
... user_id: UUID
...
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
... # polymorphic_identity 会自动设置为 'pendingfile'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
- STI模式下新字段会在类定义时自动添加到父表的metadata中
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity并处理STI列注册
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
# 注册 STI 子类列的延迟执行
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
_sti_subclasses_to_register.append(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复 STI 继承中子类字段被 Column 对象污染的问题。
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
@classmethod
def _register_sti_columns(cls) -> None:
"""
将STI子类的新字段注册到父表的列定义中
检测当前类是否是STI子类与父类共用同一张表
如果是则将子类定义的新字段添加到父表的metadata中。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_fields: set[str] = set()
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
# 收集父类的所有字段名
if hasattr(base, 'model_fields'):
parent_fields.update(base.model_fields.keys())
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列注册到父表
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列注册
# 获取当前类的新字段(不在父类中的字段)
if not hasattr(cls, 'model_fields'):
return
existing_columns = {col.name for col in parent_table.columns}
for field_name, field_info in cls.model_fields.items():
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段和ClassVar
if field_name.startswith('_'):
continue
# 跳过已存在的列
if field_name in existing_columns:
continue
# 使用 SQLModel 的内置 API 创建列
try:
column = get_column_from_field(field_info)
column.name = field_name
column.key = field_name
# STI子类字段在数据库层面必须可空因为其他子类的行不会有这些字段的值
# Pydantic层面的约束仍然有效创建特定子类时会验证必填字段
column.nullable = True
# 将列添加到父表
parent_table.append_column(column)
except Exception as e:
l.warning(f"{cls.__name__} 创建列 {field_name} 失败: {e}")
@classmethod
def _register_sti_column_properties(cls) -> None:
"""
将 STI 子类的列作为 ColumnProperty 添加到 mapper
此方法在 configure_mappers() 之后调用,将已添加到表中的列
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
这确保了查询父类时SELECT 语句包含所有 STI 子类的列,
避免在响应序列化时触发懒加载MissingGreenlet 错误)。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_class = None
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
parent_class = base
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列属性注册
# 获取子类和父类的 mapper
child_mapper = inspect(cls).mapper
parent_mapper = inspect(parent_class).mapper
local_table = child_mapper.local_table
# 查找父类的所有字段名
parent_fields: set[str] = set()
if hasattr(parent_class, 'model_fields'):
parent_fields.update(parent_class.model_fields.keys())
if not hasattr(cls, 'model_fields'):
return
# 获取两个 mapper 已有的列属性
child_existing_props = {p.key for p in child_mapper.column_attrs}
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
for field_name in cls.model_fields:
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段
if field_name.startswith('_'):
continue
# 检查表中是否有这个列
if field_name not in local_table.columns:
continue
column = local_table.columns[field_name]
# 添加到子类的 mapper如果尚不存在
if field_name not in child_existing_props:
try:
prop = ColumnProperty(column)
child_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"{cls.__name__} 添加列属性 {field_name} 失败: {e}")
# 同时添加到父类的 mapper确保查询父类时 SELECT 包含所有 STI 子类的列)
if field_name not in parent_existing_props:
try:
prop = ColumnProperty(column)
parent_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def _is_joined_table_inheritance(cls) -> bool:
"""
检测当前类是否使用联表继承Joined Table Inheritance
通过检查子类是否有独立的表来判断:
- JTI: 子类有独立的 local_table与父类不同
- STI: 子类与父类共用同一个 local_table
:return: True 表示 JTIFalse 表示 STI 或无子类
"""
mapper = inspect(cls)
base_table_name = mapper.local_table.name
# 检查所有直接子类
for subclass in cls.__subclasses__():
sub_mapper = inspect(subclass)
# 如果任何子类有不同的表名,说明是 JTI
if sub_mapper.local_table.name != base_table_name:
return True
return False
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

View File

@@ -0,0 +1,470 @@
"""
关系预加载 Mixin
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
设计原则:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 查询最优:相同关系只查询一次,不同关系增量查询
- 零侵入:调用方无需任何改动
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
使用方式:
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
kling_video_generator: KlingO1Generator = Relationship(...)
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# 自动加载,可以安全访问
price = self.kling_video_generator.kling_o1.pro_price_per_second
...
# 调用方 - 无需任何改动
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
await tool._call(...) # 关系相同则跳过,否则增量加载
支持 AsyncGenerator
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # 装饰器正确处理 async generator
"""
import inspect as python_inspect
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from loguru import logger as l
from sqlalchemy import inspect as sa_inspect
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
P = ParamSpec('P')
R = TypeVar('R')
def _extract_session(
func: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> AsyncSession | None:
"""
从方法参数中提取 AsyncSession
按以下顺序查找:
1. kwargs 中名为 'session' 的参数
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
3. kwargs 中类型为 AsyncSession 的参数
"""
# 1. 优先从 kwargs 查找
if 'session' in kwargs:
return kwargs['session']
# 2. 从函数签名定位位置参数
try:
sig = python_inspect.signature(func)
param_names = list(sig.parameters.keys())
if 'session' in param_names:
# 计算位置(减去 self
idx = param_names.index('session') - 1
if 0 <= idx < len(args):
return args[idx]
except (ValueError, TypeError):
pass
# 3. 遍历 kwargs 找 AsyncSession 类型
for value in kwargs.values():
if isinstance(value, AsyncSession):
return value
return None
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
"""
检查对象的关系是否已加载(独立函数版本)
Args:
obj: 要检查的对象
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(obj)
return rel_name not in state.unloaded
except Exception:
return False
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
"""
在类中查找指向目标类的关系属性名
Args:
from_class: 源类
to_class: 目标类
Returns:
关系属性名,如果找不到则返回 None
Example:
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
# 返回 'kling_video_generator'
"""
for attr_name in dir(from_class):
try:
attr = getattr(from_class, attr_name, None)
if attr is None:
continue
# 检查是否是 SQLAlchemy InstrumentedAttribute关系属性
# parent.class_ 是关系所在的类property.mapper.class_ 是关系指向的目标类
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
target_class = attr.property.mapper.class_
if target_class == to_class:
return attr_name
except AttributeError:
continue
return None
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
装饰器:声明方法需要的关系,自动按需增量加载
参数格式:
- 字符串:本类属性名,如 'kling_video_generator'
- RelationshipInfo外部类属性如 KlingO1Generator.kling_o1
行为:
- 方法调用时自动检查关系是否已加载
- 未加载的关系会被增量加载(单次查询)
- 已加载的关系直接跳过
支持:
- 普通 async 方法:`async def cost(...) -> ToolCost`
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
Example:
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# self.kling_video_generator.kling_o1 已自动加载
...
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # AsyncGenerator 正确处理
验证:
- 字符串格式的关系名在类创建时__init_subclass__验证
- 拼写错误会在导入时抛出 AttributeError
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
# 检测是否是 async generator 函数
is_async_gen = python_inspect.isasyncgenfunction(func)
if is_async_gen:
# AsyncGenerator 需要特殊处理wrapper 也必须是 async generator
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
# 委托给原始 async generator逐个 yield 值
async for item in func(self, *args, **kwargs):
yield item # type: ignore
else:
# 普通 async 函数await 并返回结果
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
return await func(self, *args, **kwargs)
# 保存关系声明供验证和内省使用
wrapper._required_relations = relations # type: ignore
return wrapper
return decorator
class RelationPreloadMixin:
"""
关系预加载 Mixin
提供按需增量加载能力,确保 SQL 查询数理论最优。
特性:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 原地更新:直接修改 self无需替换实例
- 导入时验证:字符串关系名在类创建时验证
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
"""
def __init_subclass__(cls, **kwargs) -> None:
"""类创建时验证所有 @requires_relations 声明"""
super().__init_subclass__(**kwargs)
# 收集类及其父类的所有注解(包含普通字段)
all_annotations: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__annotations__'):
all_annotations.update(klass.__annotations__.keys())
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__
sqlmodel_relationships: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__sqlmodel_relationships__'):
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
# 合并所有可用的属性名
all_available_names = all_annotations | sqlmodel_relationships
for method_name in dir(cls):
if method_name.startswith('__'):
continue
try:
method = getattr(cls, method_name, None)
except AttributeError:
continue
if method is None or not hasattr(method, '_required_relations'):
continue
# 验证字符串格式的关系名
for spec in method._required_relations:
if isinstance(spec, str):
# 检查注解、Relationship 或已有属性
if spec not in all_available_names and not hasattr(cls, spec):
raise AttributeError(
f"{cls.__name__}.{method_name} 声明了关系 '{spec}'"
f"{cls.__name__} 没有此属性"
)
def _is_relation_loaded(self, rel_name: str) -> bool:
"""
检查关系是否真正已加载(基于 SQLAlchemy inspect
使用 SQLAlchemy 的 inspect 检测真实加载状态,
自动处理 commit 导致的 expire 问题。
Args:
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(self)
# unloaded 包含未加载的关系属性名
return rel_name not in state.unloaded
except Exception:
# 对象可能未被 SQLAlchemy 管理
return False
async def _ensure_relations_loaded(
self,
session: AsyncSession,
relations: tuple[str | RelationshipInfo, ...],
) -> None:
"""
确保指定关系已加载,只加载未加载的部分
基于 SQLAlchemy inspect 检测真实状态,自动处理:
- 首次访问的关系
- commit 后 expire 的关系
- 嵌套关系(如 KlingO1Generator.kling_o1
Args:
session: 数据库会话
relations: 需要的关系规格
"""
# 找出真正未加载的关系(基于 SQLAlchemy inspect
to_load: list[str | RelationshipInfo] = []
# 区分直接关系和嵌套关系的 key
direct_keys: set[str] = set() # 本类的直接关系属性名
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
for rel in relations:
if isinstance(rel, str):
# 直接关系:检查本类的关系是否已加载
if not self._is_relation_loaded(rel):
to_load.append(rel)
direct_keys.add(rel)
else:
# 嵌套关系InstrumentedAttribute如 KlingO1Generator.kling_o1
# 1. 查找指向父类的关系属性
parent_class = rel.parent.class_
parent_attr = _find_relation_to_class(self.__class__, parent_class)
if parent_attr is None:
# 找不到路径,可能是配置错误,但仍尝试加载
l.warning(
f"无法找到从 {self.__class__.__name__}{parent_class.__name__} 的关系路径,"
f"无法检查 {rel.key} 是否已加载"
)
to_load.append(rel)
continue
# 2. 检查父对象是否已加载
if not self._is_relation_loaded(parent_attr):
# 父对象未加载,需要同时加载父对象和嵌套关系
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
else:
# 3. 父对象已加载,检查嵌套关系是否已加载
parent_obj = getattr(self, parent_attr)
if not _is_obj_relation_loaded(parent_obj, rel.key):
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
# 因为 _build_load_chains 需要完整的链来构建 selectinload
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
if not to_load:
return # 全部已加载,跳过
# 构建 load 参数
load_options = self._specs_to_load_options(to_load)
if not load_options:
return
# 安全地获取主键值(避免触发懒加载)
state = sa_inspect(self)
pk_tuple = state.key[1] if state.key else None
if pk_tuple is None:
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
return
# 主键是元组,取第一个值(假设单列主键)
pk_value = pk_tuple[0]
# 单次查询加载缺失的关系
fresh = await self.__class__.get(
session,
self.__class__.id == pk_value,
load=load_options,
)
if fresh is None:
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
return
# 原地复制到 self只复制直接关系嵌套关系通过父关系自动可访问
all_direct_keys = direct_keys | nested_parent_keys
for key in all_direct_keys:
value = getattr(fresh, key, None)
object.__setattr__(self, key, value)
def _specs_to_load_options(
self,
specs: list[str | RelationshipInfo],
) -> list[RelationshipInfo]:
"""
将关系规格转换为 load 参数
- 字符串 → cls.{name}
- RelationshipInfo → 直接使用
"""
result: list[RelationshipInfo] = []
for spec in specs:
if isinstance(spec, str):
rel = getattr(self.__class__, spec, None)
if rel is not None:
result.append(rel)
else:
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
else:
result.append(spec)
return result
# ==================== 可选的手动预加载 API ====================
@classmethod
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
"""
获取指定方法声明的关系(用于外部预加载场景)
Args:
method_name: 方法名
Returns:
RelationshipInfo 列表
"""
method = getattr(cls, method_name, None)
if method is None or not hasattr(method, '_required_relations'):
return []
result: list[RelationshipInfo] = []
for spec in method._required_relations:
if isinstance(spec, str):
rel = getattr(cls, spec, None)
if rel:
result.append(rel)
else:
result.append(spec)
return result
@classmethod
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
"""
获取多个方法的关系并去重(用于批量预加载场景)
Args:
method_names: 方法名列表
Returns:
去重后的 RelationshipInfo 列表
"""
seen: set[str] = set()
result: list[RelationshipInfo] = []
for method_name in method_names:
for rel in cls.get_relations_for_method(method_name):
key = rel.key
if key not in seen:
seen.add(key)
result.append(rel)
return result
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
"""
手动预加载指定方法的关系(可选优化 API
当需要确保在调用方法前完成所有加载时使用。
通常情况下不需要调用此方法,装饰器会自动处理。
Args:
session: 数据库会话
method_names: 方法名列表
Returns:
self支持链式调用
Example:
# 可选:显式预加载(通常不需要)
tool = await tool.preload_for(session, 'cost', '_call')
"""
all_relations: list[str | RelationshipInfo] = []
for method_name in method_names:
method = getattr(self.__class__, method_name, None)
if method and hasattr(method, '_required_relations'):
all_relations.extend(method._required_relations)
if all_relations:
await self._ensure_relations_loaded(session, tuple(all_relations))
return self

1247
sqlmodels/mixin/table.py Normal file

File diff suppressed because it is too large Load Diff

123
sqlmodels/model_base.py Normal file
View File

@@ -0,0 +1,123 @@
import uuid
from datetime import datetime
from enum import StrEnum
from sqlmodel import Field
from .base import SQLModelBase
class ResponseBase(SQLModelBase):
"""通用响应模型"""
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""实例ID用于标识请求的唯一性"""
# ==================== Admin Summary DTO ====================
class MetricsSummary(SQLModelBase):
"""站点统计摘要"""
dates: list[datetime]
"""日期列表"""
files: list[int]
"""每日新增文件数"""
users: list[int]
"""每日新增用户数"""
shares: list[int]
"""每日新增分享数"""
file_total: int
"""文件总数"""
user_total: int
"""用户总数"""
share_total: int
"""分享总数"""
entities_total: int
"""实体总数"""
generated_at: datetime
"""生成时间"""
class LicenseInfo(SQLModelBase):
"""许可证信息"""
expired_at: datetime
"""过期时间"""
signed_at: datetime
"""签发时间"""
root_domains: list[str]
"""根域名列表"""
domains: list[str]
"""域名列表"""
vol_domains: list[str]
"""卷域名列表"""
class VersionInfo(SQLModelBase):
"""版本信息"""
version: str
"""版本号"""
pro: bool
"""是否为专业版"""
commit: str
"""提交哈希"""
class AdminSummaryResponse(ResponseBase):
"""管理员概况响应"""
metrics_summary: MetricsSummary
"""统计摘要"""
site_urls: list[str]
"""站点URL列表"""
license: LicenseInfo
"""许可证信息"""
version: VersionInfo
"""版本信息"""
class MCPMethod(StrEnum):
"""MCP 方法枚举"""
PING = "ping"
"""Ping 方法,用于测试连接"""
class MCPBase(SQLModelBase):
"""MCP 请求基础模型"""
jsonrpc: str = "2.0"
"""JSON-RPC 版本"""
id: uuid.UUID = Field(default_factory=uuid.uuid4)
"""请求/响应 ID用于标识请求/响应的唯一性"""
class MCPRequestBase(MCPBase):
"""MCP 请求模型基础类"""
method: str
"""方法名称"""
class MCPResponseBase(MCPBase):
"""MCP 响应模型基础类"""
result: str
"""方法返回结果"""

103
sqlmodels/node.py Normal file
View File

@@ -0,0 +1,103 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Index
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .download import Download
class NodeStatus(StrEnum):
"""节点状态枚举"""
ONLINE = "online"
"""正常"""
OFFLINE = "offline"
"""离线"""
class NodeType(StrEnum):
"""节点类型枚举"""
MASTER = "master"
"""主节点"""
SLAVE = "slave"
"""从节点"""
class Aria2ConfigurationBase(SQLModelBase):
"""Aria2配置基础模型"""
rpc_url: str | None = Field(default=None, max_length=255)
"""RPC地址"""
rpc_secret: str | None = None
"""RPC密钥"""
temp_path: str | None = Field(default=None, max_length=255)
"""临时下载路径"""
max_concurrent: int = Field(default=5, ge=1, le=50)
"""最大并发数"""
timeout: int = Field(default=300, ge=1)
"""请求超时时间(秒)"""
class Aria2Configuration(Aria2ConfigurationBase, TableBaseMixin):
"""Aria2配置模型与Node一对一关联"""
node_id: int = Field(
foreign_key="node.id",
unique=True,
index=True,
ondelete="CASCADE"
)
"""关联的节点ID"""
# 反向关系
node: "Node" = Relationship(back_populates="aria2_config")
"""关联的节点"""
class Node(SQLModelBase, TableBaseMixin):
"""节点模型"""
__table_args__ = (
Index("ix_node_status", "status"),
)
status: NodeStatus = Field(default=NodeStatus.ONLINE)
"""节点状态"""
name: str = Field(max_length=255, unique=True)
"""节点名称"""
type: NodeType
"""节点类型"""
server: str = Field(max_length=255)
"""节点地址IP或域名"""
slave_key: str | None = Field(default=None, max_length=255)
"""从机通讯密钥"""
master_key: str | None = Field(default=None, max_length=255)
"""主机通讯密钥"""
aria2_enabled: bool = False
"""是否启用Aria2"""
rank: int = 0
"""节点排序权重"""
# 关系
aria2_config: Aria2Configuration | None = Relationship(
back_populates="node",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""Aria2配置"""
downloads: list["Download"] = Relationship(back_populates="node")
"""该节点的下载任务"""

807
sqlmodels/object.py Normal file
View File

@@ -0,0 +1,807 @@
from datetime import datetime
from typing import TYPE_CHECKING, Literal
from uuid import UUID
from enum import StrEnum
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .user import User
from .policy import Policy
from .source_link import SourceLink
from .share import Share
from .physical_file import PhysicalFile
from .uri import DiskNextURI
class ObjectType(StrEnum):
"""对象类型枚举"""
FILE = "file"
FOLDER = "folder"
class StorageType(StrEnum):
"""存储类型枚举"""
LOCAL = "local"
QINIU = "qiniu"
TENCENT = "tencent"
ALIYUN = "aliyun"
ONEDRIVE = "onedrive"
GOOGLE_DRIVE = "google_drive"
DROPBOX = "dropbox"
WEBDAV = "webdav"
REMOTE = "remote"
class FileMetadataBase(SQLModelBase):
"""文件元数据基础模型"""
width: int | None = Field(default=None)
"""图片宽度(像素)"""
height: int | None = Field(default=None)
"""图片高度(像素)"""
duration: float | None = Field(default=None)
"""音视频时长(秒)"""
bitrate: int | None = Field(default=None)
"""比特率kbps"""
mime_type: str | None = Field(default=None, max_length=127)
"""MIME类型"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和"""
checksum_sha256: str | None = Field(default=None, max_length=64)
"""SHA256校验和"""
# ==================== Base 模型 ====================
class ObjectBase(SQLModelBase):
"""对象基础字段,供数据库模型和 DTO 共享"""
name: str
"""对象名称(文件名或目录名)"""
type: ObjectType
"""对象类型"""
size: int | None = None
"""文件大小(字节),目录为 None"""
# ==================== DTO 模型 ====================
class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO"""
parent_id: UUID
"""父目录UUID"""
name: str
"""目录名称"""
policy_id: UUID | None = None
"""存储策略UUID不指定则继承父目录"""
class ObjectMoveRequest(SQLModelBase):
"""移动对象请求 DTO"""
src_ids: list[UUID]
"""源对象UUID列表"""
dst_id: UUID
"""目标文件夹UUID"""
class ObjectDeleteRequest(SQLModelBase):
"""删除对象请求 DTO"""
ids: list[UUID]
"""待删除对象UUID列表"""
class ObjectResponse(ObjectBase):
"""对象响应 DTO"""
id: UUID
"""对象UUID"""
thumb: bool = False
"""是否有缩略图"""
created_at: datetime
"""对象创建时间"""
updated_at: datetime
"""对象修改时间"""
source_enabled: bool = False
"""是否启用离线下载源"""
class PolicyResponse(SQLModelBase):
"""存储策略响应 DTO"""
id: UUID
"""策略UUID"""
name: str
"""策略名称"""
type: StorageType
"""存储类型"""
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
"""单文件最大限制单位字节0表示不限制"""
file_type: list[str] | None = None
"""允许的文件类型列表None 表示不限制"""
class DirectoryResponse(SQLModelBase):
"""目录响应 DTO"""
id: UUID
"""当前目录UUID"""
parent: UUID | None = None
"""父目录UUID根目录为None"""
objects: list[ObjectResponse] = []
"""目录下的对象列表"""
policy: PolicyResponse
"""存储策略"""
# ==================== 数据库模型 ====================
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
"""文件元数据模型与Object一对一关联"""
object_id: UUID = Field(
foreign_key="object.id",
unique=True,
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID"""
# 反向关系
object: "Object" = Relationship(back_populates="file_metadata")
"""关联的对象"""
class Object(ObjectBase, UUIDTableBaseMixin):
"""
统一对象模型
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
根目录规则:
- 每个用户有一个显式根目录对象name="/", parent_id=NULL
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
- 根目录的 policy_id 指定用户默认存储策略
- 路径格式:/path/to/file如 /docs/readme.md不包含用户名前缀
"""
__table_args__ = (
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/"
CheckConstraint(
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
name="ck_object_name_no_slash",
),
# 性能索引
Index("ix_object_owner_updated", "owner_id", "updated_at"),
Index("ix_object_parent_updated", "parent_id", "updated_at"),
Index("ix_object_owner_type", "owner_id", "type"),
Index("ix_object_owner_size", "owner_id", "size"),
)
# ==================== 基础字段 ====================
name: str = Field(max_length=255)
"""对象名称(文件名或目录名)"""
type: ObjectType
"""对象类型file 或 folder"""
password: str | None = Field(default=None, max_length=255)
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
# ==================== 文件专属字段 ====================
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0"""
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
"""分块上传会话ID仅文件有效"""
physical_file_id: UUID | None = Field(
default=None,
foreign_key="physicalfile.id",
index=True,
ondelete="SET NULL"
)
"""关联的物理文件UUID仅文件有效目录为NULL"""
# ==================== 外键 ====================
parent_id: UUID | None = Field(
default=None,
foreign_key="object.id",
index=True,
ondelete="CASCADE"
)
"""父目录UUIDNULL 表示这是用户的根目录"""
owner_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所有者用户UUID"""
policy_id: UUID = Field(
foreign_key="policy.id",
index=True,
ondelete="RESTRICT"
)
"""存储策略UUID文件直接使用目录作为子文件的默认策略"""
# ==================== 封禁相关字段 ====================
is_banned: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否被封禁"""
banned_at: datetime | None = None
"""封禁时间"""
banned_by: UUID | None = Field(
default=None,
foreign_key="user.id",
index=True,
ondelete="SET NULL",
sa_column_kwargs={"name": "banned_by"}
)
"""封禁操作者UUID"""
ban_reason: str | None = Field(default=None, max_length=500)
"""封禁原因"""
# ==================== 关系 ====================
owner: "User" = Relationship(
back_populates="objects",
sa_relationship_kwargs={"foreign_keys": "[Object.owner_id]"}
)
"""所有者"""
banner: "User" = Relationship(
sa_relationship_kwargs={"foreign_keys": "[Object.banned_by]"}
)
"""封禁操作者"""
policy: "Policy" = Relationship(back_populates="objects")
"""存储策略"""
# 自引用关系
parent: "Object" = Relationship(
back_populates="children",
sa_relationship_kwargs={"remote_side": "Object.id"},
)
"""父目录"""
children: list["Object"] = Relationship(
back_populates="parent",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""子对象(文件和子目录)"""
# 仅文件有效的关系
file_metadata: FileMetadata | None = Relationship(
back_populates="object",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""文件元数据(仅文件有效)"""
source_links: list["SourceLink"] = Relationship(
back_populates="object",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""源链接列表(仅文件有效)"""
shares: list["Share"] = Relationship(
back_populates="object",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""分享列表"""
physical_file: "PhysicalFile" = Relationship(back_populates="objects")
"""关联的物理文件(仅文件有效)"""
# ==================== 业务属性 ====================
@property
def source_name(self) -> str | None:
"""
源文件存储路径(向后兼容属性)
:return: 物理文件存储路径,如果没有关联物理文件则返回 None
"""
if self.physical_file:
return self.physical_file.storage_path
return None
@property
def is_file(self) -> bool:
"""是否为文件"""
return self.type == ObjectType.FILE
@property
def is_folder(self) -> bool:
"""是否为目录"""
return self.type == ObjectType.FOLDER
# ==================== 业务方法 ====================
@classmethod
async def get_root(cls, session, user_id: UUID) -> "Object | None":
"""
获取用户的根目录
:param session: 数据库会话
:param user_id: 用户UUID
:return: 根目录对象,不存在则返回 None
"""
return await cls.get(
session,
(cls.owner_id == user_id) & (cls.parent_id == None)
)
@classmethod
async def get_by_path(
cls,
session,
user_id: UUID,
path: str,
) -> "Object | None":
"""
根据路径获取对象
路径从用户根目录开始,不包含用户名前缀。
"/" 表示根目录,"/docs/images" 表示根目录下的 docs/images。
:param session: 数据库会话
:param user_id: 用户UUID
:param path: 路径,如 "/""/docs/images"
:return: Object 或 None
"""
path = path.strip()
if not path:
raise ValueError("路径不能为空")
# 获取用户根目录
root = await cls.get_root(session, user_id)
if not root:
return None
# 移除开头的斜杠并分割路径
if path.startswith("/"):
path = path[1:]
parts = [p for p in path.split("/") if p]
# 空路径 -> 返回根目录
if not parts:
return root
# 从根目录开始遍历路径
current = root
for part in parts:
if not current:
return None
current = await cls.get(
session,
(cls.owner_id == user_id) &
(cls.parent_id == current.id) &
(cls.name == part)
)
return current
@classmethod
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
"""
获取目录下的所有子对象
:param session: 数据库会话
:param user_id: 用户UUID
:param parent_id: 父目录UUID
:return: 子对象列表
"""
return await cls.get(
session,
(cls.owner_id == user_id) & (cls.parent_id == parent_id),
fetch_mode="all"
)
@classmethod
async def resolve_uri(
cls,
session,
uri: "DiskNextURI",
requesting_user_id: UUID | None = None,
) -> "Object":
"""
将 URI 解析为 Object 实例
分派逻辑(类似 Cloudreve 的 getNavigator
- MY → user_id = uri.id(str(requesting_user_id))
验证权限(自己的或管理员),然后 get_by_path
- SHARE → 通过 uri.fs_id 查 Share 表,验证密码和有效期
获取 share.object然后沿 uri.path 遍历子对象
- TRASH → 延后实现
:param session: 数据库会话
:param uri: DiskNextURI 实例
:param requesting_user_id: 请求用户UUID
:return: Object 实例
:raises ValueError: URI 无法解析
:raises PermissionError: 权限不足
:raises NotImplementedError: 不支持的命名空间
"""
from .uri import FileSystemNamespace
if uri.namespace == FileSystemNamespace.MY:
# 确定目标用户
target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None)
if not target_user_id_str:
raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id")
target_user_id = UUID(target_user_id_str)
# 权限检查:只能访问自己的空间(管理员权限由路由层判断)
if requesting_user_id and target_user_id != requesting_user_id:
raise PermissionError("无权访问其他用户的文件空间")
obj = await cls.get_by_path(session, target_user_id, uri.path)
if not obj:
raise ValueError(f"路径不存在: {uri.path}")
return obj
elif uri.namespace == FileSystemNamespace.SHARE:
raise NotImplementedError("分享空间解析尚未实现")
elif uri.namespace == FileSystemNamespace.TRASH:
raise NotImplementedError("回收站解析尚未实现")
else:
raise ValueError(f"未知的命名空间: {uri.namespace}")
async def get_full_path(self, session) -> str:
"""
从当前对象沿 parent_id 向上遍历到根目录,返回完整路径
:param session: 数据库会话
:return: 完整路径,如 "/docs/images/photo.jpg"
"""
parts: list[str] = []
current: Object | None = self
while current and current.parent_id is not None:
parts.append(current.name)
current = await Object.get(session, Object.id == current.parent_id)
# 反转顺序(从根到当前)
parts.reverse()
return "/" + "/".join(parts)
# ==================== 上传会话模型 ====================
class UploadSessionBase(SQLModelBase):
"""上传会话基础字段"""
file_name: str = Field(max_length=255)
"""原始文件名"""
file_size: int = Field(ge=0, sa_type=BigInteger)
"""文件总大小(字节)"""
chunk_size: int = Field(ge=1, sa_type=BigInteger)
"""分片大小(字节)"""
total_chunks: int = Field(ge=1)
"""总分片数"""
class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
"""
上传会话模型
用于管理分片上传的会话状态。
会话有效期为24小时过期后自动失效。
"""
# 会话状态
uploaded_chunks: int = 0
"""已上传分片数"""
uploaded_size: int = Field(default=0, sa_type=BigInteger)
"""已上传大小(字节)"""
storage_path: str | None = Field(default=None, max_length=512)
"""文件存储路径"""
expires_at: datetime
"""会话过期时间"""
# 外键
owner_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
"""上传者用户UUID"""
parent_id: UUID = Field(foreign_key="object.id", index=True, ondelete="CASCADE")
"""目标父目录UUID"""
policy_id: UUID = Field(foreign_key="policy.id", index=True, ondelete="RESTRICT")
"""存储策略UUID"""
# 关系
owner: "User" = Relationship()
"""上传者"""
parent: "Object" = Relationship(
sa_relationship_kwargs={"foreign_keys": "[UploadSession.parent_id]"}
)
"""目标父目录"""
policy: "Policy" = Relationship()
"""存储策略"""
@property
def is_expired(self) -> bool:
"""会话是否已过期"""
return datetime.now() > self.expires_at
@property
def is_complete(self) -> bool:
"""上传是否完成"""
return self.uploaded_chunks >= self.total_chunks
# ==================== 上传会话相关 DTO ====================
class CreateUploadSessionRequest(SQLModelBase):
"""创建上传会话请求 DTO"""
file_name: str = Field(max_length=255)
"""文件名"""
file_size: int = Field(ge=0)
"""文件大小(字节)"""
parent_id: UUID
"""父目录UUID"""
policy_id: UUID | None = None
"""存储策略UUID不指定则使用父目录的策略"""
class UploadSessionResponse(SQLModelBase):
"""上传会话响应 DTO"""
id: UUID
"""会话UUID"""
file_name: str
"""原始文件名"""
file_size: int
"""文件总大小(字节)"""
chunk_size: int
"""分片大小(字节)"""
total_chunks: int
"""总分片数"""
uploaded_chunks: int
"""已上传分片数"""
expires_at: datetime
"""过期时间"""
class UploadChunkResponse(SQLModelBase):
"""上传分片响应 DTO"""
uploaded_chunks: int
"""已上传分片数"""
total_chunks: int
"""总分片数"""
is_complete: bool
"""是否上传完成"""
object_id: UUID | None = None
"""完成后的文件对象UUID未完成时为None"""
class CreateFileRequest(SQLModelBase):
"""创建空白文件请求 DTO"""
name: str = Field(max_length=255)
"""文件名"""
parent_id: UUID
"""父目录UUID"""
policy_id: UUID | None = None
"""存储策略UUID不指定则使用父目录的策略"""
# ==================== 对象操作相关 DTO ====================
class ObjectCopyRequest(SQLModelBase):
"""复制对象请求 DTO"""
src_ids: list[UUID]
"""源对象UUID列表"""
dst_id: UUID
"""目标文件夹UUID"""
class ObjectRenameRequest(SQLModelBase):
"""重命名对象请求 DTO"""
id: UUID
"""对象UUID"""
new_name: str = Field(max_length=255)
"""新名称"""
class ObjectPropertyResponse(SQLModelBase):
"""对象基本属性响应 DTO"""
id: UUID
"""对象UUID"""
name: str
"""对象名称"""
type: ObjectType
"""对象类型"""
size: int
"""文件大小(字节)"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""修改时间"""
parent_id: UUID | None
"""父目录UUID"""
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
"""对象详细属性响应 DTO继承基本属性"""
# 元数据信息
mime_type: str | None = None
"""MIME类型"""
width: int | None = None
"""图片宽度(像素)"""
height: int | None = None
"""图片高度(像素)"""
duration: float | None = None
"""音视频时长(秒)"""
checksum_md5: str | None = None
"""MD5校验和"""
# 分享统计
share_count: int = 0
"""分享次数"""
total_views: int = 0
"""总浏览次数"""
total_downloads: int = 0
"""总下载次数"""
# 存储信息
policy_name: str | None = None
"""存储策略名称"""
reference_count: int = 1
"""物理文件引用计数(仅文件有效)"""
# ==================== 管理员文件管理 DTO ====================
class AdminFileResponse(ObjectResponse):
"""管理员文件响应 DTO"""
owner_id: UUID
"""所有者UUID"""
owner_email: str
"""所有者邮箱"""
policy_name: str
"""存储策略名称"""
is_banned: bool = False
"""是否被封禁"""
banned_at: datetime | None = None
"""封禁时间"""
ban_reason: str | None = None
"""封禁原因"""
@classmethod
def from_object(
cls,
obj: "Object",
owner: "User | None",
policy: "Policy | None",
) -> "AdminFileResponse":
"""从 Object ORM 对象构建"""
return cls(
# ObjectBase 字段
**ObjectBase.model_validate(obj, from_attributes=True).model_dump(),
# ObjectResponse 字段
id=obj.id,
thumb=False,
created_at=obj.created_at,
updated_at=obj.updated_at,
source_enabled=False,
# AdminFileResponse 字段
owner_id=obj.owner_id,
owner_email=owner.email if owner else "unknown",
policy_name=policy.name if policy else "unknown",
is_banned=obj.is_banned,
banned_at=obj.banned_at,
ban_reason=obj.ban_reason,
)
class FileBanRequest(SQLModelBase):
"""文件封禁请求 DTO"""
ban: bool = True
"""是否封禁"""
reason: str | None = Field(default=None, max_length=500)
"""封禁原因"""
class AdminFileListResponse(SQLModelBase):
"""管理员文件列表响应 DTO"""
files: list[AdminFileResponse] = []
"""文件列表"""
total: int = 0
"""总数"""

66
sqlmodels/order.py Normal file
View File

@@ -0,0 +1,66 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class OrderStatus(StrEnum):
"""订单状态枚举"""
PENDING = "pending"
"""待支付"""
COMPLETED = "completed"
"""已完成"""
CANCELLED = "cancelled"
"""已取消"""
class OrderType(StrEnum):
"""订单类型枚举"""
# [TODO] 补充具体订单类型
pass
class Order(SQLModelBase, TableBaseMixin):
"""订单模型"""
order_no: str = Field(max_length=255, unique=True, index=True)
"""订单号,唯一"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单类型 [TODO] 待定义枚举"""
method: str | None = Field(default=None, max_length=255)
"""支付方式"""
product_id: int | None = Field(default=None)
"""商品ID"""
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
"""购买数量"""
name: str = Field(max_length=255)
"""商品名称"""
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""订单价格(分)"""
status: OrderStatus = Field(default=OrderStatus.PENDING)
"""订单状态"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="orders")

View File

@@ -0,0 +1,91 @@
"""
物理文件模型
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile
实现文件共享而不复制物理文件。
引用计数逻辑:
- 每个引用此文件的 Object 都会增加引用计数
- 当 Object 被删除时,减少引用计数
- 只有当引用计数为 0 时,才物理删除文件
"""
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
from .policy import Policy
class PhysicalFileBase(SQLModelBase):
"""物理文件基础模型"""
storage_path: str = Field(max_length=512)
"""物理存储路径(相对于存储策略根目录)"""
size: int = Field(default=0, sa_type=BigInteger)
"""文件大小(字节)"""
checksum_md5: str | None = Field(default=None, max_length=32)
"""MD5校验和用于文件去重和完整性校验"""
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
"""
物理文件模型
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile
实现文件共享而不复制物理文件。
"""
__table_args__ = (
Index("ix_physical_file_policy_path", "policy_id", "storage_path"),
Index("ix_physical_file_checksum", "checksum_md5"),
)
policy_id: UUID = Field(
foreign_key="policy.id",
index=True,
ondelete="RESTRICT",
)
"""存储策略UUID"""
reference_count: int = Field(default=1, ge=0)
"""引用计数(有多少个 Object 引用此物理文件)"""
# 关系
policy: "Policy" = Relationship()
"""存储策略"""
objects: list["Object"] = Relationship(back_populates="physical_file")
"""引用此物理文件的所有逻辑对象"""
def increment_reference(self) -> int:
"""
增加引用计数
:return: 更新后的引用计数
"""
self.reference_count += 1
return self.reference_count
def decrement_reference(self) -> int:
"""
减少引用计数
:return: 更新后的引用计数
"""
if self.reference_count > 0:
self.reference_count -= 1
return self.reference_count
@property
def can_be_deleted(self) -> bool:
"""是否可以物理删除引用计数为0"""
return self.reference_count == 0

187
sqlmodels/policy.py Normal file
View File

@@ -0,0 +1,187 @@
from typing import TYPE_CHECKING
from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, text
from .base import SQLModelBase
from .mixin import UUIDTableBaseMixin
if TYPE_CHECKING:
from .object import Object
from .group import Group
class GroupPolicyLink(SQLModelBase, table=True):
"""用户组与存储策略的多对多关联表"""
group_id: UUID = Field(
foreign_key="group.id",
primary_key=True,
ondelete="CASCADE"
)
"""用户组UUID"""
policy_id: UUID = Field(
foreign_key="policy.id",
primary_key=True,
ondelete="CASCADE"
)
"""存储策略UUID"""
class PolicyType(StrEnum):
LOCAL = "local"
S3 = "s3"
class PolicyBase(SQLModelBase):
"""存储策略基础字段,供 DTO 和数据库模型共享"""
name: str = Field(max_length=255)
"""策略名称"""
type: PolicyType
"""存储策略类型"""
server: str | None = Field(default=None, max_length=255)
"""服务器地址(本地策略为绝对路径)"""
bucket_name: str | None = Field(default=None, max_length=255)
"""存储桶名称"""
is_private: bool = True
"""是否为私有空间"""
base_url: str | None = Field(default=None, max_length=255)
"""访问文件的基础URL"""
access_key: str | None = None
"""Access Key"""
secret_key: str | None = None
"""Secret Key"""
max_size: int = Field(default=0, ge=0)
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool = False
"""是否自动重命名"""
dir_name_rule: str | None = Field(default=None, max_length=255)
"""目录命名规则"""
file_name_rule: str | None = Field(default=None, max_length=255)
"""文件命名规则"""
is_origin_link_enable: bool = False
"""是否开启源链接访问"""
# ==================== DTO 模型 ====================
class PolicySummary(SQLModelBase):
"""策略摘要,用于列表展示"""
id: UUID
"""策略UUID"""
name: str
"""策略名称"""
type: PolicyType
"""策略类型"""
server: str | None
"""服务器地址"""
max_size: int
"""最大文件尺寸"""
is_private: bool
"""是否私有"""
# ==================== 数据库模型 ====================
class PolicyOptionsBase(SQLModelBase):
"""存储策略选项的基础模型"""
token: str | None = Field(default=None)
"""访问令牌"""
file_type: str | None = Field(default=None)
"""允许的文件类型"""
mimetype: str | None = Field(default=None, max_length=127)
"""MIME类型"""
od_redirect: str | None = Field(default=None, max_length=255)
"""OneDrive重定向地址"""
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
"""分片上传大小字节默认50MB"""
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否使用S3路径风格"""
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
"""存储策略选项模型与Policy一对一关联"""
policy_id: UUID = Field(
foreign_key="policy.id",
unique=True,
ondelete="CASCADE"
)
"""关联的策略UUID"""
# 反向关系
policy: "Policy" = Relationship(back_populates="options")
"""关联的策略"""
class Policy(PolicyBase, UUIDTableBaseMixin):
"""存储策略模型"""
# 覆盖基类字段以添加数据库专有配置
name: str = Field(max_length=255, unique=True)
"""策略名称"""
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否为私有空间"""
max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""允许上传的最大文件尺寸(字节)"""
auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否自动重命名"""
is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
"""是否开启源链接访问"""
# 一对一关系:策略选项
options: PolicyOptions | None = Relationship(
back_populates="policy",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""策略的扩展选项"""
# 关系
objects: list["Object"] = Relationship(back_populates="policy")
"""策略下的所有对象"""
# 多对多关系:策略可以被多个用户组使用
groups: list["Group"] = Relationship(
back_populates="policies",
link_model=GroupPolicyLink,
)
@staticmethod
async def create(
policy: 'Policy | None' = None,
**kwargs
):
pass

23
sqlmodels/redeem.py Normal file
View File

@@ -0,0 +1,23 @@
from enum import StrEnum
from sqlmodel import Field, text
from .base import SQLModelBase
from .mixin import TableBaseMixin
class RedeemType(StrEnum):
"""兑换码类型枚举"""
# [TODO] 补充具体兑换码类型
pass
class Redeem(SQLModelBase, TableBaseMixin):
"""兑换码模型"""
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""兑换码类型 [TODO] 待定义枚举"""
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
code: str = Field(unique=True, index=True, description="兑换码,唯一")
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")

35
sqlmodels/report.py Normal file
View File

@@ -0,0 +1,35 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from sqlmodel import Field, Relationship
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .share import Share
class ReportReason(StrEnum):
"""举报原因枚举"""
# [TODO] 补充具体举报原因
pass
class Report(SQLModelBase, TableBaseMixin):
"""举报模型"""
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""举报原因 [TODO] 待定义枚举"""
description: str | None = Field(default=None, max_length=255, description="补充描述")
# 外键
share_id: int = Field(
foreign_key="share.id",
index=True,
ondelete="CASCADE"
)
"""被举报的分享ID"""
# 关系
share: "Share" = Relationship(back_populates="reports")

136
sqlmodels/setting.py Normal file
View File

@@ -0,0 +1,136 @@
from enum import StrEnum
from sqlmodel import UniqueConstraint
from .base import SQLModelBase
from .mixin import TableBaseMixin
from .user import UserResponse
class CaptchaType(StrEnum):
"""验证码类型枚举"""
DEFAULT = "default"
GCAPTCHA = "gcaptcha"
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
# ==================== DTO 模型 ====================
class SiteConfigResponse(SQLModelBase):
"""站点配置响应 DTO"""
title: str = "DiskNext"
"""网站标题"""
site_notice: str | None = None
"""网站公告"""
user: UserResponse | None = None
"""用户信息"""
logo_light: str | None = None
"""网站Logo URL"""
logo_dark: str | None = None
"""网站Logo URL深色模式"""
register_enabled: bool = True
"""是否允许注册"""
login_captcha: bool = False
"""登录是否需要验证码"""
reg_captcha: bool = False
"""注册是否需要验证码"""
forget_captcha: bool = False
"""找回密码是否需要验证码"""
captcha_type: CaptchaType = CaptchaType.DEFAULT
"""验证码类型"""
captcha_key: str | None = None
"""验证码 public keyDEFAULT 类型时为 None"""
# ==================== 管理员设置 DTO ====================
class SettingItem(SQLModelBase):
"""单个设置项 DTO"""
type: str
"""设置类型"""
name: str
"""设置项名称"""
value: str | None = None
"""设置值"""
class SettingsListResponse(SQLModelBase):
"""获取设置列表响应 DTO"""
settings: list[SettingItem]
"""设置项列表"""
total: int
"""总数"""
class SettingsUpdateRequest(SQLModelBase):
"""更新设置请求 DTO"""
settings: list[SettingItem]
"""要更新的设置项列表"""
class SettingsUpdateResponse(SQLModelBase):
"""更新设置响应 DTO"""
updated: int
"""更新的设置项数量"""
created: int
"""新建的设置项数量"""
# ==================== 数据库模型 ====================
class SettingsType(StrEnum):
"""设置类型枚举"""
ARIA2 = "aria2"
AUTH = "auth"
AUTHN = "authn"
AVATAR = "avatar"
BASIC = "basic"
CAPTCHA = "captcha"
CRON = "cron"
FILE_EDIT = "file_edit"
LOGIN = "login"
MAIL = "mail"
MAIL_TEMPLATE = "mail_template"
MOBILE = "mobile"
OAUTH = "oauth"
PATH = "path"
PREVIEW = "preview"
PWA = "pwa"
REGISTER = "register"
RETRY = "retry"
SHARE = "share"
SLAVE = "slave"
TASK = "task"
THUMB = "thumb"
TIMEOUT = "timeout"
UPLOAD = "upload"
VERSION = "version"
VIEW = "view"
WOPI = "wopi"
# 数据库模型
class Setting(SettingItem, TableBaseMixin):
"""设置模型,继承 SettingItem 中的 name 和 value 字段"""
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
type: SettingsType
"""设置类型/分组(覆盖基类的 str 类型为枚举类型)"""

220
sqlmodels/share.py Normal file
View File

@@ -0,0 +1,220 @@
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
from .report import Report
from .object import Object
# ==================== Base 模型 ====================
class ShareBase(SQLModelBase):
"""分享基础字段,供 DTO 和数据库模型共享"""
object_id: UUID
"""关联的对象UUID"""
password: str | None = None
"""分享密码"""
expires: datetime | None = None
"""过期时间NULL为永不过期"""
remain_downloads: int | None = None
"""剩余下载次数NULL为不限制"""
preview_enabled: bool = True
"""是否允许预览"""
score: int = 0
"""兑换此分享所需的积分"""
# ==================== 数据库模型 ====================
class Share(SQLModelBase, TableBaseMixin):
"""分享模型"""
__table_args__ = (
UniqueConstraint("code", name="uq_share_code"),
Index("ix_share_source_name", "source_name"),
Index("ix_share_user_created", "user_id", "created_at"),
Index("ix_share_object", "object_id"),
)
code: str = Field(max_length=64, nullable=False, index=True)
"""分享码"""
password: str | None = Field(default=None, max_length=255)
"""分享密码(加密后)"""
object_id: UUID = Field(
foreign_key="object.id",
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID"""
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""浏览次数"""
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""下载次数"""
remain_downloads: int | None = Field(default=None)
"""剩余下载次数 (NULL为不限制)"""
expires: datetime | None = Field(default=None)
"""过期时间 (NULL为永不过期)"""
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
"""是否允许预览"""
source_name: str | None = Field(default=None, max_length=255)
"""源名称(冗余字段,便于展示)"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""兑换此分享所需的积分"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""创建分享的用户UUID"""
# 关系
user: "User" = Relationship(back_populates="shares")
"""分享创建者"""
object: "Object" = Relationship(back_populates="shares")
"""关联的对象"""
reports: list["Report"] = Relationship(
back_populates="share",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
"""举报列表"""
@property
def is_dir(self) -> bool:
"""是否为目录分享(向后兼容属性)"""
from .object import ObjectType
return self.object.type == ObjectType.FOLDER if self.object else False
# ==================== DTO 模型 ====================
class ShareCreateRequest(ShareBase):
"""创建分享请求 DTO继承 ShareBase 中的所有字段"""
pass
class ShareResponse(SQLModelBase):
"""分享响应 DTO"""
id: int
"""分享ID"""
code: str
"""分享码"""
object_id: UUID
"""关联对象UUID"""
source_name: str | None
"""源名称"""
views: int
"""浏览次数"""
downloads: int
"""下载次数"""
remain_downloads: int | None
"""剩余下载次数"""
expires: datetime | None
"""过期时间"""
preview_enabled: bool
"""是否允许预览"""
score: int
"""积分"""
created_at: datetime
"""创建时间"""
is_expired: bool
"""是否已过期"""
has_password: bool
"""是否有密码"""
class ShareListItemBase(SQLModelBase):
"""分享列表项基础字段"""
id: int
"""分享ID"""
code: str
"""分享码"""
views: int
"""浏览次数"""
downloads: int
"""下载次数"""
remain_downloads: int | None
"""剩余下载次数"""
expires: datetime | None
"""过期时间"""
preview_enabled: bool
"""是否允许预览"""
score: int
"""积分"""
user_id: UUID
"""用户UUID"""
created_at: datetime
"""创建时间"""
class AdminShareListItem(ShareListItemBase):
"""管理员分享列表项 DTO添加关联字段"""
username: str | None
"""用户名"""
object_name: str | None
"""对象名称"""
@classmethod
def from_share(
cls,
share: "Share",
user: "User | None",
obj: "Object | None",
) -> "AdminShareListItem":
"""从 Share ORM 对象构建"""
return cls(
**ShareListItemBase.model_validate(share, from_attributes=True).model_dump(),
username=user.email if user else None,
object_name=obj.name if obj else None,
)

37
sqlmodels/source_link.py Normal file
View File

@@ -0,0 +1,37 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, Index
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .object import Object
class SourceLink(SQLModelBase, TableBaseMixin):
"""链接模型"""
__table_args__ = (
Index("ix_sourcelink_object_name", "object_id", "name"),
)
name: str = Field(max_length=255)
"""链接名称"""
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""通过此链接的下载次数"""
# 外键
object_id: UUID = Field(
foreign_key="object.id",
index=True,
ondelete="CASCADE"
)
"""关联的对象UUID必须是文件类型"""
# 关系
object: "Object" = Relationship(back_populates="source_links")
"""关联的对象"""

31
sqlmodels/storage_pack.py Normal file
View File

@@ -0,0 +1,31 @@
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class StoragePack(SQLModelBase, TableBaseMixin):
"""容量包模型"""
name: str = Field(max_length=255, description="容量包名称")
active_time: datetime | None = Field(default=None, description="激活时间")
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
size: int = Field(description="容量包大小(字节)")
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="storage_packs")

50
sqlmodels/tag.py Normal file
View File

@@ -0,0 +1,50 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class TagType(StrEnum):
"""标签类型枚举"""
MANUAL = "manual"
"""手动标签"""
AUTOMATIC = "automatic"
"""自动标签"""
class Tag(SQLModelBase, TableBaseMixin):
"""标签模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
name: str = Field(max_length=255)
"""标签名称"""
icon: str | None = Field(default=None, max_length=255)
"""标签图标"""
color: str | None = Field(default=None, max_length=255)
"""标签颜色"""
type: TagType = Field(default=TagType.MANUAL)
"""标签类型"""
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="tags")

153
sqlmodels/task.py Normal file
View File

@@ -0,0 +1,153 @@
from enum import StrEnum
from typing import TYPE_CHECKING
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, CheckConstraint, Index
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .download import Download
from .user import User
class TaskStatus(StrEnum):
"""任务状态枚举"""
QUEUED = "queued"
"""排队中"""
RUNNING = "running"
"""处理中"""
COMPLETED = "completed"
"""已完成"""
ERROR = "error"
"""错误"""
class TaskType(StrEnum):
"""任务类型枚举"""
# [TODO] 补充具体任务类型
pass
# ==================== DTO 模型 ====================
class TaskSummaryBase(SQLModelBase):
"""任务摘要基础字段"""
id: int
"""任务ID"""
type: int
"""任务类型"""
status: TaskStatus
"""任务状态"""
progress: int
"""进度0-100"""
error: str | None
"""错误信息"""
user_id: UUID
"""用户UUID"""
created_at: datetime
"""创建时间"""
updated_at: datetime
"""更新时间"""
class TaskSummary(TaskSummaryBase):
"""任务摘要,用于管理员列表展示"""
username: str | None
"""用户名"""
@classmethod
def from_task(cls, task: "Task", user: "User | None") -> "TaskSummary":
"""从 Task ORM 对象构建"""
return cls(
**TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(),
username=user.email if user else None,
)
# ==================== 数据库模型 ====================
class TaskPropsBase(SQLModelBase):
"""任务属性基础模型"""
source_path: str | None = None
"""源路径"""
dest_path: str | None = None
"""目标路径"""
file_ids: str | None = None
"""文件ID列表逗号分隔"""
# [TODO] 根据业务需求补充更多字段
class TaskProps(TaskPropsBase, TableBaseMixin):
"""任务属性模型与Task一对一关联"""
task_id: int = Field(
foreign_key="task.id",
primary_key=True,
ondelete="CASCADE"
)
"""关联的任务ID"""
# 反向关系
task: "Task" = Relationship(back_populates="props")
"""关联的任务"""
class Task(SQLModelBase, TableBaseMixin):
"""任务模型"""
__table_args__ = (
CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"),
Index("ix_task_status", "status"),
Index("ix_task_user_status", "user_id", "status"),
)
status: TaskStatus = Field(default=TaskStatus.QUEUED)
"""任务状态"""
type: int = Field(default=0)
"""任务类型 [TODO] 待定义枚举"""
progress: int = Field(default=0, ge=0, le=100)
"""任务进度0-100"""
error: str | None = Field(default=None)
"""错误信息"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
props: TaskProps | None = Relationship(
back_populates="task",
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
)
"""任务属性"""
user: "User" = Relationship(back_populates="tasks")
"""所属用户"""
downloads: list["Download"] = Relationship(back_populates="task")
"""关联的下载任务"""

258
sqlmodels/uri.py Normal file
View File

@@ -0,0 +1,258 @@
from enum import StrEnum
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
from .base import SQLModelBase
class FileSystemNamespace(StrEnum):
"""文件系统命名空间"""
MY = "my"
"""用户个人空间"""
SHARE = "share"
"""分享空间"""
TRASH = "trash"
"""回收站"""
class DiskNextURI(SQLModelBase):
"""
DiskNext 文件 URI
URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query]
fs_id 可省略:
- my/trash 命名空间省略时默认当前用户
- share 命名空间必须提供 fs_idShare.code
"""
fs_id: str | None = None
"""文件系统标识符,可省略"""
namespace: FileSystemNamespace
"""命名空间"""
path: str = "/"
"""路径"""
password: str | None = None
"""访问密码(用于有密码的分享)"""
query: dict[str, str] | None = None
"""查询参数"""
# === 属性 ===
@property
def path_parts(self) -> list[str]:
"""路径分割为列表(过滤空串)"""
return [p for p in self.path.split("/") if p]
@property
def is_root(self) -> bool:
"""是否指向根目录"""
return self.path.strip("/") == ""
# === 核心方法 ===
def id(self, default_id: str | None = None) -> str | None:
"""
获取 fs_id省略时返回 default_id
参考 Cloudreve URI.ID(defaultUid) 方法
:param default_id: 默认值(通常为当前用户 ID
:return: fs_id 或 default_id
"""
return self.fs_id if self.fs_id else default_id
# === 类方法 ===
@classmethod
def parse(cls, uri: str) -> "DiskNextURI":
"""
解析 URI 字符串
实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析
- hostname → namespace
- username → fs_id
- password → password
- path → path
- query → query dict
:param uri: URI 字符串,如 "disknext://my/docs/readme.md"
:return: DiskNextURI 实例
:raises ValueError: URI 格式无效
"""
if not uri.startswith("disknext://"):
raise ValueError(f"URI 必须以 disknext:// 开头: {uri}")
# 替换协议为 http:// 以利用 urllib.parse 解析
http_uri = "http://" + uri[len("disknext://"):]
parsed = urlparse(http_uri)
# 解析 namespace
hostname = parsed.hostname
if not hostname:
raise ValueError(f"URI 缺少命名空间: {uri}")
try:
namespace = FileSystemNamespace(hostname)
except ValueError:
raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}")
# 解析 fs_id 和 password
fs_id = unquote(parsed.username) if parsed.username else None
password = unquote(parsed.password) if parsed.password else None
# 解析 path
path = unquote(parsed.path) if parsed.path else "/"
if not path:
path = "/"
# 解析 query
query: dict[str, str] | None = None
if parsed.query:
raw_query = parse_qs(parsed.query, keep_blank_values=True)
query = {k: v[0] for k, v in raw_query.items()}
return cls(
fs_id=fs_id,
namespace=namespace,
path=path,
password=password,
query=query,
)
@classmethod
def build(
cls,
namespace: FileSystemNamespace,
path: str = "/",
fs_id: str | None = None,
password: str | None = None,
) -> "DiskNextURI":
"""
构建 URI 实例
:param namespace: 命名空间
:param path: 路径
:param fs_id: 文件系统标识符
:param password: 访问密码
:return: DiskNextURI 实例
"""
# 确保 path 以 / 开头
if not path.startswith("/"):
path = "/" + path
return cls(
fs_id=fs_id,
namespace=namespace,
path=path,
password=password,
)
# === 实例方法 ===
def to_string(self) -> str:
"""
序列化为 URI 字符串
:return: URI 字符串,如 "disknext://my/docs/readme.md"
"""
result = "disknext://"
# fs_id 和 password
if self.fs_id:
result += quote(self.fs_id, safe="")
if self.password:
result += ":" + quote(self.password, safe="")
result += "@"
# namespace
result += self.namespace.value
# path
result += self.path
# query
if self.query:
result += "?" + urlencode(self.query)
return result
def join(self, *elements: str) -> "DiskNextURI":
"""
拼接路径元素,返回新 URI
:param elements: 路径元素
:return: 新的 DiskNextURI 实例
"""
base = self.path.rstrip("/")
for element in elements:
element = element.strip("/")
if element:
base += "/" + element
if not base:
base = "/"
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path=base,
password=self.password,
query=self.query,
)
def dir_uri(self) -> "DiskNextURI":
"""
返回父目录的 URI
:return: 父目录的 DiskNextURI 实例
"""
parts = self.path_parts
if not parts:
# 已经是根目录
return self.root()
parent_path = "/" + "/".join(parts[:-1])
if not parent_path.endswith("/"):
parent_path += "/"
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path=parent_path,
password=self.password,
)
def root(self) -> "DiskNextURI":
"""
返回根目录的 URI保留 namespace 和 fs_id
:return: 根目录的 DiskNextURI 实例
"""
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path="/",
password=self.password,
)
def name(self) -> str:
"""
返回路径的最后一段(文件名或目录名)
:return: 文件名或目录名,根目录返回空字符串
"""
parts = self.path_parts
return parts[-1] if parts else ""
def __str__(self) -> str:
return self.to_string()
def __repr__(self) -> str:
return f"DiskNextURI({self.to_string()!r})"

554
sqlmodels/user.py Normal file
View File

@@ -0,0 +1,554 @@
from datetime import datetime
from enum import StrEnum
from typing import Literal, TYPE_CHECKING, TypeVar
from uuid import UUID
from pydantic import BaseModel
from sqlalchemy import BinaryExpression, ClauseElement, and_
from sqlmodel import Field, Relationship
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
from .base import SQLModelBase
from .model_base import ResponseBase
from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
T = TypeVar("T", bound="User")
if TYPE_CHECKING:
from .group import Group
from .download import Download
from .object import Object
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .user_authn import UserAuthn
from .webdav import WebDAV
class AvatarType(StrEnum):
"""头像类型枚举"""
DEFAULT = "default"
GRAVATAR = "gravatar"
FILE = "file"
class ThemeType(StrEnum):
"""主题类型枚举"""
LIGHT = "light"
DARK = "dark"
SYSTEM = "system"
class UserStatus(StrEnum):
"""用户状态枚举"""
ACTIVE = "active"
ADMIN_BANNED = "admin_banned"
SYSTEM_BANNED = "system_banned"
# ==================== 筛选参数 ====================
class UserFilterParams(SQLModelBase):
"""
用户过滤参数
用于管理员搜索用户列表,支持用户组、用户名、昵称、状态等过滤。
"""
group_id: UUID | None = None
"""按用户组UUID筛选"""
email_contains: str | None = Field(default=None, max_length=50)
"""邮箱包含(不区分大小写的模糊搜索)"""
nickname_contains: str | None = Field(default=None, max_length=50)
"""昵称包含(不区分大小写的模糊搜索)"""
status: UserStatus | None = None
"""按用户状态筛选"""
# ==================== Base 模型 ====================
class UserBase(SQLModelBase):
"""用户基础字段,供数据库模型和 DTO 共享"""
email: str
"""用户邮箱"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
score: int = 0
"""用户积分"""
# ==================== DTO 模型 ====================
class LoginRequest(SQLModelBase):
"""登录请求 DTO"""
email: str
"""用户邮箱"""
password: str
"""用户密码"""
captcha: str | None = None
"""验证码"""
two_fa_code: int | None = Field(min_length=6, max_length=6)
"""两步验证代码"""
class RegisterRequest(SQLModelBase):
"""注册请求 DTO"""
email: str
"""用户邮箱,唯一"""
password: str
"""用户密码"""
captcha: str | None = None
"""验证码"""
class BatchDeleteRequest(SQLModelBase):
"""批量删除请求 DTO"""
ids: list[UUID]
"""待删除 UUID 列表"""
class RefreshTokenRequest(SQLModelBase):
"""刷新令牌请求 DTO"""
refresh_token: str
"""刷新令牌"""
class WebAuthnInfo(SQLModelBase):
"""WebAuthn 信息 DTO"""
credential_id: str
"""凭证 ID"""
credential_public_key: str
"""凭证公钥"""
sign_count: int
"""签名计数器"""
credential_device_type: bool
"""是否为平台认证器"""
credential_backed_up: bool
"""凭证是否已备份"""
transports: list[str]
"""支持的传输方式"""
class AccessTokenBase(BaseModel):
"""访问令牌响应 DTO"""
access_expires: datetime
"""访问令牌过期时间"""
access_token: str
"""访问令牌"""
class RefreshTokenBase(BaseModel):
"""刷新令牌响应DTO"""
refresh_expires: datetime
"""刷新令牌过期时间"""
refresh_token: str
"""刷新令牌"""
class TokenResponse(ResponseBase, AccessTokenBase, RefreshTokenBase):
"""令牌响应 DTO"""
class UserResponse(ResponseBase):
"""用户响应 DTO"""
id: UUID
"""用户UUID"""
email: str
"""用户邮箱"""
nickname: str | None = None
"""用户昵称"""
avatar: Literal["default", "gravatar", "file"] = "default"
"""头像类型"""
created_at: datetime
"""用户创建时间"""
anonymous: bool = False
"""是否为匿名用户"""
group: "GroupResponse | None" = None
"""用户所属用户组"""
tags: list[str] = []
"""用户标签列表"""
class UserStorageResponse(SQLModelBase):
"""用户存储信息 DTO"""
used: int
"""已用存储空间(字节)"""
free: int
"""剩余存储空间(字节)"""
total: int
"""总存储空间(字节)"""
class UserPublic(UserBase):
"""用户公开信息 DTO用于 API 响应"""
id: UUID
"""用户UUID"""
nickname: str | None = None
"""昵称"""
storage: int = 0
"""已用存储空间(字节)"""
avatar: str | None = None
"""头像地址"""
group_expires: datetime | None = None
"""用户组过期时间"""
group_id: UUID | None = None
"""所属用户组UUID"""
group_name: str | None = None
"""用户组名称"""
two_factor: str | None = None
"""两步验证密钥32位字符串null 表示未启用)"""
created_at: datetime | None = None
"""创建时间"""
updated_at: datetime | None = None
"""更新时间"""
class UserSettingResponse(SQLModelBase):
"""用户设置响应 DTO"""
id: UUID
"""用户UUID"""
email: str
"""用户邮箱"""
nickname: str | None = None
"""昵称"""
created_at: datetime
"""用户注册时间"""
group_name: str
"""用户所属用户组名称"""
language: str
"""语言偏好"""
timezone: int
"""时区"""
authn: "list[AuthnResponse] | None" = None
"""认证信息"""
group_expires: datetime | None = None
"""用户组过期时间"""
two_factor: bool = False
"""是否启用两步验证"""
# ==================== 管理员用户管理 DTO ====================
class UserAdminCreateRequest(SQLModelBase):
"""管理员创建用户请求 DTO"""
email: str = Field(max_length=50)
"""用户邮箱"""
password: str
"""用户密码(明文,由服务端加密)"""
nickname: str | None = Field(default=None, max_length=50)
"""昵称"""
group_id: UUID
"""所属用户组UUID"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
class UserAdminUpdateRequest(SQLModelBase):
"""管理员更新用户请求 DTO"""
email: str = Field(max_length=50)
"""邮箱"""
nickname: str | None = Field(default=None, max_length=50)
"""昵称"""
password: str | None = None
"""新密码(为空则不修改)"""
group_id: UUID | None = None
"""用户组UUID"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
score: int | None = Field(default=None, ge=0)
"""积分"""
storage: int | None = Field(default=None, ge=0)
"""已用存储空间(用于手动校准)"""
group_expires: datetime | None = None
"""用户组过期时间"""
two_factor: str | None = None
"""两步验证密钥32位字符串传 null 可清除,不传则不修改)"""
class UserCalibrateResponse(SQLModelBase):
"""用户存储校准响应 DTO"""
user_id: UUID
"""用户UUID"""
previous_storage: int
"""校准前的存储空间(字节)"""
current_storage: int
"""校准后的存储空间(字节)"""
difference: int
"""差异值(字节)"""
file_count: int
"""实际文件数量"""
class UserAdminDetailResponse(UserPublic):
"""管理员用户详情响应 DTO"""
two_factor_enabled: bool = False
"""是否启用两步验证"""
file_count: int = 0
"""文件数量"""
share_count: int = 0
"""分享数量"""
task_count: int = 0
"""任务数量"""
# 前向引用导入
from .group import GroupResponse # noqa: E402
from .user_authn import AuthnResponse # noqa: E402
# 更新前向引用
UserResponse.model_rebuild()
UserSettingResponse.model_rebuild()
# ==================== 数据库模型 ====================
class User(UserBase, UUIDTableBaseMixin):
"""用户模型"""
email: str = Field(max_length=50, unique=True, index=True)
"""用户邮箱,唯一"""
nickname: str | None = Field(default=None, max_length=50)
"""用于公开展示的名字,可使用真实姓名或昵称"""
password: str = Field(max_length=255)
"""用户密码(加密后)"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""已用存储空间(字节)"""
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
"""两步验证密钥"""
avatar: str = Field(default="default", max_length=255)
"""头像地址"""
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
"""用户积分"""
group_expires: datetime | None = Field(default=None)
"""当前用户组过期时间"""
# Option 相关字段
# theme: ThemeType = Field(default=ThemeType.SYSTEM)
# """主题类型: light/dark/system"""
language: str = Field(default="zh-CN", max_length=5)
"""语言偏好"""
timezone: int = Field(default=8, ge=-12, le=12)
"""时区UTC 偏移小时数"""
# 外键
group_id: UUID = Field(
foreign_key="group.id",
index=True,
ondelete="RESTRICT"
)
"""所属用户组UUID"""
previous_group_id: UUID | None = Field(
default=None,
foreign_key="group.id",
ondelete="SET NULL"
)
"""之前的用户组UUID用于过期后恢复"""
# 关系
group: "Group" = Relationship(
back_populates="users",
sa_relationship_kwargs={
"foreign_keys": "User.group_id"
}
)
previous_group: "Group" = Relationship(
back_populates="previous_users",
sa_relationship_kwargs={
"foreign_keys": "User.previous_group_id"
}
)
downloads: list["Download"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
objects: list["Object"] = Relationship(
back_populates="owner",
sa_relationship_kwargs={
"cascade": "all, delete-orphan",
"foreign_keys": "[Object.owner_id]"
}
)
"""用户的所有对象(文件和目录)"""
orders: list["Order"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
shares: list["Share"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
storage_packs: list["StoragePack"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
tags: list["Tag"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
tasks: list["Task"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
webdavs: list["WebDAV"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
authns: list["UserAuthn"] = Relationship(
back_populates="user",
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
)
def to_public(self) -> "UserPublic":
"""转换为公开 DTO排除敏感字段。需要预加载 group 关系。"""
data = UserPublic.model_validate(self)
data.group_name = self.group.name
return data
@classmethod
async def get_with_count(
cls: type[T],
session: AsyncSession,
condition: BinaryExpression | ClauseElement | None = None,
*,
filter_params: 'UserFilterParams | None' = None,
join: type[T] | tuple[type[T], ClauseElement] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None,
) -> 'ListResponse[T]':
"""
获取分页用户列表及总数,支持用户过滤参数
:param filter_params: UserFilterParams 过滤参数对象(用户组、用户名、昵称、状态等)
:param 其他参数: 继承自 UUIDTableBaseMixin.get_with_count()
"""
# 构建过滤条件
merged_condition = condition
if filter_params is not None:
filter_conditions: list[BinaryExpression] = []
if filter_params.group_id is not None:
filter_conditions.append(cls.group_id == filter_params.group_id)
if filter_params.email_contains is not None:
filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%"))
if filter_params.nickname_contains is not None:
filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%"))
if filter_params.status is not None:
filter_conditions.append(cls.status == filter_params.status)
if filter_conditions:
combined_filter = and_(*filter_conditions)
if merged_condition is not None:
merged_condition = and_(merged_condition, combined_filter)
else:
merged_condition = combined_filter
return await super().get_with_count(
session,
merged_condition,
join=join,
options=options,
load=load,
order_by=order_by,
filter=filter,
table_view=table_view,
)

61
sqlmodels/user_authn.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
# ==================== DTO 模型 ====================
class AuthnResponse(SQLModelBase):
"""WebAuthn 响应 DTO"""
id: str
"""凭证ID"""
fingerprint: str
"""凭证指纹"""
# ==================== 数据库模型 ====================
class UserAuthn(SQLModelBase, TableBaseMixin):
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
credential_id: str = Field(max_length=255, unique=True, index=True)
"""凭证 IDBase64 编码"""
credential_public_key: str = Field(sa_column=Column(Text))
"""凭证公钥Base64 编码"""
sign_count: int = Field(default=0, ge=0)
"""签名计数器,用于防重放攻击"""
credential_device_type: str = Field(max_length=32)
"""凭证设备类型:'single_device''multi_device'"""
credential_backed_up: bool = Field(default=False)
"""凭证是否已备份"""
transports: str | None = Field(default=None, max_length=255)
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
name: str | None = Field(default=None, max_length=100)
"""用户自定义的凭证名称,便于识别"""
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="authns")

33
sqlmodels/webdav.py Normal file
View File

@@ -0,0 +1,33 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase
from .mixin import TableBaseMixin
if TYPE_CHECKING:
from .user import User
class WebDAV(SQLModelBase, TableBaseMixin):
"""WebDAV账户模型"""
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
name: str = Field(max_length=255, description="WebDAV账户名")
password: str = Field(max_length=255, description="WebDAV密码")
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
readonly: bool = Field(default=False, description="是否只读")
use_proxy: bool = Field(default=False, description="是否使用代理下载")
# 外键
user_id: UUID = Field(
foreign_key="user.id",
index=True,
ondelete="CASCADE"
)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="webdavs")