feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
1274
sqlmodels/README.md
Normal file
1274
sqlmodels/README.md
Normal file
File diff suppressed because it is too large
Load Diff
105
sqlmodels/__init__.py
Normal file
105
sqlmodels/__init__.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
RegisterRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
UserStorageResponse,
|
||||
UserPublic,
|
||||
UserResponse,
|
||||
UserSettingResponse,
|
||||
WebAuthnInfo,
|
||||
# 管理员DTO
|
||||
UserAdminUpdateRequest,
|
||||
UserCalibrateResponse,
|
||||
UserAdminDetailResponse,
|
||||
)
|
||||
from .user_authn import AuthnResponse, UserAuthn
|
||||
from .color import ThemeResponse
|
||||
|
||||
from .download import (
|
||||
Download,
|
||||
DownloadAria2File,
|
||||
DownloadAria2Info,
|
||||
DownloadAria2InfoBase,
|
||||
DownloadStatus,
|
||||
DownloadType,
|
||||
)
|
||||
from .node import (
|
||||
Aria2Configuration,
|
||||
Aria2ConfigurationBase,
|
||||
Node,
|
||||
NodeStatus,
|
||||
NodeType,
|
||||
)
|
||||
from .group import (
|
||||
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
# 管理员DTO
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
|
||||
)
|
||||
from .object import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
FileMetadata,
|
||||
FileMetadataBase,
|
||||
Object,
|
||||
ObjectBase,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
ObjectMoveRequest,
|
||||
ObjectPropertyDetailResponse,
|
||||
ObjectPropertyResponse,
|
||||
ObjectRenameRequest,
|
||||
ObjectResponse,
|
||||
ObjectType,
|
||||
PolicyResponse,
|
||||
UploadChunkResponse,
|
||||
UploadSession,
|
||||
UploadSessionBase,
|
||||
UploadSessionResponse,
|
||||
# 管理员DTO
|
||||
AdminFileResponse,
|
||||
AdminFileListResponse,
|
||||
FileBanRequest,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .uri import DiskNextURI, FileSystemNamespace
|
||||
from .order import Order, OrderStatus, OrderType
|
||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||
from .redeem import Redeem, RedeemType
|
||||
from .report import Report, ReportReason
|
||||
from .setting import (
|
||||
Setting, SettingsType, SiteConfigResponse,
|
||||
# 管理员DTO
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from .share import Share, ShareBase, ShareCreateRequest, ShareResponse, AdminShareListItem
|
||||
from .source_link import SourceLink
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
from .model_base import (
|
||||
MCPBase,
|
||||
MCPMethod,
|
||||
MCPRequestBase,
|
||||
MCPResponseBase,
|
||||
ResponseBase,
|
||||
# Admin Summary DTO
|
||||
MetricsSummary,
|
||||
LicenseInfo,
|
||||
VersionInfo,
|
||||
AdminSummaryResponse,
|
||||
)
|
||||
|
||||
# mixin 中的通用分页模型
|
||||
from .mixin import ListResponse
|
||||
657
sqlmodels/base/README.md
Normal file
657
sqlmodels/base/README.md
Normal file
@@ -0,0 +1,657 @@
|
||||
# SQLModels Base Module
|
||||
|
||||
This module provides `SQLModelBase`, the root base class for all SQLModel models in this project. It includes a custom metaclass with automatic type injection and Python 3.14 compatibility.
|
||||
|
||||
**Note**: Table base classes (`TableBaseMixin`, `UUIDTableBaseMixin`) and polymorphic utilities have been migrated to the [`sqlmodels.mixin`](../mixin/README.md) module. See the mixin documentation for CRUD operations, polymorphic inheritance patterns, and pagination utilities.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Migration Notice](#migration-notice)
|
||||
- [Python 3.14 Compatibility](#python-314-compatibility)
|
||||
- [Core Component](#core-component)
|
||||
- [SQLModelBase](#sqlmodelbase)
|
||||
- [Metaclass Features](#metaclass-features)
|
||||
- [Automatic sa_type Injection](#automatic-sa_type-injection)
|
||||
- [Table Configuration](#table-configuration)
|
||||
- [Polymorphic Support](#polymorphic-support)
|
||||
- [Custom Types Integration](#custom-types-integration)
|
||||
- [Best Practices](#best-practices)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The `sqlmodels.base` module provides `SQLModelBase`, the foundational base class for all SQLModel models. It features:
|
||||
|
||||
- **Smart metaclass** that automatically extracts and injects SQLAlchemy types from type annotations
|
||||
- **Python 3.14 compatibility** through comprehensive PEP 649/749 support
|
||||
- **Flexible configuration** through class parameters and automatic docstring support
|
||||
- **Type-safe annotations** with automatic validation
|
||||
|
||||
All models in this project should directly or indirectly inherit from `SQLModelBase`.
|
||||
|
||||
---
|
||||
|
||||
## Migration Notice
|
||||
|
||||
As of the recent refactoring, the following components have been moved:
|
||||
|
||||
| Component | Old Location | New Location |
|
||||
|-----------|-------------|--------------|
|
||||
| `TableBase` → `TableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `UUIDTableBase` → `UUIDTableBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `PolymorphicBaseMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `create_subclass_id_mixin()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `AutoPolymorphicIdentityMixin` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `TableViewRequest` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
| `now()`, `now_date()` | `sqlmodels.base` | `sqlmodels.mixin` |
|
||||
|
||||
**Update your imports**:
|
||||
|
||||
```python
|
||||
# ❌ Old (deprecated)
|
||||
from sqlmodels.base import TableBase, UUIDTableBase
|
||||
|
||||
# ✅ New (correct)
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
```
|
||||
|
||||
For detailed documentation on table mixins, CRUD operations, and polymorphic patterns, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Python 3.14 Compatibility
|
||||
|
||||
### Overview
|
||||
|
||||
This module provides full compatibility with **Python 3.14's PEP 649** (Deferred Evaluation of Annotations) and **PEP 749** (making it the default).
|
||||
|
||||
**Key Changes in Python 3.14**:
|
||||
- Annotations are no longer evaluated at class definition time
|
||||
- Type hints are stored as deferred code objects
|
||||
- `__annotate__` function generates annotations on demand
|
||||
- Forward references become `ForwardRef` objects
|
||||
|
||||
### Implementation Strategy
|
||||
|
||||
We use **`typing.get_type_hints()`** as the universal annotations resolver:
|
||||
|
||||
```python
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[...]:
|
||||
# Create temporary proxy class
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
|
||||
# Use get_type_hints with include_extras=True
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
include_extras=True # Preserve Annotated metadata
|
||||
)
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
```
|
||||
|
||||
**Why `get_type_hints()`?**
|
||||
- ✅ Works across Python 3.10-3.14+
|
||||
- ✅ Handles PEP 649 automatically
|
||||
- ✅ Preserves `Annotated` metadata (with `include_extras=True`)
|
||||
- ✅ Resolves forward references
|
||||
- ✅ Recommended by Python documentation
|
||||
|
||||
### SQLModel Compatibility Patch
|
||||
|
||||
**Problem**: SQLModel's `get_sqlalchemy_type()` doesn't recognize custom types with `__sqlmodel_sa_type__` attribute.
|
||||
|
||||
**Solution**: Global monkey-patch that checks for SQLAlchemy type before falling back to original logic:
|
||||
|
||||
```python
|
||||
if sys.version_info >= (3, 14):
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# Priority 1: Check __sqlmodel_sa_type__ attribute
|
||||
# Handles NumpyVector[dims, dtype] and similar custom types
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# Priority 2: Check Annotated metadata
|
||||
if get_origin(annotation) is Annotated:
|
||||
for metadata in get_args(annotation)[1:]:
|
||||
if hasattr(metadata, '__sqlmodel_sa_type__'):
|
||||
return metadata.__sqlmodel_sa_type__
|
||||
|
||||
# ... handle ForwardRef, ClassVar, etc.
|
||||
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
```
|
||||
|
||||
### Supported Patterns
|
||||
|
||||
#### Pattern 1: Direct Custom Type Usage
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding - sa_type automatically extracted"""
|
||||
```
|
||||
|
||||
#### Pattern 2: Annotated Wrapper
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
```
|
||||
|
||||
#### Pattern 3: Array Type
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import Array
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
|
||||
class ServerConfig(TableBaseMixin, table=True):
|
||||
protocols: Array[ProtocolEnum]
|
||||
"""Allowed protocols - sa_type from Array handler"""
|
||||
```
|
||||
|
||||
### Migration from Python 3.13
|
||||
|
||||
**No code changes required!** The implementation is transparent:
|
||||
|
||||
- Uses `typing.get_type_hints()` which works in both Python 3.13 and 3.14
|
||||
- Custom types already use `__sqlmodel_sa_type__` attribute
|
||||
- Monkey-patch only activates for Python 3.14+
|
||||
|
||||
---
|
||||
|
||||
## Core Component
|
||||
|
||||
### SQLModelBase
|
||||
|
||||
`SQLModelBase` is the root base class for all SQLModel models. It uses a custom metaclass (`__DeclarativeMeta`) that provides advanced features beyond standard SQLModel capabilities.
|
||||
|
||||
**Key Features**:
|
||||
- Automatic `use_attribute_docstrings` configuration (use docstrings instead of `Field(description=...)`)
|
||||
- Automatic `validate_by_name` configuration
|
||||
- Custom metaclass for sa_type injection and polymorphic setup
|
||||
- Integration with Pydantic v2
|
||||
- Python 3.14 PEP 649 compatibility
|
||||
|
||||
**Usage**:
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
email: str
|
||||
"""User's email address"""
|
||||
```
|
||||
|
||||
**Important Notes**:
|
||||
- Use **docstrings** for field descriptions, not `Field(description=...)`
|
||||
- Do NOT override `model_config` in subclasses (it's already configured in SQLModelBase)
|
||||
- This class should be used for non-table models (DTOs, request/response models)
|
||||
|
||||
**For table models**, use mixins from `sqlmodels.mixin`:
|
||||
- `TableBaseMixin` - Integer primary key with timestamps
|
||||
- `UUIDTableBaseMixin` - UUID primary key with timestamps
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete table mixin documentation.
|
||||
|
||||
---
|
||||
|
||||
## Metaclass Features
|
||||
|
||||
### Automatic sa_type Injection
|
||||
|
||||
The metaclass automatically extracts SQLAlchemy types from custom type annotations, enabling clean syntax for complex database types.
|
||||
|
||||
**Before** (verbose):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql.numpy_vector import _NumpyVectorSQLAlchemyType
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
**After** (clean):
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding"""
|
||||
```
|
||||
|
||||
**How It Works**:
|
||||
|
||||
The metaclass uses a three-tier detection strategy:
|
||||
|
||||
1. **Direct `__sqlmodel_sa_type__` attribute** (Priority 1)
|
||||
```python
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
2. **Annotated metadata** (Priority 2)
|
||||
```python
|
||||
# For Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
for item in metadata_items:
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
```
|
||||
|
||||
3. **Pydantic Core Schema metadata** (Priority 3)
|
||||
```python
|
||||
schema = annotation.__get_pydantic_core_schema__(...)
|
||||
if schema['metadata'].get('sa_type'):
|
||||
return schema['metadata']['sa_type']
|
||||
```
|
||||
|
||||
After extracting `sa_type`, the metaclass:
|
||||
- Creates `Field(sa_type=sa_type)` if no Field is defined
|
||||
- Injects `sa_type` into existing Field if not already set
|
||||
- Respects explicit `Field(sa_type=...)` (no override)
|
||||
|
||||
**Supported Patterns**:
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# Pattern 1: Direct usage (recommended)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
|
||||
# Pattern 2: With Field constraints
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(nullable=False)
|
||||
|
||||
# Pattern 3: Annotated wrapper
|
||||
EmbeddingVector = Annotated[np.ndarray, NumpyVector[256, np.float32]]
|
||||
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: EmbeddingVector
|
||||
|
||||
# Pattern 4: Explicit sa_type (override)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32] = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(128, np.float16)
|
||||
)
|
||||
```
|
||||
|
||||
### Table Configuration
|
||||
|
||||
The metaclass provides smart defaults and flexible configuration:
|
||||
|
||||
**Automatic `table=True`**:
|
||||
```python
|
||||
# Classes inheriting from TableBaseMixin automatically get table=True
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin): # table=True is automatic
|
||||
pass
|
||||
```
|
||||
|
||||
**Convenient mapper arguments**:
|
||||
```python
|
||||
# Instead of verbose __mapper_args__
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
polymorphic_on='_polymorphic_name',
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
pass
|
||||
|
||||
# Equivalent to:
|
||||
class MyModel(UUIDTableBaseMixin):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': '_polymorphic_name',
|
||||
'polymorphic_abstract': True
|
||||
}
|
||||
```
|
||||
|
||||
**Smart merging**:
|
||||
```python
|
||||
# Dictionary and keyword arguments are merged
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(
|
||||
UUIDTableBaseMixin,
|
||||
mapper_args={'version_id_col': 'version'},
|
||||
polymorphic_on='type' # Merged into __mapper_args__
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### Polymorphic Support
|
||||
|
||||
The metaclass supports SQLAlchemy's joined table inheritance through convenient parameters:
|
||||
|
||||
**Supported parameters**:
|
||||
- `polymorphic_on`: Discriminator column name
|
||||
- `polymorphic_identity`: Identity value for this class
|
||||
- `polymorphic_abstract`: Whether this is an abstract base
|
||||
- `table_args`: SQLAlchemy table arguments
|
||||
- `table_name`: Override table name (becomes `__tablename__`)
|
||||
|
||||
**For complete polymorphic inheritance patterns**, including `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Custom Types Integration
|
||||
|
||||
### Using NumpyVector
|
||||
|
||||
The `NumpyVector` type demonstrates automatic sa_type injection:
|
||||
|
||||
```python
|
||||
from sqlmodels.sqlmodel_types.dialects.postgresql import NumpyVector
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
import numpy as np
|
||||
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Speaker voice embedding - sa_type automatically injected"""
|
||||
```
|
||||
|
||||
**How NumpyVector works**:
|
||||
|
||||
```python
|
||||
# NumpyVector[dims, dtype] returns a class with:
|
||||
class _NumpyVectorType:
|
||||
__sqlmodel_sa_type__ = _NumpyVectorSQLAlchemyType(dimensions, dtype)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return handler.generate_schema(np.ndarray)
|
||||
```
|
||||
|
||||
This dual approach ensures:
|
||||
1. Metaclass can extract `sa_type` via `__sqlmodel_sa_type__`
|
||||
2. Pydantic can validate as `np.ndarray`
|
||||
|
||||
### Creating Custom SQLAlchemy Types
|
||||
|
||||
To create types that work with automatic injection, provide one of:
|
||||
|
||||
**Option 1: `__sqlmodel_sa_type__` attribute** (preferred):
|
||||
|
||||
```python
|
||||
from sqlalchemy import TypeDecorator, String
|
||||
|
||||
class UpperCaseString(TypeDecorator):
|
||||
impl = String
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return value.upper() if value else value
|
||||
|
||||
class UpperCaseType:
|
||||
__sqlmodel_sa_type__ = UpperCaseString()
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type, handler):
|
||||
return core_schema.str_schema()
|
||||
|
||||
# Usage
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCaseType # Automatically uses UpperCaseString()
|
||||
```
|
||||
|
||||
**Option 2: Pydantic metadata with sa_type**:
|
||||
|
||||
```python
|
||||
def __get_pydantic_core_schema__(self, source_type, handler):
|
||||
return core_schema.json_or_python_schema(
|
||||
json_schema=core_schema.str_schema(),
|
||||
python_schema=core_schema.str_schema(),
|
||||
metadata={'sa_type': UpperCaseString()}
|
||||
)
|
||||
```
|
||||
|
||||
**Option 3: Using Annotated**:
|
||||
|
||||
```python
|
||||
from typing import Annotated
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
UpperCase = Annotated[str, UpperCaseType()]
|
||||
|
||||
class MyModel(UUIDTableBaseMixin, table=True):
|
||||
code: UpperCase
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Inherit from correct base classes
|
||||
|
||||
```python
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
# ✅ For non-table models (DTOs, requests, responses)
|
||||
class UserBase(SQLModelBase):
|
||||
name: str
|
||||
|
||||
# ✅ For table models with UUID primary key
|
||||
class User(UserBase, UUIDTableBaseMixin, table=True):
|
||||
email: str
|
||||
|
||||
# ✅ For table models with custom primary key
|
||||
class LegacyUser(TableBaseMixin, table=True):
|
||||
id: int = Field(primary_key=True)
|
||||
username: str
|
||||
```
|
||||
|
||||
### 2. Use docstrings for field descriptions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Recommended
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
"""User's display name"""
|
||||
|
||||
# ❌ Avoid
|
||||
class User(UUIDTableBaseMixin, table=True):
|
||||
name: str = Field(description="User's display name")
|
||||
```
|
||||
|
||||
**Why?** SQLModelBase has `use_attribute_docstrings=True`, so docstrings automatically become field descriptions in API docs.
|
||||
|
||||
### 3. Leverage automatic sa_type injection
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Clean and recommended
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: NumpyVector[256, np.float32]
|
||||
"""Voice embedding"""
|
||||
|
||||
# ❌ Verbose and unnecessary
|
||||
class SpeakerInfo(UUIDTableBaseMixin, table=True):
|
||||
embedding: np.ndarray = Field(
|
||||
sa_type=_NumpyVectorSQLAlchemyType(256, np.float32)
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Follow polymorphic naming conventions
|
||||
|
||||
See [`sqlmodels/mixin/README.md`](../mixin/README.md) for complete polymorphic inheritance patterns using `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`.
|
||||
|
||||
### 5. Separate Base, Parent, and Implementation classes
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, PolymorphicBaseMixin
|
||||
|
||||
# ✅ Recommended structure
|
||||
class ASRBase(SQLModelBase):
|
||||
"""Pure data fields, no table"""
|
||||
name: str
|
||||
base_url: str
|
||||
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract parent with table"""
|
||||
@abstractmethod
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
pass
|
||||
|
||||
class WhisperASR(ASR, table=True):
|
||||
"""Concrete implementation"""
|
||||
model_size: str
|
||||
|
||||
async def transcribe(self, audio: bytes) -> str:
|
||||
# Implementation
|
||||
pass
|
||||
```
|
||||
|
||||
**Why?**
|
||||
- Base class can be reused for DTOs
|
||||
- Parent class defines the polymorphic hierarchy
|
||||
- Implementation classes are clean and focused
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: ValueError: X has no matching SQLAlchemy type
|
||||
|
||||
**Solution**: Ensure your custom type provides `__sqlmodel_sa_type__` attribute or proper Pydantic metadata with `sa_type`.
|
||||
|
||||
```python
|
||||
# ✅ Provide __sqlmodel_sa_type__
|
||||
class MyType:
|
||||
__sqlmodel_sa_type__ = MyCustomSQLAlchemyType()
|
||||
```
|
||||
|
||||
### Issue: Can't generate DDL for NullType()
|
||||
|
||||
**Symptoms**: Error during table creation saying a column has `NullType`.
|
||||
|
||||
**Root Cause**: Custom type's `sa_type` not detected by SQLModel.
|
||||
|
||||
**Solution**:
|
||||
1. Ensure your type has `__sqlmodel_sa_type__` class attribute
|
||||
2. Check that the monkey-patch is active (`sys.version_info >= (3, 14)`)
|
||||
3. Verify type annotation is correct (not a string forward reference)
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
# ✅ Correct
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: NumpyVector[256, np.float32] # __sqlmodel_sa_type__ detected
|
||||
|
||||
# ❌ Wrong (string annotation)
|
||||
class Model(UUIDTableBaseMixin, table=True):
|
||||
data: 'NumpyVector[256, np.float32]' # sa_type lost
|
||||
```
|
||||
|
||||
### Issue: Polymorphic identity conflicts
|
||||
|
||||
**Symptoms**: SQLAlchemy raises errors about duplicate polymorphic identities.
|
||||
|
||||
**Solution**:
|
||||
1. Check that each concrete class has a unique identity
|
||||
2. Use `AutoPolymorphicIdentityMixin` for automatic naming
|
||||
3. Manually specify identity if needed:
|
||||
```python
|
||||
class MyClass(Parent, polymorphic_identity='unique.name', table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### Issue: Python 3.14 annotation errors
|
||||
|
||||
**Symptoms**: Errors related to `__annotations__` or type resolution.
|
||||
|
||||
**Solution**: The implementation uses `get_type_hints()` which handles PEP 649 automatically. If issues persist:
|
||||
1. Check for manual `__annotations__` manipulation (avoid it)
|
||||
2. Ensure all types are properly imported
|
||||
3. Avoid `from __future__ import annotations` (can cause SQLModel issues)
|
||||
|
||||
### Issue: Polymorphic and CRUD-related errors
|
||||
|
||||
For issues related to polymorphic inheritance, CRUD operations, or table mixins, see the troubleshooting section in [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Implementation Details
|
||||
|
||||
For developers modifying this module:
|
||||
|
||||
**Core files**:
|
||||
- `sqlmodel_base.py` - Contains `__DeclarativeMeta` and `SQLModelBase`
|
||||
- `../mixin/table.py` - Contains `TableBaseMixin` and `UUIDTableBaseMixin`
|
||||
- `../mixin/polymorphic.py` - Contains `PolymorphicBaseMixin`, `create_subclass_id_mixin()`, and `AutoPolymorphicIdentityMixin`
|
||||
|
||||
**Key functions in this module**:
|
||||
|
||||
1. **`_resolve_annotations(attrs: dict[str, Any])`**
|
||||
- Uses `typing.get_type_hints()` for Python 3.14 compatibility
|
||||
- Returns tuple: `(annotations, annotation_strings, globalns, localns)`
|
||||
- Preserves `Annotated` metadata with `include_extras=True`
|
||||
|
||||
2. **`_extract_sa_type_from_annotation(annotation: Any) -> Any | None`**
|
||||
- Extracts SQLAlchemy type from type annotations
|
||||
- Supports `__sqlmodel_sa_type__`, `Annotated`, and Pydantic core schema
|
||||
- Called by metaclass during class creation
|
||||
|
||||
3. **`_patched_get_sqlalchemy_type(field)`** (Python 3.14+)
|
||||
- Global monkey-patch for SQLModel
|
||||
- Checks `__sqlmodel_sa_type__` before falling back to original logic
|
||||
- Handles custom types like `NumpyVector` and `Array`
|
||||
|
||||
4. **`__DeclarativeMeta.__new__()`**
|
||||
- Processes class definition parameters
|
||||
- Injects `sa_type` into field definitions
|
||||
- Sets up `__mapper_args__`, `__table_args__`, etc.
|
||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
||||
|
||||
**Metaclass processing order**:
|
||||
1. Check if class should be a table (`_has_table_mixin`)
|
||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
||||
4. Resolve annotations using `get_type_hints()`
|
||||
5. For each field, try to extract `sa_type` and inject into Field
|
||||
6. Call parent metaclass with cleaned kwargs
|
||||
|
||||
For table mixin implementation details, see [`sqlmodels/mixin/README.md`](../mixin/README.md).
|
||||
|
||||
---
|
||||
|
||||
## See Also
|
||||
|
||||
**Project Documentation**:
|
||||
- [SQLModel Mixin Documentation](../mixin/README.md) - Table mixins, CRUD operations, polymorphic patterns
|
||||
- [Project Coding Standards (CLAUDE.md)](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/CLAUDE.md)
|
||||
- [Custom SQLModel Types Guide](/mnt/c/Users/Administrator/PycharmProjects/emoecho-backend-server/sqlmodels/sqlmodel_types/README.md)
|
||||
|
||||
**External References**:
|
||||
- [SQLAlchemy Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
- [Pydantic V2 Documentation](https://docs.pydantic.dev/latest/)
|
||||
- [SQLModel Documentation](https://sqlmodel.tiangolo.com/)
|
||||
- [PEP 649: Deferred Evaluation of Annotations](https://peps.python.org/pep-0649/)
|
||||
- [PEP 749: Implementing PEP 649](https://peps.python.org/pep-0749/)
|
||||
- [Python Annotations Best Practices](https://docs.python.org/3/howto/annotations.html)
|
||||
12
sqlmodels/base/__init__.py
Normal file
12
sqlmodels/base/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
SQLModel 基础模块
|
||||
|
||||
包含:
|
||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
||||
|
||||
注意:
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
|
||||
为了避免循环导入,此处不再重新导出它们
|
||||
请直接从 sqlmodels.mixin 导入这些类
|
||||
"""
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
846
sqlmodels/base/sqlmodel_base.py
Normal file
846
sqlmodels/base/sqlmodel_base.py
Normal file
@@ -0,0 +1,846 @@
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Mapping, get_args, get_origin, get_type_hints
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined as Undefined
|
||||
from sqlalchemy.orm import Mapped
|
||||
from sqlmodel import Field, SQLModel
|
||||
from sqlmodel.main import SQLModelMetaclass
|
||||
|
||||
# Python 3.14+ PEP 649支持
|
||||
if sys.version_info >= (3, 14):
|
||||
import annotationlib
|
||||
|
||||
# 全局Monkey-patch: 修复SQLModel在Python 3.14上的兼容性问题
|
||||
import sqlmodel.main
|
||||
_original_get_sqlalchemy_type = sqlmodel.main.get_sqlalchemy_type
|
||||
|
||||
def _patched_get_sqlalchemy_type(field):
|
||||
"""
|
||||
修复SQLModel的get_sqlalchemy_type函数,处理Python 3.14的类型问题。
|
||||
|
||||
问题:
|
||||
1. ForwardRef对象(来自Relationship字段)会导致issubclass错误
|
||||
2. typing._GenericAlias对象(如ClassVar[T])也会导致同样问题
|
||||
3. list/dict等泛型类型在没有Field/Relationship时可能导致错误
|
||||
4. Mapped类型在Python 3.14下可能出现在annotation中
|
||||
5. Annotated类型可能包含sa_type metadata(如Array[T])
|
||||
6. 自定义类型(如NumpyVector)有__sqlmodel_sa_type__属性
|
||||
7. Pydantic已处理的Annotated类型会将metadata存储在field.metadata中
|
||||
|
||||
解决:
|
||||
- 优先检查field.metadata中的__get_pydantic_core_schema__(Pydantic已处理的情况)
|
||||
- 检测__sqlmodel_sa_type__属性(NumpyVector等)
|
||||
- 检测Relationship/ClassVar等返回None
|
||||
- 对于Annotated类型,尝试提取sa_type metadata
|
||||
- 其他情况调用原始函数
|
||||
"""
|
||||
# 优先检查 field.metadata(Pydantic已处理Annotated类型的情况)
|
||||
# 当使用 Array[T] 或 Annotated[T, metadata] 时,Pydantic会将metadata存储在这里
|
||||
metadata = getattr(field, 'metadata', None)
|
||||
if metadata:
|
||||
# metadata是一个列表,包含所有Annotated的元数据项
|
||||
for metadata_item in metadata:
|
||||
# 检查metadata_item是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata_item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata_item.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Pydantic schema获取可能失败(类型不匹配、缺少属性等)
|
||||
# 这是正常情况,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
annotation = getattr(field, 'annotation', None)
|
||||
if annotation is not None:
|
||||
# 优先检查 __sqlmodel_sa_type__ 属性
|
||||
# 这处理 NumpyVector[dims, dtype] 等自定义类型
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 检查自定义类型(如JSON100K)的 __get_pydantic_core_schema__ 方法
|
||||
# 这些类型在schema的metadata中定义sa_type
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema(传None作为handler,因为我们只需要metadata)
|
||||
schema = annotation.__get_pydantic_core_schema__(annotation, lambda x: None)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Schema获取失败,继续其他检查
|
||||
pass
|
||||
|
||||
anno_type_name = type(annotation).__name__
|
||||
|
||||
# ForwardRef: Relationship字段的annotation
|
||||
if anno_type_name == 'ForwardRef':
|
||||
return None
|
||||
|
||||
# AnnotatedAlias: 检查是否有sa_type metadata(如Array[T])
|
||||
if anno_type_name == 'AnnotatedAlias' or anno_type_name == '_AnnotatedAlias':
|
||||
from typing import get_origin, get_args
|
||||
import typing
|
||||
|
||||
# 尝试提取Annotated的metadata
|
||||
if hasattr(typing, 'get_args'):
|
||||
args = get_args(annotation)
|
||||
# args[0]是实际类型,args[1:]是metadata
|
||||
for metadata in args[1:]:
|
||||
# 检查metadata是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(metadata, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用获取schema
|
||||
schema = metadata.__get_pydantic_core_schema__(None, None)
|
||||
# 检查schema中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError):
|
||||
# Annotated metadata的schema获取可能失败
|
||||
# 这是正常的类型检查过程,继续检查下一个metadata
|
||||
pass
|
||||
|
||||
# _GenericAlias或GenericAlias: typing泛型类型
|
||||
if anno_type_name in ('_GenericAlias', 'GenericAlias'):
|
||||
from typing import get_origin
|
||||
import typing
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# ClassVar必须跳过
|
||||
if origin is typing.ClassVar:
|
||||
return None
|
||||
|
||||
# list/dict/tuple/set等内置泛型,如果字段没有明确的Field或Relationship,也跳过
|
||||
# 这通常意味着它是Relationship字段或类变量
|
||||
if origin in (list, dict, tuple, set):
|
||||
# 检查field_info是否存在且有意义
|
||||
# Relationship字段会有特殊的field_info
|
||||
field_info = getattr(field, 'field_info', None)
|
||||
if field_info is None:
|
||||
return None
|
||||
|
||||
# Mapped: SQLAlchemy 2.0的Mapped类型,SQLModel不应该处理
|
||||
# 这可能是从父类继承的字段或Python 3.14注解处理的副作用
|
||||
# 检查类型名称和annotation的字符串表示
|
||||
if 'Mapped' in anno_type_name or 'Mapped' in str(annotation):
|
||||
return None
|
||||
|
||||
# 检查annotation是否是Mapped类或其实例
|
||||
try:
|
||||
from sqlalchemy.orm import Mapped as SAMapped
|
||||
# 检查origin(对于Mapped[T]这种泛型)
|
||||
from typing import get_origin
|
||||
if get_origin(annotation) is SAMapped:
|
||||
return None
|
||||
# 检查类型本身
|
||||
if annotation is SAMapped or isinstance(annotation, type) and issubclass(annotation, SAMapped):
|
||||
return None
|
||||
except (ImportError, TypeError):
|
||||
# 如果SQLAlchemy没有Mapped或检查失败,继续
|
||||
pass
|
||||
|
||||
# 其他情况正常处理
|
||||
return _original_get_sqlalchemy_type(field)
|
||||
|
||||
sqlmodel.main.get_sqlalchemy_type = _patched_get_sqlalchemy_type
|
||||
|
||||
# 第二个Monkey-patch: 修复继承表类中InstrumentedAttribute作为默认值的问题
|
||||
# 在Python 3.14 + SQLModel组合下,当子类(如SMSBaoProvider)继承父类(如VerificationCodeProvider)时,
|
||||
# 父类的关系字段(如server_config)会在子类的model_fields中出现,
|
||||
# 但其default值错误地设置为InstrumentedAttribute对象,而不是None
|
||||
# 这导致实例化时尝试设置InstrumentedAttribute为字段值,触发SQLAlchemy内部错误
|
||||
import sqlmodel._compat as _compat
|
||||
from sqlalchemy.orm import attributes as _sa_attributes
|
||||
|
||||
_original_sqlmodel_table_construct = _compat.sqlmodel_table_construct
|
||||
|
||||
def _patched_sqlmodel_table_construct(self_instance, values):
|
||||
"""
|
||||
修复sqlmodel_table_construct,跳过InstrumentedAttribute默认值
|
||||
|
||||
问题:
|
||||
- 继承自polymorphic基类的表类(如FishAudioTTS, SMSBaoProvider)
|
||||
- 其model_fields中的继承字段default值为InstrumentedAttribute
|
||||
- 原函数尝试将InstrumentedAttribute设置为字段值
|
||||
- SQLAlchemy无法处理,抛出 '_sa_instance_state' 错误
|
||||
|
||||
解决:
|
||||
- 只设置用户提供的值和非InstrumentedAttribute默认值
|
||||
- InstrumentedAttribute默认值跳过(让SQLAlchemy自己处理)
|
||||
"""
|
||||
cls = type(self_instance)
|
||||
|
||||
# 收集要设置的字段值
|
||||
fields_to_set = {}
|
||||
|
||||
for name, field in cls.model_fields.items():
|
||||
# 如果用户提供了值,直接使用
|
||||
if name in values:
|
||||
fields_to_set[name] = values[name]
|
||||
continue
|
||||
|
||||
# 否则检查默认值
|
||||
# 跳过InstrumentedAttribute默认值 - 这些是继承字段的错误默认值
|
||||
if isinstance(field.default, _sa_attributes.InstrumentedAttribute):
|
||||
continue
|
||||
|
||||
# 使用正常的默认值
|
||||
if field.default is not Undefined:
|
||||
fields_to_set[name] = field.default
|
||||
elif field.default_factory is not None:
|
||||
fields_to_set[name] = field.get_default(call_default_factory=True)
|
||||
|
||||
# 设置属性 - 只设置非InstrumentedAttribute值
|
||||
for key, value in fields_to_set.items():
|
||||
if not isinstance(value, _sa_attributes.InstrumentedAttribute):
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
# 设置Pydantic内部属性
|
||||
object.__setattr__(self_instance, '__pydantic_fields_set__', set(values.keys()))
|
||||
if not cls.__pydantic_root_model__:
|
||||
_extra = None
|
||||
if cls.model_config.get('extra') == 'allow':
|
||||
_extra = {}
|
||||
for k, v in values.items():
|
||||
if k not in cls.model_fields:
|
||||
_extra[k] = v
|
||||
object.__setattr__(self_instance, '__pydantic_extra__', _extra)
|
||||
|
||||
if cls.__pydantic_post_init__:
|
||||
self_instance.model_post_init(None)
|
||||
elif not cls.__pydantic_root_model__:
|
||||
object.__setattr__(self_instance, '__pydantic_private__', None)
|
||||
|
||||
# 设置关系
|
||||
for key in self_instance.__sqlmodel_relationships__:
|
||||
value = values.get(key, Undefined)
|
||||
if value is not Undefined:
|
||||
setattr(self_instance, key, value)
|
||||
|
||||
return self_instance
|
||||
|
||||
_compat.sqlmodel_table_construct = _patched_sqlmodel_table_construct
|
||||
else:
|
||||
annotationlib = None
|
||||
|
||||
|
||||
def _extract_sa_type_from_annotation(annotation: Any) -> Any | None:
|
||||
"""
|
||||
从类型注解中提取SQLAlchemy类型。
|
||||
|
||||
支持以下形式:
|
||||
1. NumpyVector[256, np.float32] - 直接使用类型(有__sqlmodel_sa_type__属性)
|
||||
2. Annotated[np.ndarray, NumpyVector[256, np.float32]] - Annotated包装
|
||||
3. 任何有__get_pydantic_core_schema__且返回metadata['sa_type']的类型
|
||||
|
||||
Args:
|
||||
annotation: 字段的类型注解
|
||||
|
||||
Returns:
|
||||
提取到的SQLAlchemy类型,如果没有则返回None
|
||||
"""
|
||||
# 方法1:直接检查类型本身是否有__sqlmodel_sa_type__属性
|
||||
# 这涵盖了 NumpyVector[256, np.float32] 这种直接使用的情况
|
||||
if hasattr(annotation, '__sqlmodel_sa_type__'):
|
||||
return annotation.__sqlmodel_sa_type__
|
||||
|
||||
# 方法2:检查是否为Annotated类型
|
||||
if get_origin(annotation) is typing.Annotated:
|
||||
# 获取元数据项(跳过第一个实际类型参数)
|
||||
args = get_args(annotation)
|
||||
if len(args) >= 2:
|
||||
metadata_items = args[1:] # 第一个是实际类型,后面都是元数据
|
||||
|
||||
# 遍历元数据,查找包含sa_type的项
|
||||
for item in metadata_items:
|
||||
# 检查元数据项是否有__sqlmodel_sa_type__属性
|
||||
if hasattr(item, '__sqlmodel_sa_type__'):
|
||||
return item.__sqlmodel_sa_type__
|
||||
|
||||
# 检查是否有__get_pydantic_core_schema__方法
|
||||
if hasattr(item, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
# 调用该方法获取core schema
|
||||
schema = item.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
# 检查schema的metadata中是否有sa_type
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# Pydantic core schema获取可能失败:
|
||||
# - TypeError: 参数不匹配
|
||||
# - AttributeError: metadata不存在
|
||||
# - KeyError: schema结构不符合预期
|
||||
# - ValueError: 无效的类型定义
|
||||
# 这是正常的类型探测过程,继续检查下一个metadata项
|
||||
pass
|
||||
|
||||
# 方法3:检查类型本身是否有__get_pydantic_core_schema__
|
||||
# (虽然NumpyVector已经在方法1处理,但这是通用的fallback)
|
||||
if hasattr(annotation, '__get_pydantic_core_schema__'):
|
||||
try:
|
||||
schema = annotation.__get_pydantic_core_schema__(
|
||||
annotation,
|
||||
lambda x: None # 虚拟handler
|
||||
)
|
||||
if isinstance(schema, dict) and 'metadata' in schema:
|
||||
sa_type = schema['metadata'].get('sa_type')
|
||||
if sa_type is not None:
|
||||
return sa_type
|
||||
except (TypeError, AttributeError, KeyError, ValueError):
|
||||
# 类型本身的schema获取失败
|
||||
# 这是正常的fallback机制,annotation可能不支持此协议
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_annotations(attrs: dict[str, Any]) -> tuple[
|
||||
dict[str, Any],
|
||||
dict[str, str],
|
||||
Mapping[str, Any],
|
||||
Mapping[str, Any],
|
||||
]:
|
||||
"""
|
||||
Resolve annotations from a class namespace with Python 3.14 (PEP 649) support.
|
||||
|
||||
This helper prefers evaluated annotations (Format.VALUE) so that `typing.Annotated`
|
||||
metadata and custom types remain accessible. Forward references that cannot be
|
||||
evaluated are replaced with typing.ForwardRef placeholders to avoid aborting the
|
||||
whole resolution process.
|
||||
"""
|
||||
raw_annotations = attrs.get('__annotations__') or {}
|
||||
try:
|
||||
base_annotations = dict(raw_annotations)
|
||||
except TypeError:
|
||||
base_annotations = {}
|
||||
|
||||
module_name = attrs.get('__module__')
|
||||
module_globals: dict[str, Any]
|
||||
if module_name and module_name in sys.modules:
|
||||
module_globals = dict(sys.modules[module_name].__dict__)
|
||||
else:
|
||||
module_globals = {}
|
||||
|
||||
module_globals.setdefault('__builtins__', __builtins__)
|
||||
localns: dict[str, Any] = dict(attrs)
|
||||
|
||||
try:
|
||||
temp_cls = type('AnnotationProxy', (object,), dict(attrs))
|
||||
temp_cls.__module__ = module_name
|
||||
extras_kw = {'include_extras': True} if sys.version_info >= (3, 10) else {}
|
||||
evaluated = get_type_hints(
|
||||
temp_cls,
|
||||
globalns=module_globals,
|
||||
localns=localns,
|
||||
**extras_kw,
|
||||
)
|
||||
except (NameError, AttributeError, TypeError, RecursionError):
|
||||
# get_type_hints可能失败的原因:
|
||||
# - NameError: 前向引用无法解析(类型尚未定义)
|
||||
# - AttributeError: 模块或类型不存在
|
||||
# - TypeError: 无效的类型注解
|
||||
# - RecursionError: 循环依赖的类型定义
|
||||
# 这是正常情况,回退到原始注解字符串
|
||||
evaluated = base_annotations
|
||||
|
||||
return dict(evaluated), {}, module_globals, localns
|
||||
|
||||
|
||||
def _evaluate_annotation_from_string(
|
||||
field_name: str,
|
||||
annotation_strings: dict[str, str],
|
||||
current_type: Any,
|
||||
globalns: Mapping[str, Any],
|
||||
localns: Mapping[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Attempt to re-evaluate the original annotation string for a field.
|
||||
|
||||
This is used as a fallback when the resolved annotation lost its metadata
|
||||
(e.g., Annotated wrappers) and we need to recover custom sa_type data.
|
||||
"""
|
||||
if not annotation_strings:
|
||||
return current_type
|
||||
|
||||
expr = annotation_strings.get(field_name)
|
||||
if not expr or not isinstance(expr, str):
|
||||
return current_type
|
||||
|
||||
try:
|
||||
return eval(expr, globalns, localns)
|
||||
except (NameError, SyntaxError, AttributeError, TypeError):
|
||||
# eval可能失败的原因:
|
||||
# - NameError: 类型名称在namespace中不存在
|
||||
# - SyntaxError: 注解字符串有语法错误
|
||||
# - AttributeError: 访问不存在的模块属性
|
||||
# - TypeError: 无效的类型表达式
|
||||
# 这是正常的fallback机制,返回当前已解析的类型
|
||||
return current_type
|
||||
|
||||
|
||||
class __DeclarativeMeta(SQLModelMetaclass):
|
||||
"""
|
||||
一个智能的混合模式元类,它提供了灵活性和清晰度:
|
||||
|
||||
1. **自动设置 `table=True`**: 如果一个类继承了 `TableBaseMixin`,则自动应用 `table=True`。
|
||||
2. **明确的字典参数**: 支持 `mapper_args={...}`, `table_args={...}`, `table_name='...'`。
|
||||
3. **便捷的关键字参数**: 支持最常见的 mapper 参数作为顶级关键字(如 `polymorphic_on`)。
|
||||
4. **智能合并**: 当字典和关键字同时提供时,会自动合并,且关键字参数有更高优先级。
|
||||
"""
|
||||
|
||||
_KNOWN_MAPPER_KEYS = {
|
||||
"polymorphic_on",
|
||||
"polymorphic_identity",
|
||||
"polymorphic_abstract",
|
||||
"version_id_col",
|
||||
"concrete",
|
||||
}
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
# 1. 约定优于配置:自动设置 table=True
|
||||
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
|
||||
if is_intended_as_table and 'table' not in kwargs:
|
||||
kwargs['table'] = True
|
||||
|
||||
# 2. 智能合并 __mapper_args__
|
||||
collected_mapper_args = {}
|
||||
|
||||
# 首先,处理明确的 mapper_args 字典 (优先级较低)
|
||||
if 'mapper_args' in kwargs:
|
||||
collected_mapper_args.update(kwargs.pop('mapper_args'))
|
||||
|
||||
# 其次,处理便捷的关键字参数 (优先级更高)
|
||||
for key in cls._KNOWN_MAPPER_KEYS:
|
||||
if key in kwargs:
|
||||
# .pop() 获取值并移除,避免传递给父类
|
||||
collected_mapper_args[key] = kwargs.pop(key)
|
||||
|
||||
# 如果收集到了任何 mapper 参数,则更新到类的属性中
|
||||
if collected_mapper_args:
|
||||
existing = attrs.get('__mapper_args__', {}).copy()
|
||||
existing.update(collected_mapper_args)
|
||||
attrs['__mapper_args__'] = existing
|
||||
|
||||
# 3. 处理其他明确的参数
|
||||
if 'table_args' in kwargs:
|
||||
attrs['__table_args__'] = kwargs.pop('table_args')
|
||||
if 'table_name' in kwargs:
|
||||
attrs['__tablename__'] = kwargs.pop('table_name')
|
||||
if 'abstract' in kwargs:
|
||||
attrs['__abstract__'] = kwargs.pop('abstract')
|
||||
|
||||
# 4. 从Annotated元数据中提取sa_type并注入到Field
|
||||
# 重要:必须在调用父类__new__之前处理,因为SQLModel会消费annotations
|
||||
#
|
||||
# Python 3.14兼容性问题:
|
||||
# - SQLModel在Python 3.14上会因为ClassVar[T]类型而崩溃(issubclass错误)
|
||||
# - 我们必须在SQLModel看到annotations之前过滤掉ClassVar字段
|
||||
# - 虽然PEP 749建议不修改__annotations__,但这是修复SQLModel bug的必要措施
|
||||
#
|
||||
# 获取annotations的策略:
|
||||
# - Python 3.14+: 优先从__annotate__获取(如果存在)
|
||||
# - fallback: 从__annotations__读取(如果存在)
|
||||
# - 最终fallback: 空字典
|
||||
annotations, annotation_strings, eval_globals, eval_locals = _resolve_annotations(attrs)
|
||||
|
||||
if annotations:
|
||||
attrs['__annotations__'] = annotations
|
||||
if annotationlib is not None:
|
||||
# 在Python 3.14中禁用descriptor,转为普通dict
|
||||
attrs['__annotate__'] = None
|
||||
|
||||
for field_name, field_type in annotations.items():
|
||||
field_type = _evaluate_annotation_from_string(
|
||||
field_name,
|
||||
annotation_strings,
|
||||
field_type,
|
||||
eval_globals,
|
||||
eval_locals,
|
||||
)
|
||||
|
||||
# 跳过字符串或ForwardRef类型注解,让SQLModel自己处理
|
||||
if isinstance(field_type, str) or isinstance(field_type, typing.ForwardRef):
|
||||
continue
|
||||
|
||||
# 跳过特殊类型的字段
|
||||
origin = get_origin(field_type)
|
||||
|
||||
# 跳过 ClassVar 字段 - 它们不是数据库字段
|
||||
if origin is typing.ClassVar:
|
||||
continue
|
||||
|
||||
# 跳过 Mapped 字段 - SQLAlchemy 2.0+ 的声明式字段,已经有 mapped_column
|
||||
if origin is Mapped:
|
||||
continue
|
||||
|
||||
# 尝试从注解中提取sa_type
|
||||
sa_type = _extract_sa_type_from_annotation(field_type)
|
||||
|
||||
if sa_type is not None:
|
||||
# 检查字段是否已有Field定义
|
||||
field_value = attrs.get(field_name, Undefined)
|
||||
|
||||
if field_value is Undefined:
|
||||
# 没有Field定义,创建一个新的Field并注入sa_type
|
||||
attrs[field_name] = Field(sa_type=sa_type)
|
||||
elif isinstance(field_value, FieldInfo):
|
||||
# 已有Field定义,检查是否已设置sa_type
|
||||
# 注意:只有在未设置时才注入,尊重显式配置
|
||||
# SQLModel使用Undefined作为"未设置"的标记
|
||||
if not hasattr(field_value, 'sa_type') or field_value.sa_type is Undefined:
|
||||
field_value.sa_type = sa_type
|
||||
# 如果field_value是其他类型(如默认值),不处理
|
||||
# SQLModel会在后续处理中将其转换为Field
|
||||
|
||||
# 5. 调用父类的 __new__ 方法,传入被清理过的 kwargs
|
||||
result = super().__new__(cls, name, bases, attrs, **kwargs)
|
||||
|
||||
# 6. 修复:在联表继承场景下,继承父类的 __sqlmodel_relationships__
|
||||
# SQLModel 为每个 table=True 的类创建新的空 __sqlmodel_relationships__
|
||||
# 这导致子类丢失父类的关系定义,触发错误的 Column 创建
|
||||
# 必须在 super().__new__() 之后修复,因为 SQLModel 会覆盖我们预设的值
|
||||
if kwargs.get('table', False):
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__'):
|
||||
for rel_name, rel_info in base.__sqlmodel_relationships__.items():
|
||||
# 只继承子类没有重新定义的关系
|
||||
if rel_name not in result.__sqlmodel_relationships__:
|
||||
result.__sqlmodel_relationships__[rel_name] = rel_info
|
||||
# 同时修复被错误创建的 Column - 恢复为父类的 relationship
|
||||
if hasattr(base, rel_name):
|
||||
base_attr = getattr(base, rel_name)
|
||||
setattr(result, rel_name, base_attr)
|
||||
|
||||
# 7. 检测:禁止子类重定义父类的 Relationship 字段
|
||||
# 子类重定义同名的 Relationship 字段会导致 SQLAlchemy 关系映射混乱,
|
||||
# 应该在类定义时立即报错,而不是在运行时出现难以调试的问题。
|
||||
for base in bases:
|
||||
parent_relationships = getattr(base, '__sqlmodel_relationships__', {})
|
||||
for rel_name in parent_relationships:
|
||||
# 检查当前类是否在 attrs 中重新定义了这个关系字段
|
||||
if rel_name in attrs:
|
||||
raise TypeError(
|
||||
f"类 {name} 不允许重定义父类 {base.__name__} 的 Relationship 字段 '{rel_name}'。"
|
||||
f"如需修改关系配置,请在父类中修改。"
|
||||
)
|
||||
|
||||
# 8. 修复:从 model_fields/__pydantic_fields__ 中移除 Relationship 字段
|
||||
# SQLModel 0.0.27 bug:子类会错误地继承父类的 Relationship 字段到 model_fields
|
||||
# 这导致 Pydantic 尝试为 Relationship 字段生成 schema,因为类型是
|
||||
# Mapped[list['Character']] 这种前向引用,Pydantic 无法解析,
|
||||
# 导致 __pydantic_complete__ = False
|
||||
#
|
||||
# 修复策略:
|
||||
# - 检查类的 __sqlmodel_relationships__ 属性
|
||||
# - 从 model_fields 和 __pydantic_fields__ 中移除这些字段
|
||||
# - Relationship 字段由 SQLAlchemy 管理,不需要 Pydantic 参与
|
||||
relationships = getattr(result, '__sqlmodel_relationships__', {})
|
||||
if relationships:
|
||||
model_fields = getattr(result, 'model_fields', {})
|
||||
pydantic_fields = getattr(result, '__pydantic_fields__', {})
|
||||
|
||||
fields_removed = False
|
||||
for rel_name in relationships:
|
||||
if rel_name in model_fields:
|
||||
del model_fields[rel_name]
|
||||
fields_removed = True
|
||||
if rel_name in pydantic_fields:
|
||||
del pydantic_fields[rel_name]
|
||||
fields_removed = True
|
||||
|
||||
# 如果移除了字段,重新构建 Pydantic 模式
|
||||
# 注意:只在有字段被移除时才 rebuild,避免不必要的开销
|
||||
if fields_removed and hasattr(result, 'model_rebuild'):
|
||||
result.model_rebuild(force=True)
|
||||
|
||||
return result
|
||||
|
||||
def __init__(
|
||||
cls,
|
||||
classname: str,
|
||||
bases: tuple[type, ...],
|
||||
dict_: dict[str, typing.Any],
|
||||
**kw: typing.Any,
|
||||
) -> None:
|
||||
"""
|
||||
重写 SQLModel 的 __init__ 以支持联表继承(Joined Table Inheritance)
|
||||
|
||||
SQLModel 原始行为:
|
||||
- 如果任何基类是表模型,则不调用 DeclarativeMeta.__init__
|
||||
- 这阻止了子类创建自己的表
|
||||
|
||||
修复逻辑:
|
||||
- 检测联表继承场景(子类有自己的 __tablename__ 且有外键指向父表)
|
||||
- 强制调用 DeclarativeMeta.__init__ 来创建子表
|
||||
"""
|
||||
from sqlmodel.main import is_table_model_class, DeclarativeMeta, ModelMetaclass
|
||||
|
||||
# 检查是否是表模型
|
||||
if not is_table_model_class(cls):
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 检查是否有基类是表模型
|
||||
base_is_table = any(is_table_model_class(base) for base in bases)
|
||||
|
||||
if not base_is_table:
|
||||
# 没有基类是表模型,走正常的 SQLModel 流程
|
||||
# 处理关系字段
|
||||
cls._setup_relationships()
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
return
|
||||
|
||||
# 关键:检测联表继承场景
|
||||
# 条件:
|
||||
# 1. 当前类的 __tablename__ 与父类不同(表示需要新表)
|
||||
# 2. 当前类有字段带有 foreign_key 指向父表
|
||||
current_tablename = getattr(cls, '__tablename__', None)
|
||||
|
||||
# 查找父表信息
|
||||
parent_table = None
|
||||
parent_tablename = None
|
||||
for base in bases:
|
||||
if is_table_model_class(base) and hasattr(base, '__tablename__'):
|
||||
parent_tablename = base.__tablename__
|
||||
break
|
||||
|
||||
# 检查是否有不同的 tablename
|
||||
has_different_tablename = (
|
||||
current_tablename is not None
|
||||
and parent_tablename is not None
|
||||
and current_tablename != parent_tablename
|
||||
)
|
||||
|
||||
# 检查是否有外键字段指向父表的主键
|
||||
# 注意:由于字段合并,我们需要检查直接基类的 model_fields
|
||||
# 而不是当前类的合并后的 model_fields
|
||||
has_fk_to_parent = False
|
||||
|
||||
def _normalize_tablename(name: str) -> str:
|
||||
"""标准化表名以进行比较(移除下划线,转小写)"""
|
||||
return name.replace('_', '').lower()
|
||||
|
||||
def _fk_matches_parent(fk_str: str, parent_table: str) -> bool:
|
||||
"""检查 FK 字符串是否指向父表"""
|
||||
if not fk_str or not parent_table:
|
||||
return False
|
||||
# FK 格式: "tablename.column" 或 "schema.tablename.column"
|
||||
parts = fk_str.split('.')
|
||||
if len(parts) >= 2:
|
||||
fk_table = parts[-2] # 取倒数第二个作为表名
|
||||
# 标准化比较(处理下划线差异)
|
||||
return _normalize_tablename(fk_table) == _normalize_tablename(parent_table)
|
||||
return False
|
||||
|
||||
if has_different_tablename and parent_tablename:
|
||||
# 首先检查当前类的 model_fields
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
|
||||
# 如果没找到,检查直接基类的 model_fields(解决 mixin 字段被覆盖的问题)
|
||||
if not has_fk_to_parent:
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
has_fk_to_parent = True
|
||||
break
|
||||
if has_fk_to_parent:
|
||||
break
|
||||
|
||||
is_joined_inheritance = has_different_tablename and has_fk_to_parent
|
||||
|
||||
if is_joined_inheritance:
|
||||
# 联表继承:需要创建子表
|
||||
|
||||
# 修复外键字段:由于字段合并,外键信息可能丢失
|
||||
# 需要从基类的 mixin 中找回外键信息,并重建列
|
||||
from sqlalchemy import Column, ForeignKey, inspect as sa_inspect
|
||||
from sqlalchemy.dialects.postgresql import UUID as SA_UUID
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
|
||||
# 联表继承:子表只应该有 id(FK 到父表)+ 子类特有的字段
|
||||
# 所有继承自祖先表的列都不应该在子表中重复创建
|
||||
|
||||
# 收集整个继承链中所有祖先表的列名(这些列不应该在子表中重复)
|
||||
# 需要遍历整个 MRO,因为可能是多级继承(如 Tool -> Function -> GetWeatherFunction)
|
||||
ancestor_column_names: set[str] = set()
|
||||
for ancestor in cls.__mro__:
|
||||
if ancestor is cls:
|
||||
continue # 跳过当前类
|
||||
if is_table_model_class(ancestor):
|
||||
try:
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.local_table 是公开属性 (mapper.py:979-998)
|
||||
mapper = sa_inspect(ancestor)
|
||||
for col in mapper.local_table.columns:
|
||||
# 跳过 _polymorphic_name 列(鉴别器,由根父表管理)
|
||||
if col.name.startswith('_polymorphic'):
|
||||
continue
|
||||
ancestor_column_names.add(col.name)
|
||||
except NoInspectionAvailable:
|
||||
continue
|
||||
|
||||
# 找到子类自己定义的字段(不在父类中的)
|
||||
child_own_fields: set[str] = set()
|
||||
for field_name in cls.model_fields:
|
||||
# 检查这个字段是否是在当前类直接定义的(不是继承的)
|
||||
# 通过检查父类是否有这个字段来判断
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_fields.add(field_name)
|
||||
|
||||
# 从子类类属性中移除父表已有的列定义
|
||||
# 这样 SQLAlchemy 就不会在子表中创建这些列
|
||||
fk_field_name = None
|
||||
for base in bases:
|
||||
if hasattr(base, 'model_fields'):
|
||||
for field_name, field_info in base.model_fields.items():
|
||||
fk = getattr(field_info, 'foreign_key', None)
|
||||
pk = getattr(field_info, 'primary_key', False)
|
||||
if fk is not None and isinstance(fk, str) and _fk_matches_parent(fk, parent_tablename):
|
||||
fk_field_name = field_name
|
||||
# 找到了外键字段,重建它
|
||||
# 创建一个新的 Column 对象包含外键约束
|
||||
new_col = Column(
|
||||
field_name,
|
||||
SA_UUID(as_uuid=True),
|
||||
ForeignKey(fk),
|
||||
primary_key=pk if pk else False
|
||||
)
|
||||
setattr(cls, field_name, new_col)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
# 移除继承自祖先表的列属性(除了 FK/PK 和子类自己的字段)
|
||||
# 这防止 SQLAlchemy 在子表中创建重复列
|
||||
# 注意:在 __init__ 阶段,列是 Column 对象,不是 InstrumentedAttribute
|
||||
for col_name in ancestor_column_names:
|
||||
if col_name == fk_field_name:
|
||||
continue # 保留 FK/PK 列(子表的主键,同时是父表的外键)
|
||||
if col_name == 'id':
|
||||
continue # id 会被 FK 字段覆盖
|
||||
if col_name in child_own_fields:
|
||||
continue # 保留子类自己定义的字段
|
||||
|
||||
# 检查类属性是否是 Column 或 InstrumentedAttribute
|
||||
if col_name in cls.__dict__:
|
||||
attr = cls.__dict__[col_name]
|
||||
# Column 对象或 InstrumentedAttribute 都需要删除
|
||||
if isinstance(attr, (Column, InstrumentedAttribute)):
|
||||
try:
|
||||
delattr(cls, col_name)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# 找到子类自己定义的关系(不在父类中的)
|
||||
# 继承的关系会从父类自动获取,只需要设置子类新增的关系
|
||||
child_own_relationships: set[str] = set()
|
||||
for rel_name in cls.__sqlmodel_relationships__:
|
||||
is_inherited = False
|
||||
for base in bases:
|
||||
if hasattr(base, '__sqlmodel_relationships__') and rel_name in base.__sqlmodel_relationships__:
|
||||
is_inherited = True
|
||||
break
|
||||
if not is_inherited:
|
||||
child_own_relationships.add(rel_name)
|
||||
|
||||
# 只为子类自己定义的新关系调用关系设置
|
||||
if child_own_relationships:
|
||||
cls._setup_relationships(only_these=child_own_relationships)
|
||||
|
||||
# 强制调用 DeclarativeMeta.__init__
|
||||
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
|
||||
else:
|
||||
# 非联表继承:单表继承或正常 Pydantic 模型
|
||||
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
|
||||
|
||||
def _setup_relationships(cls, only_these: set[str] | None = None) -> None:
|
||||
"""
|
||||
设置 SQLAlchemy 关系字段(从 SQLModel 源码复制)
|
||||
|
||||
Args:
|
||||
only_these: 如果提供,只设置这些关系(用于 joined table inheritance 子类)
|
||||
如果为 None,设置所有关系(默认行为)
|
||||
"""
|
||||
from sqlalchemy.orm import relationship, Mapped
|
||||
from sqlalchemy import inspect
|
||||
from sqlmodel.main import get_relationship_to
|
||||
from typing import get_origin
|
||||
|
||||
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
|
||||
# 如果指定了 only_these,只设置这些关系
|
||||
if only_these is not None and rel_name not in only_these:
|
||||
continue
|
||||
if rel_info.sa_relationship:
|
||||
setattr(cls, rel_name, rel_info.sa_relationship)
|
||||
continue
|
||||
|
||||
raw_ann = cls.__annotations__[rel_name]
|
||||
origin: typing.Any = get_origin(raw_ann)
|
||||
if origin is Mapped:
|
||||
ann = raw_ann.__args__[0]
|
||||
else:
|
||||
ann = raw_ann
|
||||
cls.__annotations__[rel_name] = Mapped[ann]
|
||||
|
||||
relationship_to = get_relationship_to(
|
||||
name=rel_name, rel_info=rel_info, annotation=ann
|
||||
)
|
||||
rel_kwargs: dict[str, typing.Any] = {}
|
||||
if rel_info.back_populates:
|
||||
rel_kwargs["back_populates"] = rel_info.back_populates
|
||||
if rel_info.cascade_delete:
|
||||
rel_kwargs["cascade"] = "all, delete-orphan"
|
||||
if rel_info.passive_deletes:
|
||||
rel_kwargs["passive_deletes"] = rel_info.passive_deletes
|
||||
if rel_info.link_model:
|
||||
ins = inspect(rel_info.link_model)
|
||||
local_table = getattr(ins, "local_table")
|
||||
if local_table is None:
|
||||
raise RuntimeError(
|
||||
f"Couldn't find secondary table for {rel_info.link_model}"
|
||||
)
|
||||
rel_kwargs["secondary"] = local_table
|
||||
|
||||
rel_args: list[typing.Any] = []
|
||||
if rel_info.sa_relationship_args:
|
||||
rel_args.extend(rel_info.sa_relationship_args)
|
||||
if rel_info.sa_relationship_kwargs:
|
||||
rel_kwargs.update(rel_info.sa_relationship_kwargs)
|
||||
|
||||
rel_value = relationship(relationship_to, *rel_args, **rel_kwargs)
|
||||
setattr(cls, rel_name, rel_value)
|
||||
|
||||
|
||||
class SQLModelBase(SQLModel, metaclass=__DeclarativeMeta):
|
||||
"""此类必须和TableBase系列类搭配使用"""
|
||||
|
||||
model_config = ConfigDict(use_attribute_docstrings=True, validate_by_name=True)
|
||||
7
sqlmodels/color.py
Normal file
7
sqlmodels/color.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .base import SQLModelBase
|
||||
|
||||
class ThemeResponse(SQLModelBase):
|
||||
"""主题响应 DTO"""
|
||||
|
||||
pass
|
||||
|
||||
33
sqlmodels/database.py
Normal file
33
sqlmodels/database.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from utils.conf import appmeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import AsyncGenerator
|
||||
|
||||
ASYNC_DATABASE_URL = appmeta.database_url
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=appmeta.debug,
|
||||
connect_args={
|
||||
"check_same_thread": False
|
||||
} if ASYNC_DATABASE_URL.startswith("sqlite") else {},
|
||||
future=True,
|
||||
# pool_size=POOL_SIZE,
|
||||
# max_overflow=64,
|
||||
)
|
||||
|
||||
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async with _async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
async def init_db(
|
||||
url: str = ASYNC_DATABASE_URL
|
||||
):
|
||||
"""创建数据库结构"""
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
78
sqlmodels/database_connection.py
Normal file
78
sqlmodels/database_connection.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import AsyncGenerator, ClassVar
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
engine: ClassVar[AsyncEngine | None] = None
|
||||
_async_session_factory: ClassVar[sessionmaker | None] = None
|
||||
|
||||
@classmethod
|
||||
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
|
||||
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
|
||||
async with cls._async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
@classmethod
|
||||
async def init(
|
||||
cls,
|
||||
database_url: str,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化数据库连接引擎。
|
||||
|
||||
:param database_url: 数据库连接URL
|
||||
:param debug: 是否开启调试模式
|
||||
"""
|
||||
# 构建引擎参数
|
||||
engine_kwargs: dict = {
|
||||
'echo': debug,
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if debug:
|
||||
# Debug 模式使用 NullPool(无连接池,每次创建新连接)
|
||||
engine_kwargs['poolclass'] = NullPool
|
||||
else:
|
||||
# 生产模式使用 AsyncAdaptedQueuePool 连接池
|
||||
engine_kwargs.update({
|
||||
'poolclass': AsyncAdaptedQueuePool,
|
||||
'pool_size': 40,
|
||||
'max_overflow': 80,
|
||||
'pool_timeout': 30,
|
||||
'pool_recycle': 1800,
|
||||
'pool_pre_ping': True,
|
||||
})
|
||||
|
||||
# 只在需要时添加 connect_args
|
||||
if database_url.startswith("sqlite"):
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
cls.engine = create_async_engine(database_url, **engine_kwargs)
|
||||
|
||||
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
|
||||
|
||||
# 开发阶段直接 create_all 创建表结构
|
||||
async with cls.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("数据库引擎初始化完成")
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
"""
|
||||
优雅地关闭数据库连接引擎。
|
||||
仅应在应用结束时调用。
|
||||
"""
|
||||
if cls.engine:
|
||||
logger.info("正在关闭数据库连接引擎...")
|
||||
await cls.engine.dispose()
|
||||
logger.info("数据库连接引擎已成功关闭。")
|
||||
else:
|
||||
logger.info("数据库连接引擎未初始化,无需关闭。")
|
||||
198
sqlmodels/download.py
Normal file
198
sqlmodels/download.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin, TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .task import Task
|
||||
from .node import Node
|
||||
|
||||
class DownloadStatus(StrEnum):
|
||||
"""下载状态枚举"""
|
||||
PREPARING = "preparing"
|
||||
"""准备中"""
|
||||
RUNNING = "running"
|
||||
"""进行中"""
|
||||
COMPLETED = "completed"
|
||||
"""已完成"""
|
||||
ERROR = "error"
|
||||
"""错误"""
|
||||
|
||||
|
||||
class DownloadType(StrEnum):
|
||||
"""下载类型枚举"""
|
||||
# [TODO] 补充具体下载类型
|
||||
pass
|
||||
|
||||
|
||||
# ==================== Aria2 信息模型 ====================
|
||||
|
||||
class DownloadAria2InfoBase(SQLModelBase):
|
||||
"""Aria2下载信息基础模型"""
|
||||
|
||||
info_hash: Annotated[str | None, Field(max_length=40)] = None
|
||||
"""InfoHash(BT种子)"""
|
||||
|
||||
piece_length: int = 0
|
||||
"""分片大小"""
|
||||
|
||||
num_pieces: int = 0
|
||||
"""分片数量"""
|
||||
|
||||
num_seeders: int = 0
|
||||
"""做种人数"""
|
||||
|
||||
connections: int = 0
|
||||
"""连接数"""
|
||||
|
||||
upload_speed: int = 0
|
||||
"""上传速度(bytes/s)"""
|
||||
|
||||
upload_length: int = 0
|
||||
"""已上传大小(字节)"""
|
||||
|
||||
error_code: str | None = None
|
||||
"""错误代码"""
|
||||
|
||||
error_message: str | None = None
|
||||
"""错误信息"""
|
||||
|
||||
|
||||
class DownloadAria2Info(DownloadAria2InfoBase, SQLModelBase, table=True):
|
||||
"""Aria2下载信息模型(与Download一对一关联)"""
|
||||
|
||||
download_id: UUID = Field(
|
||||
foreign_key="download.id",
|
||||
primary_key=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的下载任务UUID"""
|
||||
|
||||
# 反向关系
|
||||
download: "Download" = Relationship(back_populates="aria2_info")
|
||||
"""关联的下载任务"""
|
||||
|
||||
|
||||
class DownloadAria2File(SQLModelBase, TableBaseMixin):
|
||||
"""Aria2下载文件列表(与Download一对多关联)"""
|
||||
|
||||
download_id: UUID = Field(
|
||||
foreign_key="download.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的下载任务UUID"""
|
||||
|
||||
file_index: int = Field(ge=1)
|
||||
"""文件索引(从1开始)"""
|
||||
|
||||
path: str
|
||||
"""文件路径"""
|
||||
|
||||
length: int = 0
|
||||
"""文件大小(字节)"""
|
||||
|
||||
completed_length: int = 0
|
||||
"""已完成大小(字节)"""
|
||||
|
||||
is_selected: bool = True
|
||||
"""是否选中下载"""
|
||||
|
||||
# 反向关系
|
||||
download: "Download" = Relationship(back_populates="aria2_files")
|
||||
"""关联的下载任务"""
|
||||
|
||||
|
||||
# ==================== 主模型 ====================
|
||||
|
||||
class DownloadBase(SQLModelBase):
|
||||
pass
|
||||
|
||||
class Download(DownloadBase, UUIDTableBaseMixin):
|
||||
"""离线下载任务模型"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("node_id", "g_id", name="uq_download_node_gid"),
|
||||
Index("ix_download_user_status", "user_id", "status"),
|
||||
)
|
||||
|
||||
status: DownloadStatus = Field(default=DownloadStatus.PREPARING, index=True)
|
||||
"""下载状态"""
|
||||
|
||||
type: int = Field(default=0)
|
||||
"""任务类型 [TODO] 待定义枚举"""
|
||||
|
||||
source: str
|
||||
"""来源URL或标识"""
|
||||
|
||||
total_size: int = Field(default=0)
|
||||
"""总大小(字节)"""
|
||||
|
||||
downloaded_size: int = Field(default=0)
|
||||
"""已下载大小(字节)"""
|
||||
|
||||
g_id: str | None = Field(default=None, index=True)
|
||||
"""Aria2 GID"""
|
||||
|
||||
speed: int = Field(default=0)
|
||||
"""下载速度(bytes/s)"""
|
||||
|
||||
parent: str | None = Field(default=None, max_length=255)
|
||||
"""父任务标识"""
|
||||
|
||||
error: str | None = Field(default=None)
|
||||
"""错误信息"""
|
||||
|
||||
dst: str
|
||||
"""目标存储路径"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
task_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key="task.id",
|
||||
index=True,
|
||||
ondelete="SET NULL"
|
||||
)
|
||||
"""关联的任务ID"""
|
||||
|
||||
node_id: int = Field(
|
||||
foreign_key="node.id",
|
||||
index=True,
|
||||
ondelete="RESTRICT"
|
||||
)
|
||||
"""执行下载的节点ID"""
|
||||
|
||||
# 关系
|
||||
aria2_info: DownloadAria2Info | None = Relationship(
|
||||
back_populates="download",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""Aria2下载信息"""
|
||||
|
||||
aria2_files: list[DownloadAria2File] = Relationship(
|
||||
back_populates="download",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""Aria2文件列表"""
|
||||
|
||||
user: "User" = Relationship(back_populates="downloads")
|
||||
"""所属用户"""
|
||||
|
||||
task: "Task" = Relationship(back_populates="downloads")
|
||||
"""关联的任务"""
|
||||
|
||||
node: "Node" = Relationship(back_populates="downloads")
|
||||
"""执行下载的节点"""
|
||||
|
||||
299
sqlmodels/group.py
Normal file
299
sqlmodels/group.py
Normal file
@@ -0,0 +1,299 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .policy import Policy
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class GroupBase(SQLModelBase):
|
||||
"""用户组基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
name: str
|
||||
"""用户组名称"""
|
||||
|
||||
|
||||
class GroupOptionsBase(SQLModelBase):
|
||||
"""用户组基础选项字段"""
|
||||
|
||||
share_download: bool = False
|
||||
"""是否允许分享下载"""
|
||||
|
||||
share_free: bool = False
|
||||
"""是否免积分获取需要积分的内容"""
|
||||
|
||||
relocate: bool = False
|
||||
"""是否允许文件重定位"""
|
||||
|
||||
source_batch: int = 0
|
||||
"""批量获取源地址数量"""
|
||||
|
||||
select_node: bool = False
|
||||
"""是否允许选择节点"""
|
||||
|
||||
advance_delete: bool = False
|
||||
"""是否允许高级删除"""
|
||||
|
||||
|
||||
class GroupAllOptionsBase(GroupOptionsBase):
|
||||
"""用户组完整选项字段,供 DTO 和数据库模型共享"""
|
||||
|
||||
archive_download: bool = False
|
||||
"""是否允许打包下载"""
|
||||
|
||||
archive_task: bool = False
|
||||
"""是否允许创建打包任务"""
|
||||
|
||||
webdav_proxy: bool = False
|
||||
"""是否允许WebDAV代理"""
|
||||
|
||||
aria2: bool = False
|
||||
"""是否允许使用aria2"""
|
||||
|
||||
redirected_source: bool = False
|
||||
"""是否使用重定向源"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class GroupCreateRequest(GroupAllOptionsBase):
|
||||
"""创建用户组请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int = Field(default=0, ge=0)
|
||||
"""最大存储空间(字节),0表示不限制"""
|
||||
|
||||
share_enabled: bool = False
|
||||
"""是否允许创建分享"""
|
||||
|
||||
web_dav_enabled: bool = False
|
||||
"""是否允许使用WebDAV"""
|
||||
|
||||
speed_limit: int = Field(default=0, ge=0)
|
||||
"""速度限制 (KB/s), 0为不限制"""
|
||||
|
||||
source_batch: int = Field(default=0, ge=0)
|
||||
"""批量获取源地址数量(覆盖基类以添加 ge 约束)"""
|
||||
|
||||
policy_ids: list[UUID] = []
|
||||
"""关联的存储策略UUID列表"""
|
||||
|
||||
|
||||
class GroupUpdateRequest(SQLModelBase):
|
||||
"""更新用户组请求 DTO(所有字段可选)"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=255)
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int | None = Field(default=None, ge=0)
|
||||
"""最大存储空间(字节)"""
|
||||
|
||||
share_enabled: bool | None = None
|
||||
"""是否允许创建分享"""
|
||||
|
||||
web_dav_enabled: bool | None = None
|
||||
"""是否允许使用WebDAV"""
|
||||
|
||||
speed_limit: int | None = Field(default=None, ge=0)
|
||||
"""速度限制 (KB/s)"""
|
||||
|
||||
# 用户组选项
|
||||
share_download: bool | None = None
|
||||
share_free: bool | None = None
|
||||
relocate: bool | None = None
|
||||
source_batch: int | None = None
|
||||
select_node: bool | None = None
|
||||
advance_delete: bool | None = None
|
||||
archive_download: bool | None = None
|
||||
archive_task: bool | None = None
|
||||
webdav_proxy: bool | None = None
|
||||
aria2: bool | None = None
|
||||
redirected_source: bool | None = None
|
||||
|
||||
policy_ids: list[UUID] | None = None
|
||||
"""关联的存储策略UUID列表"""
|
||||
|
||||
|
||||
class GroupCoreBase(SQLModelBase):
|
||||
"""用户组核心字段(从 Group 模型提取)"""
|
||||
|
||||
id: UUID
|
||||
"""用户组UUID"""
|
||||
|
||||
name: str
|
||||
"""用户组名称"""
|
||||
|
||||
max_storage: int = 0
|
||||
"""最大存储空间(字节)"""
|
||||
|
||||
share_enabled: bool = False
|
||||
"""是否允许创建分享"""
|
||||
|
||||
web_dav_enabled: bool = False
|
||||
"""是否允许使用WebDAV"""
|
||||
|
||||
admin: bool = False
|
||||
"""是否为管理员组"""
|
||||
|
||||
speed_limit: int = 0
|
||||
"""速度限制 (KB/s)"""
|
||||
|
||||
|
||||
class GroupDetailResponse(GroupCoreBase, GroupAllOptionsBase):
|
||||
"""用户组详情响应 DTO"""
|
||||
|
||||
user_count: int = 0
|
||||
"""用户数量"""
|
||||
|
||||
policy_ids: list[UUID] = []
|
||||
"""关联的存储策略UUID列表"""
|
||||
|
||||
@classmethod
|
||||
def from_group(
|
||||
cls,
|
||||
group: "Group",
|
||||
user_count: int,
|
||||
policies: list["Policy"],
|
||||
) -> "GroupDetailResponse":
|
||||
"""从 Group ORM 对象构建"""
|
||||
opts = group.options
|
||||
return cls(
|
||||
# GroupCoreBase 字段(从 Group 模型提取)
|
||||
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
|
||||
# GroupAllOptionsBase 字段(从 GroupOptions 提取)
|
||||
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
|
||||
# 计算字段
|
||||
user_count=user_count,
|
||||
policy_ids=[p.id for p in policies],
|
||||
)
|
||||
|
||||
|
||||
class GroupListResponse(SQLModelBase):
|
||||
"""用户组列表响应 DTO"""
|
||||
|
||||
groups: list["GroupDetailResponse"] = []
|
||||
"""用户组列表"""
|
||||
|
||||
total: int = 0
|
||||
"""总数"""
|
||||
|
||||
|
||||
class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
"""用户组响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""用户组UUID"""
|
||||
|
||||
allow_share: bool = False
|
||||
"""是否允许分享"""
|
||||
|
||||
allow_remote_download: bool = False
|
||||
"""是否允许离线下载"""
|
||||
|
||||
allow_archive_download: bool = False
|
||||
"""是否允许打包下载"""
|
||||
|
||||
compress: bool = False
|
||||
"""是否允许压缩"""
|
||||
|
||||
webdav: bool = False
|
||||
"""是否允许WebDAV"""
|
||||
|
||||
allow_webdav_proxy: bool = False
|
||||
"""是否允许WebDAV代理"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
# GroupPolicyLink 定义在 policy.py 中以避免循环导入
|
||||
from .policy import GroupPolicyLink
|
||||
|
||||
|
||||
class GroupOptions(GroupAllOptionsBase, TableBaseMixin):
|
||||
"""用户组选项模型"""
|
||||
|
||||
group_id: UUID = Field(
|
||||
foreign_key="group.id",
|
||||
unique=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的用户组UUID"""
|
||||
|
||||
# 反向关系
|
||||
group: "Group" = Relationship(back_populates="options")
|
||||
|
||||
|
||||
class Group(GroupBase, UUIDTableBaseMixin):
|
||||
"""用户组模型"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
"""用户组名"""
|
||||
|
||||
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""最大存储空间(字节)"""
|
||||
|
||||
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否允许创建分享"""
|
||||
|
||||
web_dav_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否允许使用WebDAV"""
|
||||
|
||||
admin: bool = False
|
||||
"""是否为管理员组"""
|
||||
|
||||
speed_limit: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""速度限制 (KB/s), 0为不限制"""
|
||||
|
||||
# 一对一关系:用户组选项
|
||||
options: GroupOptions | None = Relationship(
|
||||
back_populates="group",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
# 多对多关系:用户组可以关联多个存储策略
|
||||
policies: list["Policy"] = Relationship(
|
||||
back_populates="groups",
|
||||
link_model=GroupPolicyLink,
|
||||
)
|
||||
|
||||
# 关系:一个组可以有多个用户
|
||||
users: list["User"] = Relationship(
|
||||
back_populates="group",
|
||||
sa_relationship_kwargs={"foreign_keys": "User.group_id"}
|
||||
)
|
||||
"""当前属于该组的用户列表"""
|
||||
|
||||
previous_users: list["User"] = Relationship(
|
||||
back_populates="previous_group",
|
||||
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||
)
|
||||
"""之前属于该组的用户列表(用于过期后恢复)"""
|
||||
|
||||
def to_response(self) -> "GroupResponse":
|
||||
"""转换为响应 DTO"""
|
||||
opts = self.options
|
||||
return GroupResponse(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
allow_share=self.share_enabled,
|
||||
webdav=self.web_dav_enabled,
|
||||
share_download=opts.share_download if opts else False,
|
||||
share_free=opts.share_free if opts else False,
|
||||
relocate=opts.relocate if opts else False,
|
||||
source_batch=opts.source_batch if opts else 0,
|
||||
select_node=opts.select_node if opts else False,
|
||||
advance_delete=opts.advance_delete if opts else False,
|
||||
allow_remote_download=opts.aria2 if opts else False,
|
||||
allow_archive_download=opts.archive_download if opts else False,
|
||||
allow_webdav_proxy=opts.webdav_proxy if opts else False,
|
||||
)
|
||||
326
sqlmodels/migration.py
Normal file
326
sqlmodels/migration.py
Normal file
@@ -0,0 +1,326 @@
|
||||
|
||||
from .setting import Setting, SettingsType
|
||||
from .color import ThemeResponse
|
||||
from utils.conf.appmeta import BackendVersion
|
||||
from utils.password.pwd import Password
|
||||
from loguru import logger as log
|
||||
|
||||
async def migration() -> None:
|
||||
"""
|
||||
数据库迁移函数,初始化默认设置和用户组。
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
log.info('开始进行数据库初始化...')
|
||||
|
||||
await init_default_settings()
|
||||
await init_default_policy()
|
||||
await init_default_group()
|
||||
await init_default_user()
|
||||
|
||||
log.info('数据库初始化结束')
|
||||
|
||||
default_settings: list[Setting] = [
|
||||
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
|
||||
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
|
||||
Setting(name="default_group", value="", type=SettingsType.REGISTER),
|
||||
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
|
||||
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
|
||||
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
|
||||
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
|
||||
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpHost", value="smtp.yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpPort", value="25", type=SettingsType.MAIL),
|
||||
Setting(name="replyTo", value="feedback@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpUser", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
Setting(name="smtpPass", value="", type=SettingsType.MAIL),
|
||||
Setting(name="maxEditSize", value="4194304", type=SettingsType.FILE_EDIT),
|
||||
Setting(name="archive_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="download_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="preview_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="doc_preview_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="upload_credential_timeout", value="1800", type=SettingsType.TIMEOUT),
|
||||
Setting(name="upload_session_timeout", value="86400", type=SettingsType.TIMEOUT),
|
||||
Setting(name="slave_api_timeout", value="60", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_monitor_timeout", value="600", type=SettingsType.TIMEOUT),
|
||||
Setting(name="share_download_session_timeout", value="2073600", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_callback_check", value="20", type=SettingsType.TIMEOUT),
|
||||
Setting(name="aria2_call_timeout", value="5", type=SettingsType.TIMEOUT),
|
||||
Setting(name="onedrive_chunk_retries", value="1", type=SettingsType.RETRY),
|
||||
Setting(name="onedrive_source_timeout", value="1800", type=SettingsType.TIMEOUT),
|
||||
Setting(name="reset_after_upload_failed", value="0", type=SettingsType.UPLOAD),
|
||||
Setting(name="login_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="reg_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="email_active", value="0", type=SettingsType.REGISTER),
|
||||
Setting(name="mail_activation_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
|
||||
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>激活您的账户</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
|
||||
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
|
||||
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
|
||||
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
|
||||
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #009688; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">激活{siteTitle}账户</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
|
||||
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong>:</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您注册{siteTitle},请点击下方按钮完成账户激活。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{activationUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #009688; margin: 0; border-color: #009688; border-style: solid; border-width: 10px 20px;">激活账户</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
|
||||
Setting(name="forget_captcha", value="0", type=SettingsType.LOGIN),
|
||||
Setting(name="mail_reset_pwd_template", value="""<!DOCTYPE html PUBLIC"-//W3C//DTD XHTML 1.0 Transitional//EN""http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"><html xmlns="http://www.w3.org/1999/xhtml"style="font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; box-sizing: border-box;
|
||||
font-size: 14px; margin: 0;"><head><meta name="viewport"content="width=device-width"/><meta http-equiv="Content-Type"content="text/html; charset=UTF-8"/><title>重设密码</title><style type="text/css">img{max-width:100%}body{-webkit-font-smoothing:antialiased;-webkit-text-size-adjust:none;width:100%!important;height:100%;line-height:1.6em}body{background-color:#f6f6f6}@media only screen and(max-width:640px){body{padding:0!important}h1{font-weight:800!important;margin:20px 0 5px!important}h2{font-weight:800!important;margin:20px 0 5px!important}h3{font-weight:800!important;margin:20px 0 5px!important}h4{font-weight:800!important;margin:20px 0 5px!important}h1{font-size:22px!important}h2{font-size:18px!important}h3{font-size:16px!important}.container{padding:0!important;width:100%!important}.content{padding:0!important}.content-wrap{padding:10px!important}.invoice{width:100%!important}}</style></head><body itemscope itemtype="http://schema.org/EmailMessage"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing:
|
||||
border-box; font-size: 14px; -webkit-font-smoothing: antialiased; -webkit-text-size-adjust: none; width: 100% !important; height: 100%; line-height: 1.6em; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><table class="body-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; background-color: #f6f6f6; margin: 0;"bgcolor="#f6f6f6"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif;
|
||||
box-sizing: border-box; font-size: 14px; margin: 0;"><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td><td class="container"width="600"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; display: block !important; max-width: 600px !important; clear: both !important; margin: 0 auto;"valign="top"><div class="content"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; max-width: 600px; display: block; margin: 0 auto; padding: 20px;"><table class="main"width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; border-radius: 3px; background-color: #fff; margin: 0; border: 1px
|
||||
solid #e9e9e9;"bgcolor="#fff"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size:
|
||||
14px; margin: 0;"><td class="alert alert-warning"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 16px; vertical-align: top; color: #fff; font-weight: 500; text-align: center; border-radius: 3px 3px 0 0; background-color: #2196F3; margin: 0; padding: 20px;"align="center"bgcolor="#FF9F00"valign="top">重设{siteTitle}密码</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-wrap"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 20px;"valign="top"><table width="100%"cellpadding="0"cellspacing="0"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica
|
||||
Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">亲爱的<strong style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;">{userName}</strong>:</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">请点击下方按钮完成密码重设。如果非你本人操作,请忽略此邮件。</td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top"><a href="{resetUrl}"class="btn-primary"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; color: #FFF; text-decoration: none; line-height: 2em; font-weight: bold; text-align: center; cursor: pointer; display: inline-block; border-radius: 5px; text-transform: capitalize; background-color: #2196F3; margin: 0; border-color: #2196F3; border-style: solid; border-width: 10px 20px;">重设密码</a></td></tr><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0; padding: 0 0 20px;"valign="top">感谢您选择{siteTitle}。</td></tr></table></td></tr></table><div class="footer"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; width: 100%; clear: both; color: #999; margin: 0; padding: 20px;"><table width="100%"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><tr style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; margin: 0;"><td class="aligncenter content-block"style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 12px; vertical-align: top; color: #999; text-align: center; margin: 0; padding: 0 0 20px;"align="center"valign="top">此邮件由系统自动发送,请不要直接回复。</td></tr></table></div></div></td><td style="font-family: 'Helvetica Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; vertical-align: top; margin: 0;"valign="top"></td></tr></table></body></html>""", type=SettingsType.MAIL_TEMPLATE),
|
||||
Setting(name=f"db_version_{BackendVersion}", value="installed", type=SettingsType.VERSION),
|
||||
Setting(name="hot_share_num", value="10", type=SettingsType.SHARE),
|
||||
Setting(name="gravatar_server", value="https://www.gravatar.com/", type=SettingsType.AVATAR),
|
||||
Setting(name="defaultTheme", value="#3f51b5", type=SettingsType.BASIC),
|
||||
Setting(name="themes", value=ThemeResponse().model_dump_json(), type=SettingsType.BASIC),
|
||||
Setting(name="aria2_token", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_rpcurl", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_temp_path", value="", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_options", value="{}", type=SettingsType.ARIA2),
|
||||
Setting(name="aria2_interval", value="60", type=SettingsType.ARIA2),
|
||||
Setting(name="max_worker_num", value="10", type=SettingsType.TASK),
|
||||
Setting(name="max_parallel_transfer", value="4", type=SettingsType.TASK),
|
||||
Setting(name="secret_key", value=Password.generate(256), type=SettingsType.AUTH),
|
||||
Setting(name="temp_path", value="temp", type=SettingsType.PATH),
|
||||
Setting(name="avatar_path", value="avatar", type=SettingsType.PATH),
|
||||
Setting(name="avatar_size", value="2097152", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_l", value="200", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_m", value="130", type=SettingsType.AVATAR),
|
||||
Setting(name="avatar_size_s", value="50", type=SettingsType.AVATAR),
|
||||
Setting(name="home_view_method", value="icon", type=SettingsType.VIEW),
|
||||
Setting(name="share_view_method", value="list", type=SettingsType.VIEW),
|
||||
Setting(name="cron_garbage_collect", value="@hourly", type=SettingsType.CRON),
|
||||
Setting(name="authn_enabled", value="0", type=SettingsType.AUTHN),
|
||||
Setting(name="captcha_height", value="60", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_width", value="240", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_mode", value="3", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ComplexOfNoiseText", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ComplexOfNoiseDot", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowHollowLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowNoiseDot", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowNoiseText", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
|
||||
Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
|
||||
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
|
||||
Setting(name="pwa_medium_icon", value="/static/img/logo192.png", type=SettingsType.PWA),
|
||||
Setting(name="pwa_large_icon", value="/static/img/logo512.png", type=SettingsType.PWA),
|
||||
Setting(name="pwa_display", value="standalone", type=SettingsType.PWA),
|
||||
Setting(name="pwa_theme_color", value="#000000", type=SettingsType.PWA),
|
||||
Setting(name="pwa_background_color", value="#ffffff", type=SettingsType.PWA),
|
||||
]
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
|
||||
)
|
||||
if ver and ver.value == "installed":
|
||||
return
|
||||
|
||||
# 批量添加默认设置
|
||||
await Setting.add(session, default_settings)
|
||||
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group, GroupOptions
|
||||
from .policy import Policy, GroupPolicyLink
|
||||
from .setting import Setting
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
default_policy_id = default_policy.id if default_policy else None
|
||||
|
||||
# 未找到初始管理组时,则创建
|
||||
if not await Group.get(session, Group.name == "管理员"):
|
||||
admin_group = Group(
|
||||
name="管理员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
admin=True,
|
||||
)
|
||||
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
||||
await admin_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=admin_group_id,
|
||||
archive_download=True,
|
||||
archive_task=True,
|
||||
share_download=True,
|
||||
share_free=True,
|
||||
aria2=True,
|
||||
select_node=True,
|
||||
advance_delete=True,
|
||||
).save(session)
|
||||
|
||||
# 关联默认存储策略
|
||||
if default_policy_id:
|
||||
session.add(GroupPolicyLink(
|
||||
group_id=admin_group_id,
|
||||
policy_id=default_policy_id,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not await Group.get(session, Group.name == "注册会员"):
|
||||
member_group = Group(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
)
|
||||
member_group_id = member_group.id # 在 save 前保存 UUID
|
||||
await member_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=member_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
# 关联默认存储策略
|
||||
if default_policy_id:
|
||||
session.add(GroupPolicyLink(
|
||||
group_id=member_group_id,
|
||||
policy_id=default_policy_id,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
# 更新 default_group 设置为注册会员组的 UUID
|
||||
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
||||
if default_group_setting:
|
||||
default_group_setting.value = str(member_group_id)
|
||||
await default_group_setting.save(session)
|
||||
|
||||
# 未找到初始游客组时,则创建
|
||||
if not await Group.get(session, Group.name == "游客"):
|
||||
guest_group = Group(
|
||||
name="游客",
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
)
|
||||
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
||||
await guest_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=guest_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
# 游客组不关联存储策略(无法上传)
|
||||
|
||||
async def init_default_user() -> None:
|
||||
from .user import User
|
||||
from .group import Group
|
||||
from .object import Object, ObjectType
|
||||
from .policy import Policy
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断)
|
||||
admin_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
admin_user = None
|
||||
if admin_id_setting and admin_id_setting.value:
|
||||
from uuid import UUID
|
||||
admin_user = await User.get(session, User.id == UUID(admin_id_setting.value))
|
||||
|
||||
if not admin_user:
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(session, Group.name == "管理员")
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
raise RuntimeError("默认存储策略不存在,无法创建管理员用户")
|
||||
default_policy_id = default_policy.id # 在后续 save 前保存 UUID
|
||||
|
||||
# 生成管理员密码
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
email="admin@disknext.local",
|
||||
nickname="admin",
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
await admin_user.save(session)
|
||||
|
||||
# 记录默认管理员 ID 到 Setting
|
||||
await Setting(
|
||||
name="default_admin_id",
|
||||
value=str(admin_user_id),
|
||||
type=SettingsType.AUTH,
|
||||
).save(session)
|
||||
|
||||
# 为管理员创建根目录
|
||||
await Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy_id,
|
||||
).save(session)
|
||||
|
||||
log.warning('请注意,账号密码仅显示一次,请妥善保管')
|
||||
log.info(f'初始管理员邮箱: admin@disknext.local')
|
||||
log.info(f'初始管理员密码: {admin_password}')
|
||||
|
||||
|
||||
async def init_default_policy() -> None:
|
||||
from .policy import Policy, PolicyType
|
||||
from .database_connection import DatabaseManager
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
log.info('初始化默认存储策略...')
|
||||
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查默认存储策略是否存在
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
|
||||
if not default_policy:
|
||||
local_policy = Policy(
|
||||
name="本地存储",
|
||||
type=PolicyType.LOCAL,
|
||||
server="./data",
|
||||
is_private=True,
|
||||
max_size=0,
|
||||
auto_rename=True,
|
||||
dir_name_rule="{date}/{randomkey16}",
|
||||
file_name_rule="{randomkey16}_{originname}",
|
||||
)
|
||||
|
||||
local_policy = await local_policy.save(session)
|
||||
|
||||
# 创建物理存储目录
|
||||
storage_service = LocalStorageService(local_policy)
|
||||
await storage_service.ensure_base_directory()
|
||||
|
||||
log.info('已创建默认本地存储策略,存储目录:./data')
|
||||
543
sqlmodels/mixin/README.md
Normal file
543
sqlmodels/mixin/README.md
Normal file
@@ -0,0 +1,543 @@
|
||||
# SQLModel Mixin Module
|
||||
|
||||
This module provides composable Mixin classes for SQLModel entities, enabling reusable functionality such as CRUD operations, polymorphic inheritance, JWT authentication, and standardized response DTOs.
|
||||
|
||||
## Module Overview
|
||||
|
||||
The `sqlmodels.mixin` module contains various Mixin classes that follow the "Composition over Inheritance" design philosophy. These mixins provide:
|
||||
|
||||
- **CRUD Operations**: Async database operations (add, save, update, delete, get, count)
|
||||
- **Polymorphic Inheritance**: Tools for joined table inheritance patterns
|
||||
- **JWT Authentication**: Token generation and validation
|
||||
- **Pagination & Sorting**: Standardized table view parameters
|
||||
- **Response DTOs**: Consistent id/timestamp fields for API responses
|
||||
|
||||
## Module Structure
|
||||
|
||||
```
|
||||
sqlmodels/mixin/
|
||||
├── __init__.py # Module exports
|
||||
├── polymorphic.py # PolymorphicBaseMixin, create_subclass_id_mixin, AutoPolymorphicIdentityMixin
|
||||
├── table.py # TableBaseMixin, UUIDTableBaseMixin, TableViewRequest
|
||||
├── info_response.py # Response DTO Mixins (IntIdInfoMixin, UUIDIdInfoMixin, etc.)
|
||||
└── jwt/ # JWT authentication
|
||||
├── __init__.py
|
||||
├── key.py # JWTKey database model
|
||||
├── payload.py # JWTPayloadBase
|
||||
├── manager.py # JWTManager singleton
|
||||
├── auth.py # JWTAuthMixin
|
||||
├── exceptions.py # JWT-related exceptions
|
||||
└── responses.py # TokenResponse DTO
|
||||
```
|
||||
|
||||
## Dependency Hierarchy
|
||||
|
||||
The module has a strict import order to avoid circular dependencies:
|
||||
|
||||
1. **polymorphic.py** - Only depends on `SQLModelBase`
|
||||
2. **table.py** - Depends on `polymorphic.py`
|
||||
3. **jwt/** - May depend on both `polymorphic.py` and `table.py`
|
||||
4. **info_response.py** - Only depends on `SQLModelBase`
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. TableBaseMixin
|
||||
|
||||
Base mixin for database table models with integer primary keys.
|
||||
|
||||
**Features:**
|
||||
- Provides CRUD methods: `add()`, `save()`, `update()`, `delete()`, `get()`, `count()`, `get_exist_one()`
|
||||
- Automatic timestamp management (`created_at`, `updated_at`)
|
||||
- Async relationship loading support (via `AsyncAttrs`)
|
||||
- Pagination and sorting via `TableViewRequest`
|
||||
- Polymorphic subclass loading support
|
||||
|
||||
**Fields:**
|
||||
- `id: int | None` - Integer primary key (auto-increment)
|
||||
- `created_at: datetime` - Record creation timestamp
|
||||
- `updated_at: datetime` - Record update timestamp (auto-updated)
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import TableBaseMixin
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class User(SQLModelBase, TableBaseMixin, table=True):
|
||||
name: str
|
||||
email: str
|
||||
"""User email"""
|
||||
|
||||
# CRUD operations
|
||||
async def example(session: AsyncSession):
|
||||
# Add
|
||||
user = User(name="Alice", email="alice@example.com")
|
||||
user = await user.save(session)
|
||||
|
||||
# Get
|
||||
user = await User.get(session, User.id == 1)
|
||||
|
||||
# Update
|
||||
update_data = UserUpdateRequest(name="Alice Smith")
|
||||
user = await user.update(session, update_data)
|
||||
|
||||
# Delete
|
||||
await User.delete(session, user)
|
||||
|
||||
# Count
|
||||
count = await User.count(session, User.is_active == True)
|
||||
```
|
||||
|
||||
**Important Notes:**
|
||||
- `save()` and `update()` return refreshed instances - **always use the return value**:
|
||||
```python
|
||||
# ✅ Correct
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device
|
||||
```
|
||||
|
||||
### 2. UUIDTableBaseMixin
|
||||
|
||||
Extends `TableBaseMixin` with UUID primary keys instead of integers.
|
||||
|
||||
**Differences from TableBaseMixin:**
|
||||
- `id: UUID` - UUID primary key (auto-generated via `uuid.uuid4()`)
|
||||
- `get_exist_one()` accepts `UUID` instead of `int`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
|
||||
class Character(SQLModelBase, UUIDTableBaseMixin, table=True):
|
||||
name: str
|
||||
description: str | None = None
|
||||
"""Character description"""
|
||||
```
|
||||
|
||||
**Recommendation:** Use `UUIDTableBaseMixin` for most new models, as UUIDs provide better scalability and avoid ID collisions.
|
||||
|
||||
### 3. TableViewRequest
|
||||
|
||||
Standardized pagination and sorting parameters for LIST endpoints.
|
||||
|
||||
**Fields:**
|
||||
- `offset: int | None` - Skip first N records (default: 0)
|
||||
- `limit: int | None` - Return max N records (default: 50, max: 100)
|
||||
- `desc: bool | None` - Sort descending (default: True)
|
||||
- `order: Literal["created_at", "updated_at"] | None` - Sort field (default: "created_at")
|
||||
|
||||
**Usage with TableBaseMixin.get():**
|
||||
```python
|
||||
from dependencies import TableViewRequestDep
|
||||
|
||||
@router.get("/list")
|
||||
async def list_characters(
|
||||
session: SessionDep,
|
||||
table_view: TableViewRequestDep
|
||||
) -> list[Character]:
|
||||
"""List characters with pagination and sorting"""
|
||||
return await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view # Automatically handles pagination and sorting
|
||||
)
|
||||
```
|
||||
|
||||
**Manual usage:**
|
||||
```python
|
||||
table_view = TableViewRequest(offset=0, limit=20, desc=True, order="created_at")
|
||||
characters = await Character.get(session, fetch_mode="all", table_view=table_view)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
The traditional `offset`, `limit`, `order_by` parameters still work, but `table_view` is recommended for new code.
|
||||
|
||||
### 4. PolymorphicBaseMixin
|
||||
|
||||
Base mixin for joined table inheritance, automatically configuring polymorphic settings.
|
||||
|
||||
**Automatic Configuration:**
|
||||
- Defines `_polymorphic_name: str` field (indexed)
|
||||
- Sets `polymorphic_on='_polymorphic_name'`
|
||||
- Detects abstract classes (via ABC and abstract methods) and sets `polymorphic_abstract=True`
|
||||
|
||||
**Methods:**
|
||||
- `get_concrete_subclasses()` - Get all non-abstract subclasses (for `selectin_polymorphic`)
|
||||
- `get_polymorphic_discriminator()` - Get the polymorphic discriminator field name
|
||||
- `get_identity_to_class_map()` - Map `polymorphic_identity` to subclass types
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.mixin import PolymorphicBaseMixin, UUIDTableBaseMixin
|
||||
|
||||
class Tool(PolymorphicBaseMixin, UUIDTableBaseMixin, ABC):
|
||||
"""Abstract base class for all tools"""
|
||||
name: str
|
||||
description: str
|
||||
"""Tool description"""
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, params: dict) -> dict:
|
||||
"""Execute the tool"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Why Single Underscore Prefix?**
|
||||
- SQLAlchemy maps single-underscore fields to database columns
|
||||
- Pydantic treats them as private (excluded from serialization)
|
||||
- Double-underscore fields would be excluded by SQLAlchemy (not mapped to database)
|
||||
|
||||
### 5. create_subclass_id_mixin()
|
||||
|
||||
Factory function to create ID mixins for subclasses in joined table inheritance.
|
||||
|
||||
**Purpose:** In joined table inheritance, subclasses need a foreign key pointing to the parent table's primary key. This function generates a mixin class providing that foreign key field.
|
||||
|
||||
**Signature:**
|
||||
```python
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type[SQLModelBase]:
|
||||
"""
|
||||
Args:
|
||||
parent_table_name: Parent table name (e.g., 'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
A mixin class containing id field (foreign key + primary key)
|
||||
"""
|
||||
```
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import create_subclass_id_mixin
|
||||
|
||||
# Create mixin for ASR subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
class FunASR(ASRSubclassIdMixin, ASR, AutoPolymorphicIdentityMixin, table=True):
|
||||
"""FunASR implementation"""
|
||||
pass
|
||||
```
|
||||
|
||||
**Important:** The ID mixin **must be first in the inheritance list** to ensure MRO (Method Resolution Order) correctly overrides the parent's `id` field.
|
||||
|
||||
### 6. AutoPolymorphicIdentityMixin
|
||||
|
||||
Automatically generates `polymorphic_identity` based on class name.
|
||||
|
||||
**Naming Convention:**
|
||||
- Format: `{parent_identity}.{classname_lowercase}`
|
||||
- If no parent identity exists, uses `{classname_lowercase}`
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import AutoPolymorphicIdentityMixin
|
||||
|
||||
class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
"""Base class for function-type tools"""
|
||||
pass
|
||||
# polymorphic_identity = 'function'
|
||||
|
||||
class GetWeatherFunction(Function, table=True):
|
||||
"""Weather query function"""
|
||||
pass
|
||||
# polymorphic_identity = 'function.getweatherfunction'
|
||||
```
|
||||
|
||||
**Manual Override:**
|
||||
```python
|
||||
class CustomTool(
|
||||
Tool,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_identity='custom_name', # Override auto-generated name
|
||||
table=True
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
### 7. JWTAuthMixin
|
||||
|
||||
Provides JWT token generation and validation for entity classes (User, Client).
|
||||
|
||||
**Methods:**
|
||||
- `async issue_jwt(session: AsyncSession) -> str` - Generate JWT token for current instance
|
||||
- `@classmethod async from_jwt(session: AsyncSession, token: str) -> Self` - Validate token and retrieve entity
|
||||
|
||||
**Requirements:**
|
||||
Subclasses must define:
|
||||
- `JWTPayload` - Payload model (inherits from `JWTPayloadBase`)
|
||||
- `jwt_key_purpose` - ClassVar specifying the JWT key purpose enum value
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import JWTAuthMixin, UUIDTableBaseMixin
|
||||
|
||||
class User(SQLModelBase, UUIDTableBaseMixin, JWTAuthMixin, table=True):
|
||||
JWTPayload = UserJWTPayload # Define payload model
|
||||
jwt_key_purpose: ClassVar[JWTKeyPurposeEnum] = JWTKeyPurposeEnum.user
|
||||
|
||||
email: str
|
||||
is_admin: bool = False
|
||||
is_active: bool = True
|
||||
"""User active status"""
|
||||
|
||||
# Generate token
|
||||
async def login(session: AsyncSession, user: User) -> str:
|
||||
token = await user.issue_jwt(session)
|
||||
return token
|
||||
|
||||
# Validate token
|
||||
async def verify(session: AsyncSession, token: str) -> User:
|
||||
user = await User.from_jwt(session, token)
|
||||
return user
|
||||
```
|
||||
|
||||
### 8. Response DTO Mixins
|
||||
|
||||
Mixins for standardized InfoResponse DTOs, defining id and timestamp fields.
|
||||
|
||||
**Available Mixins:**
|
||||
- `IntIdInfoMixin` - Integer ID field
|
||||
- `UUIDIdInfoMixin` - UUID ID field
|
||||
- `DatetimeInfoMixin` - `created_at` and `updated_at` fields
|
||||
- `IntIdDatetimeInfoMixin` - Integer ID + timestamps
|
||||
- `UUIDIdDatetimeInfoMixin` - UUID ID + timestamps
|
||||
|
||||
**Design Note:** These fields are non-nullable in DTOs because database records always have these values when returned.
|
||||
|
||||
**Usage:**
|
||||
```python
|
||||
from sqlmodels.mixin import UUIDIdDatetimeInfoMixin
|
||||
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
"""Character response DTO with id and timestamps"""
|
||||
pass # Inherits id, created_at, updated_at from mixin
|
||||
```
|
||||
|
||||
## Complete Joined Table Inheritance Example
|
||||
|
||||
Here's a complete example demonstrating polymorphic inheritance:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import (
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
)
|
||||
|
||||
# 1. Define Base class (fields only, no table)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
"""Configuration name"""
|
||||
|
||||
base_url: str
|
||||
"""Service URL"""
|
||||
|
||||
# 2. Define abstract parent class (with table)
|
||||
class ASR(ASRBase, UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
"""Abstract base class for ASR configurations"""
|
||||
# PolymorphicBaseMixin automatically provides:
|
||||
# - _polymorphic_name field
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True (when ABC with abstract methods)
|
||||
|
||||
@abstractmethod
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
"""Transcribe audio to text"""
|
||||
pass
|
||||
|
||||
# 3. Create ID Mixin for second-level subclasses
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. Create second-level abstract class (if needed)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
"""FunASR abstract base (may have multiple implementations)"""
|
||||
pass
|
||||
# polymorphic_identity = 'funasr'
|
||||
|
||||
# 5. Create concrete implementation classes
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
"""FunASR local deployment"""
|
||||
# polymorphic_identity = 'funasr.funasrlocal'
|
||||
|
||||
async def transcribe(self, pcm_data: bytes) -> str:
|
||||
# Implementation...
|
||||
return "transcribed text"
|
||||
|
||||
# 6. Get all concrete subclasses (for selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# Returns: [FunASRLocal, ...]
|
||||
```
|
||||
|
||||
## Import Guidelines
|
||||
|
||||
**Standard Import:**
|
||||
```python
|
||||
from sqlmodels.mixin import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
TableViewRequest,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
JWTAuthMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
```
|
||||
|
||||
**Backward Compatibility:**
|
||||
Some exports are also available from `sqlmodels.base` for backward compatibility:
|
||||
```python
|
||||
# Legacy import path (still works)
|
||||
from sqlmodels.base import UUIDTableBase, TableViewRequest
|
||||
|
||||
# Recommended new import path
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin, TableViewRequest
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Mixin Order Matters
|
||||
|
||||
**Correct Order:**
|
||||
```python
|
||||
# ✅ ID Mixin first, then parent, then AutoPolymorphicIdentityMixin
|
||||
class SubTool(ToolSubclassIdMixin, Tool, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
**Wrong Order:**
|
||||
```python
|
||||
# ❌ ID Mixin not first - won't override parent's id field
|
||||
class SubTool(Tool, ToolSubclassIdMixin, AutoPolymorphicIdentityMixin, table=True):
|
||||
pass
|
||||
```
|
||||
|
||||
### 2. Always Use Return Values from save() and update()
|
||||
|
||||
```python
|
||||
# ✅ Correct - use returned instance
|
||||
device = await device.save(session)
|
||||
return device
|
||||
|
||||
# ❌ Wrong - device is expired after commit
|
||||
await device.save(session)
|
||||
return device # AttributeError when accessing fields
|
||||
```
|
||||
|
||||
### 3. Prefer table_view Over Manual Pagination
|
||||
|
||||
```python
|
||||
# ✅ Recommended - consistent across all endpoints
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
table_view=table_view
|
||||
)
|
||||
|
||||
# ⚠️ Works but not recommended - manual parameter management
|
||||
characters = await Character.get(
|
||||
session,
|
||||
fetch_mode="all",
|
||||
offset=0,
|
||||
limit=20,
|
||||
order_by=[desc(Character.created_at)]
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Polymorphic Loading for Many Subclasses
|
||||
|
||||
```python
|
||||
# When loading relationships with > 10 polymorphic subclasses, use load_polymorphic='all'
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # Two-phase query - only loads actual related subclasses
|
||||
)
|
||||
|
||||
# For fewer subclasses, specify the list explicitly
|
||||
tool_set = await ToolSet.get(
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic=[GetWeatherFunction, CodeInterpreterFunction]
|
||||
)
|
||||
```
|
||||
|
||||
### 5. Response DTOs Should Inherit Base Classes
|
||||
|
||||
```python
|
||||
# ✅ Correct - inherits from CharacterBase
|
||||
class CharacterInfoResponse(CharacterBase, UUIDIdDatetimeInfoMixin):
|
||||
pass
|
||||
|
||||
# ❌ Wrong - doesn't inherit from CharacterBase
|
||||
class CharacterInfoResponse(SQLModelBase, UUIDIdDatetimeInfoMixin):
|
||||
name: str # Duplicated field definition
|
||||
description: str | None = None
|
||||
```
|
||||
|
||||
**Reason:** Inheriting from Base classes ensures:
|
||||
- Type checking via `isinstance(obj, XxxBase)`
|
||||
- Consistency across related DTOs
|
||||
- Future field additions automatically propagate
|
||||
|
||||
### 6. Use Specific Types, Not Containers
|
||||
|
||||
```python
|
||||
# ✅ Correct - specific DTO for config updates
|
||||
class GetWeatherFunctionUpdateRequest(GetWeatherFunctionConfigBase):
|
||||
weather_api_key: str | None = None
|
||||
default_location: str | None = None
|
||||
"""Default location"""
|
||||
|
||||
# ❌ Wrong - lose type safety
|
||||
class ToolUpdateRequest(SQLModelBase):
|
||||
config: dict[str, Any] # No field validation
|
||||
```
|
||||
|
||||
## Type Variables
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import T, M
|
||||
|
||||
T = TypeVar("T", bound="TableBaseMixin") # For CRUD methods
|
||||
M = TypeVar("M", bound="SQLModel") # For update() method
|
||||
```
|
||||
|
||||
## Utility Functions
|
||||
|
||||
```python
|
||||
from sqlmodels.mixin import now, now_date
|
||||
|
||||
# Lambda functions for default factories
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
```
|
||||
|
||||
## Related Modules
|
||||
|
||||
- **sqlmodels.base** - Base classes (`SQLModelBase`, backward-compatible exports)
|
||||
- **dependencies** - FastAPI dependencies (`SessionDep`, `TableViewRequestDep`)
|
||||
- **sqlmodels.user** - User model with JWT authentication
|
||||
- **sqlmodels.client** - Client model with JWT authentication
|
||||
- **sqlmodels.character.llm.openai_compatibles.tools** - Polymorphic tool hierarchy
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- `POLYMORPHIC_NAME_DESIGN.md` - Design rationale for `_polymorphic_name` field
|
||||
- `CLAUDE.md` - Project coding standards and design philosophy
|
||||
- SQLAlchemy Documentation - [Joined Table Inheritance](https://docs.sqlalchemy.org/en/20/orm/inheritance.html#joined-table-inheritance)
|
||||
62
sqlmodels/mixin/__init__.py
Normal file
62
sqlmodels/mixin/__init__.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
SQLModel Mixin模块
|
||||
|
||||
提供各种Mixin类供SQLModel实体使用。
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
||||
4. relation_preload(只依赖 SQLModelBase)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
register_sti_column_properties_for_all_subclasses,
|
||||
register_sti_columns_for_all_subclasses,
|
||||
)
|
||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
||||
from .optimistic_lock import (
|
||||
OptimisticLockError,
|
||||
OptimisticLockMixin,
|
||||
)
|
||||
# table 依赖 polymorphic 和 optimistic_lock
|
||||
from .table import (
|
||||
ListResponse,
|
||||
PaginationRequest,
|
||||
T,
|
||||
TableBaseMixin,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
UUIDTableBaseMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# relation_preload 只依赖 SQLModelBase
|
||||
from .relation_preload import (
|
||||
RelationPreloadMixin,
|
||||
requires_relations,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
IntIdInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
)
|
||||
46
sqlmodels/mixin/info_response.py
Normal file
46
sqlmodels/mixin/info_response.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
InfoResponse DTO Mixin模块
|
||||
|
||||
提供用于InfoResponse类型DTO的Mixin,统一定义id/created_at/updated_at字段。
|
||||
|
||||
设计说明:
|
||||
- 这些Mixin用于**响应DTO**,不是数据库表
|
||||
- 从数据库返回时这些字段永远不为空,所以定义为必填字段
|
||||
- TableBase中的id=None和default_factory=now是正确的(入库前为None,数据库生成)
|
||||
- 这些Mixin让DTO明确表示"返回给客户端时这些字段必定有值"
|
||||
"""
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
"""整数ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: int
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class UUIDIdInfoMixin(SQLModelBase):
|
||||
"""UUID ID响应mixin - 用于InfoResponse DTO"""
|
||||
id: UUID
|
||||
"""记录ID"""
|
||||
|
||||
|
||||
class DatetimeInfoMixin(SQLModelBase):
|
||||
"""时间戳响应mixin - 用于InfoResponse DTO"""
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class IntIdDatetimeInfoMixin(IntIdInfoMixin, DatetimeInfoMixin):
|
||||
"""整数ID + 时间戳响应mixin"""
|
||||
pass
|
||||
|
||||
|
||||
class UUIDIdDatetimeInfoMixin(UUIDIdInfoMixin, DatetimeInfoMixin):
|
||||
"""UUID ID + 时间戳响应mixin"""
|
||||
pass
|
||||
90
sqlmodels/mixin/optimistic_lock.py
Normal file
90
sqlmodels/mixin/optimistic_lock.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
||||
|
||||
乐观锁适用场景:
|
||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
||||
- 涉及"数值变动"的表(如:余额、库存)
|
||||
|
||||
不适用场景:
|
||||
- 日志表、纯插入表、低价值统计表
|
||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
||||
|
||||
使用示例:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
status: OrderStatusEnum
|
||||
amount: Decimal
|
||||
|
||||
# save/update 时自动检查版本号
|
||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
||||
try:
|
||||
order = await order.save(session)
|
||||
except OptimisticLockError as e:
|
||||
# 处理冲突:重新查询并重试,或报错给用户
|
||||
l.warning(f"乐观锁冲突: {e}")
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
|
||||
class OptimisticLockError(Exception):
|
||||
"""
|
||||
乐观锁冲突异常
|
||||
|
||||
当 save/update 操作检测到版本号不匹配时抛出。
|
||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
||||
|
||||
Attributes:
|
||||
model_class: 发生冲突的模型类名
|
||||
record_id: 记录 ID(如果可用)
|
||||
expected_version: 期望的版本号(如果可用)
|
||||
original_error: 原始的 StaleDataError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model_class: str | None = None,
|
||||
record_id: str | None = None,
|
||||
expected_version: int | None = None,
|
||||
original_error: StaleDataError | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.model_class = model_class
|
||||
self.record_id = record_id
|
||||
self.expected_version = expected_version
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OptimisticLockMixin:
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
||||
|
||||
原理:
|
||||
1. 每条记录有一个 version 字段,初始值为 0
|
||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
||||
|
||||
继承顺序:
|
||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
...
|
||||
|
||||
配套重试:
|
||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
||||
- 自动重试:重新查询最新数据并再次尝试
|
||||
"""
|
||||
_has_optimistic_lock: ClassVar[bool] = True
|
||||
"""标记此类启用了乐观锁"""
|
||||
|
||||
version: int = 0
|
||||
"""乐观锁版本号,每次更新自动递增"""
|
||||
710
sqlmodels/mixin/polymorphic.py
Normal file
710
sqlmodels/mixin/polymorphic.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import Column, String, inspect
|
||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
from sqlmodel.main import get_column_from_field
|
||||
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# 用于延迟注册 STI 子类列的队列
|
||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||||
_sti_subclasses_to_register: list[type] = []
|
||||
|
||||
|
||||
def register_sti_columns_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||||
|
||||
此函数应在 configure_mappers() 之前调用。
|
||||
将 STI 子类的字段添加到父表的 metadata 中。
|
||||
同时修复被 Column 对象污染的 model_fields。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_columns()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||||
|
||||
# 修复被 Column 对象污染的 model_fields
|
||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||||
try:
|
||||
_fix_polluted_model_fields(cls)
|
||||
except Exception as e:
|
||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||||
|
||||
|
||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||||
|
||||
此函数应在 configure_mappers() 之后调用。
|
||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_column_properties()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||||
|
||||
# 清空队列
|
||||
_sti_subclasses_to_register.clear()
|
||||
|
||||
|
||||
def _fix_polluted_model_fields(cls: type) -> None:
|
||||
"""
|
||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||||
|
||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||||
|
||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||||
|
||||
:param cls: 要修复的类
|
||||
"""
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制 SQLModel 特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
**重要:数据库迁移注意事项**
|
||||
|
||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||||
|
||||
例如,对于以下继承链::
|
||||
|
||||
LLM (polymorphic_on='_polymorphic_name')
|
||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||||
|
||||
迁移脚本中设置 _polymorphic_name 时::
|
||||
|
||||
# ❌ 错误:缺少父类前缀
|
||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||||
|
||||
# ✅ 正确:包含完整的继承链前缀
|
||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||||
|
||||
**STI(单表继承)支持**:
|
||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||||
注册到父表的问题。
|
||||
|
||||
Example (JTI):
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Example (STI):
|
||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||||
... user_id: UUID
|
||||
...
|
||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
# 注册 STI 子类列的延迟执行
|
||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||||
_sti_subclasses_to_register.append(cls)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
@classmethod
|
||||
def _register_sti_columns(cls) -> None:
|
||||
"""
|
||||
将STI子类的新字段注册到父表的列定义中
|
||||
|
||||
检测当前类是否是STI子类(与父类共用同一张表),
|
||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_fields: set[str] = set()
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
# 收集父类的所有字段名
|
||||
if hasattr(base, 'model_fields'):
|
||||
parent_fields.update(base.model_fields.keys())
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列注册
|
||||
|
||||
# 获取当前类的新字段(不在父类中的字段)
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
existing_columns = {col.name for col in parent_table.columns}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段和ClassVar
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 跳过已存在的列
|
||||
if field_name in existing_columns:
|
||||
continue
|
||||
|
||||
# 使用 SQLModel 的内置 API 创建列
|
||||
try:
|
||||
column = get_column_from_field(field_info)
|
||||
column.name = field_name
|
||||
column.key = field_name
|
||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||||
column.nullable = True
|
||||
|
||||
# 将列添加到父表
|
||||
parent_table.append_column(column)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def _register_sti_column_properties(cls) -> None:
|
||||
"""
|
||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||||
|
||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||||
|
||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_class = None
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
parent_class = base
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列属性注册
|
||||
|
||||
# 获取子类和父类的 mapper
|
||||
child_mapper = inspect(cls).mapper
|
||||
parent_mapper = inspect(parent_class).mapper
|
||||
local_table = child_mapper.local_table
|
||||
|
||||
# 查找父类的所有字段名
|
||||
parent_fields: set[str] = set()
|
||||
if hasattr(parent_class, 'model_fields'):
|
||||
parent_fields.update(parent_class.model_fields.keys())
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
# 获取两个 mapper 已有的列属性
|
||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||||
|
||||
for field_name in cls.model_fields:
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 检查表中是否有这个列
|
||||
if field_name not in local_table.columns:
|
||||
continue
|
||||
|
||||
column = local_table.columns[field_name]
|
||||
|
||||
# 添加到子类的 mapper(如果尚不存在)
|
||||
if field_name not in child_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
child_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||||
|
||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||||
if field_name not in parent_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
parent_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def _is_joined_table_inheritance(cls) -> bool:
|
||||
"""
|
||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||||
|
||||
通过检查子类是否有独立的表来判断:
|
||||
- JTI: 子类有独立的 local_table(与父类不同)
|
||||
- STI: 子类与父类共用同一个 local_table
|
||||
|
||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||||
"""
|
||||
mapper = inspect(cls)
|
||||
base_table_name = mapper.local_table.name
|
||||
|
||||
# 检查所有直接子类
|
||||
for subclass in cls.__subclasses__():
|
||||
sub_mapper = inspect(subclass)
|
||||
# 如果任何子类有不同的表名,说明是 JTI
|
||||
if sub_mapper.local_table.name != base_table_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
470
sqlmodels/mixin/relation_preload.py
Normal file
470
sqlmodels/mixin/relation_preload.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
1247
sqlmodels/mixin/table.py
Normal file
1247
sqlmodels/mixin/table.py
Normal file
File diff suppressed because it is too large
Load Diff
123
sqlmodels/model_base.py
Normal file
123
sqlmodels/model_base.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlmodel import Field
|
||||
|
||||
from .base import SQLModelBase
|
||||
|
||||
|
||||
class ResponseBase(SQLModelBase):
|
||||
"""通用响应模型"""
|
||||
|
||||
instance_id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
"""实例ID,用于标识请求的唯一性"""
|
||||
|
||||
|
||||
# ==================== Admin Summary DTO ====================
|
||||
|
||||
|
||||
class MetricsSummary(SQLModelBase):
|
||||
"""站点统计摘要"""
|
||||
|
||||
dates: list[datetime]
|
||||
"""日期列表"""
|
||||
|
||||
files: list[int]
|
||||
"""每日新增文件数"""
|
||||
|
||||
users: list[int]
|
||||
"""每日新增用户数"""
|
||||
|
||||
shares: list[int]
|
||||
"""每日新增分享数"""
|
||||
|
||||
file_total: int
|
||||
"""文件总数"""
|
||||
|
||||
user_total: int
|
||||
"""用户总数"""
|
||||
|
||||
share_total: int
|
||||
"""分享总数"""
|
||||
|
||||
entities_total: int
|
||||
"""实体总数"""
|
||||
|
||||
generated_at: datetime
|
||||
"""生成时间"""
|
||||
|
||||
|
||||
class LicenseInfo(SQLModelBase):
|
||||
"""许可证信息"""
|
||||
|
||||
expired_at: datetime
|
||||
"""过期时间"""
|
||||
|
||||
signed_at: datetime
|
||||
"""签发时间"""
|
||||
|
||||
root_domains: list[str]
|
||||
"""根域名列表"""
|
||||
|
||||
domains: list[str]
|
||||
"""域名列表"""
|
||||
|
||||
vol_domains: list[str]
|
||||
"""卷域名列表"""
|
||||
|
||||
|
||||
class VersionInfo(SQLModelBase):
|
||||
"""版本信息"""
|
||||
|
||||
version: str
|
||||
"""版本号"""
|
||||
|
||||
pro: bool
|
||||
"""是否为专业版"""
|
||||
|
||||
commit: str
|
||||
"""提交哈希"""
|
||||
|
||||
class AdminSummaryResponse(ResponseBase):
|
||||
"""管理员概况响应"""
|
||||
|
||||
metrics_summary: MetricsSummary
|
||||
"""统计摘要"""
|
||||
|
||||
site_urls: list[str]
|
||||
"""站点URL列表"""
|
||||
|
||||
license: LicenseInfo
|
||||
"""许可证信息"""
|
||||
|
||||
version: VersionInfo
|
||||
"""版本信息"""
|
||||
|
||||
class MCPMethod(StrEnum):
|
||||
"""MCP 方法枚举"""
|
||||
|
||||
PING = "ping"
|
||||
"""Ping 方法,用于测试连接"""
|
||||
|
||||
class MCPBase(SQLModelBase):
|
||||
"""MCP 请求基础模型"""
|
||||
|
||||
jsonrpc: str = "2.0"
|
||||
"""JSON-RPC 版本"""
|
||||
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
"""请求/响应 ID,用于标识请求/响应的唯一性"""
|
||||
|
||||
class MCPRequestBase(MCPBase):
|
||||
"""MCP 请求模型基础类"""
|
||||
|
||||
method: str
|
||||
"""方法名称"""
|
||||
|
||||
class MCPResponseBase(MCPBase):
|
||||
"""MCP 响应模型基础类"""
|
||||
|
||||
result: str
|
||||
"""方法返回结果"""
|
||||
|
||||
103
sqlmodels/node.py
Normal file
103
sqlmodels/node.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship, text, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
|
||||
|
||||
class NodeStatus(StrEnum):
|
||||
"""节点状态枚举"""
|
||||
ONLINE = "online"
|
||||
"""正常"""
|
||||
OFFLINE = "offline"
|
||||
"""离线"""
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
"""节点类型枚举"""
|
||||
MASTER = "master"
|
||||
"""主节点"""
|
||||
SLAVE = "slave"
|
||||
"""从节点"""
|
||||
|
||||
|
||||
class Aria2ConfigurationBase(SQLModelBase):
|
||||
"""Aria2配置基础模型"""
|
||||
|
||||
rpc_url: str | None = Field(default=None, max_length=255)
|
||||
"""RPC地址"""
|
||||
|
||||
rpc_secret: str | None = None
|
||||
"""RPC密钥"""
|
||||
|
||||
temp_path: str | None = Field(default=None, max_length=255)
|
||||
"""临时下载路径"""
|
||||
|
||||
max_concurrent: int = Field(default=5, ge=1, le=50)
|
||||
"""最大并发数"""
|
||||
|
||||
timeout: int = Field(default=300, ge=1)
|
||||
"""请求超时时间(秒)"""
|
||||
|
||||
|
||||
class Aria2Configuration(Aria2ConfigurationBase, TableBaseMixin):
|
||||
"""Aria2配置模型(与Node一对一关联)"""
|
||||
|
||||
node_id: int = Field(
|
||||
foreign_key="node.id",
|
||||
unique=True,
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的节点ID"""
|
||||
|
||||
# 反向关系
|
||||
node: "Node" = Relationship(back_populates="aria2_config")
|
||||
"""关联的节点"""
|
||||
|
||||
|
||||
class Node(SQLModelBase, TableBaseMixin):
|
||||
"""节点模型"""
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_node_status", "status"),
|
||||
)
|
||||
|
||||
status: NodeStatus = Field(default=NodeStatus.ONLINE)
|
||||
"""节点状态"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
"""节点名称"""
|
||||
|
||||
type: NodeType
|
||||
"""节点类型"""
|
||||
|
||||
server: str = Field(max_length=255)
|
||||
"""节点地址(IP或域名)"""
|
||||
|
||||
slave_key: str | None = Field(default=None, max_length=255)
|
||||
"""从机通讯密钥"""
|
||||
|
||||
master_key: str | None = Field(default=None, max_length=255)
|
||||
"""主机通讯密钥"""
|
||||
|
||||
aria2_enabled: bool = False
|
||||
"""是否启用Aria2"""
|
||||
|
||||
rank: int = 0
|
||||
"""节点排序权重"""
|
||||
|
||||
# 关系
|
||||
aria2_config: Aria2Configuration | None = Relationship(
|
||||
back_populates="node",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""Aria2配置"""
|
||||
|
||||
downloads: list["Download"] = Relationship(back_populates="node")
|
||||
"""该节点的下载任务"""
|
||||
807
sqlmodels/object.py
Normal file
807
sqlmodels/object.py
Normal file
@@ -0,0 +1,807 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from enum import StrEnum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .policy import Policy
|
||||
from .source_link import SourceLink
|
||||
from .share import Share
|
||||
from .physical_file import PhysicalFile
|
||||
from .uri import DiskNextURI
|
||||
|
||||
|
||||
class ObjectType(StrEnum):
|
||||
"""对象类型枚举"""
|
||||
FILE = "file"
|
||||
FOLDER = "folder"
|
||||
|
||||
class StorageType(StrEnum):
|
||||
"""存储类型枚举"""
|
||||
LOCAL = "local"
|
||||
QINIU = "qiniu"
|
||||
TENCENT = "tencent"
|
||||
ALIYUN = "aliyun"
|
||||
ONEDRIVE = "onedrive"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
DROPBOX = "dropbox"
|
||||
WEBDAV = "webdav"
|
||||
REMOTE = "remote"
|
||||
|
||||
|
||||
class FileMetadataBase(SQLModelBase):
|
||||
"""文件元数据基础模型"""
|
||||
|
||||
width: int | None = Field(default=None)
|
||||
"""图片宽度(像素)"""
|
||||
|
||||
height: int | None = Field(default=None)
|
||||
"""图片高度(像素)"""
|
||||
|
||||
duration: float | None = Field(default=None)
|
||||
"""音视频时长(秒)"""
|
||||
|
||||
bitrate: int | None = Field(default=None)
|
||||
"""比特率(kbps)"""
|
||||
|
||||
mime_type: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
"""MD5校验和"""
|
||||
|
||||
checksum_sha256: str | None = Field(default=None, max_length=64)
|
||||
"""SHA256校验和"""
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class ObjectBase(SQLModelBase):
|
||||
"""对象基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
name: str
|
||||
"""对象名称(文件名或目录名)"""
|
||||
|
||||
type: ObjectType
|
||||
"""对象类型"""
|
||||
|
||||
size: int | None = None
|
||||
"""文件大小(字节),目录为 None"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class DirectoryCreateRequest(SQLModelBase):
|
||||
"""创建目录请求 DTO"""
|
||||
|
||||
parent_id: UUID
|
||||
"""父目录UUID"""
|
||||
|
||||
name: str
|
||||
"""目录名称"""
|
||||
|
||||
policy_id: UUID | None = None
|
||||
"""存储策略UUID,不指定则继承父目录"""
|
||||
|
||||
|
||||
class ObjectMoveRequest(SQLModelBase):
|
||||
"""移动对象请求 DTO"""
|
||||
|
||||
src_ids: list[UUID]
|
||||
"""源对象UUID列表"""
|
||||
|
||||
dst_id: UUID
|
||||
"""目标文件夹UUID"""
|
||||
|
||||
|
||||
class ObjectDeleteRequest(SQLModelBase):
|
||||
"""删除对象请求 DTO"""
|
||||
|
||||
ids: list[UUID]
|
||||
"""待删除对象UUID列表"""
|
||||
|
||||
|
||||
class ObjectResponse(ObjectBase):
|
||||
"""对象响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""对象UUID"""
|
||||
|
||||
thumb: bool = False
|
||||
"""是否有缩略图"""
|
||||
|
||||
created_at: datetime
|
||||
"""对象创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""对象修改时间"""
|
||||
|
||||
source_enabled: bool = False
|
||||
"""是否启用离线下载源"""
|
||||
|
||||
|
||||
class PolicyResponse(SQLModelBase):
|
||||
"""存储策略响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""策略UUID"""
|
||||
|
||||
name: str
|
||||
"""策略名称"""
|
||||
|
||||
type: StorageType
|
||||
"""存储类型"""
|
||||
|
||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||
"""单文件最大限制,单位字节,0表示不限制"""
|
||||
|
||||
file_type: list[str] | None = None
|
||||
"""允许的文件类型列表,None 表示不限制"""
|
||||
|
||||
|
||||
class DirectoryResponse(SQLModelBase):
|
||||
"""目录响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""当前目录UUID"""
|
||||
|
||||
parent: UUID | None = None
|
||||
"""父目录UUID,根目录为None"""
|
||||
|
||||
objects: list[ObjectResponse] = []
|
||||
"""目录下的对象列表"""
|
||||
|
||||
policy: PolicyResponse
|
||||
"""存储策略"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class FileMetadata(FileMetadataBase, UUIDTableBaseMixin):
|
||||
"""文件元数据模型(与Object一对一关联)"""
|
||||
|
||||
object_id: UUID = Field(
|
||||
foreign_key="object.id",
|
||||
unique=True,
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的对象UUID"""
|
||||
|
||||
# 反向关系
|
||||
object: "Object" = Relationship(back_populates="file_metadata")
|
||||
"""关联的对象"""
|
||||
|
||||
|
||||
class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
统一对象模型
|
||||
|
||||
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
|
||||
|
||||
根目录规则:
|
||||
- 每个用户有一个显式根目录对象(name="/", parent_id=NULL)
|
||||
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
|
||||
- 根目录的 policy_id 指定用户默认存储策略
|
||||
- 路径格式:/path/to/file(如 /docs/readme.md),不包含用户名前缀
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
|
||||
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
|
||||
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/")
|
||||
CheckConstraint(
|
||||
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
|
||||
name="ck_object_name_no_slash",
|
||||
),
|
||||
# 性能索引
|
||||
Index("ix_object_owner_updated", "owner_id", "updated_at"),
|
||||
Index("ix_object_parent_updated", "parent_id", "updated_at"),
|
||||
Index("ix_object_owner_type", "owner_id", "type"),
|
||||
Index("ix_object_owner_size", "owner_id", "size"),
|
||||
)
|
||||
|
||||
# ==================== 基础字段 ====================
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""对象名称(文件名或目录名)"""
|
||||
|
||||
type: ObjectType
|
||||
"""对象类型:file 或 folder"""
|
||||
|
||||
password: str | None = Field(default=None, max_length=255)
|
||||
"""对象独立密码(仅当用户为对象单独设置密码时有效)"""
|
||||
|
||||
# ==================== 文件专属字段 ====================
|
||||
|
||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""文件大小(字节),目录为 0"""
|
||||
|
||||
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
"""分块上传会话ID(仅文件有效)"""
|
||||
|
||||
physical_file_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="physicalfile.id",
|
||||
index=True,
|
||||
ondelete="SET NULL"
|
||||
)
|
||||
"""关联的物理文件UUID(仅文件有效,目录为NULL)"""
|
||||
|
||||
# ==================== 外键 ====================
|
||||
|
||||
parent_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="object.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""父目录UUID,NULL 表示这是用户的根目录"""
|
||||
|
||||
owner_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所有者用户UUID"""
|
||||
|
||||
policy_id: UUID = Field(
|
||||
foreign_key="policy.id",
|
||||
index=True,
|
||||
ondelete="RESTRICT"
|
||||
)
|
||||
"""存储策略UUID(文件直接使用,目录作为子文件的默认策略)"""
|
||||
|
||||
# ==================== 封禁相关字段 ====================
|
||||
|
||||
is_banned: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否被封禁"""
|
||||
|
||||
banned_at: datetime | None = None
|
||||
"""封禁时间"""
|
||||
|
||||
banned_by: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="SET NULL",
|
||||
sa_column_kwargs={"name": "banned_by"}
|
||||
)
|
||||
"""封禁操作者UUID"""
|
||||
|
||||
ban_reason: str | None = Field(default=None, max_length=500)
|
||||
"""封禁原因"""
|
||||
|
||||
# ==================== 关系 ====================
|
||||
|
||||
owner: "User" = Relationship(
|
||||
back_populates="objects",
|
||||
sa_relationship_kwargs={"foreign_keys": "[Object.owner_id]"}
|
||||
)
|
||||
"""所有者"""
|
||||
|
||||
banner: "User" = Relationship(
|
||||
sa_relationship_kwargs={"foreign_keys": "[Object.banned_by]"}
|
||||
)
|
||||
"""封禁操作者"""
|
||||
|
||||
policy: "Policy" = Relationship(back_populates="objects")
|
||||
"""存储策略"""
|
||||
|
||||
# 自引用关系
|
||||
parent: "Object" = Relationship(
|
||||
back_populates="children",
|
||||
sa_relationship_kwargs={"remote_side": "Object.id"},
|
||||
)
|
||||
"""父目录"""
|
||||
|
||||
children: list["Object"] = Relationship(
|
||||
back_populates="parent",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""子对象(文件和子目录)"""
|
||||
|
||||
# 仅文件有效的关系
|
||||
file_metadata: FileMetadata | None = Relationship(
|
||||
back_populates="object",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""文件元数据(仅文件有效)"""
|
||||
|
||||
source_links: list["SourceLink"] = Relationship(
|
||||
back_populates="object",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""源链接列表(仅文件有效)"""
|
||||
|
||||
shares: list["Share"] = Relationship(
|
||||
back_populates="object",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""分享列表"""
|
||||
|
||||
physical_file: "PhysicalFile" = Relationship(back_populates="objects")
|
||||
"""关联的物理文件(仅文件有效)"""
|
||||
|
||||
# ==================== 业务属性 ====================
|
||||
|
||||
@property
|
||||
def source_name(self) -> str | None:
|
||||
"""
|
||||
源文件存储路径(向后兼容属性)
|
||||
|
||||
:return: 物理文件存储路径,如果没有关联物理文件则返回 None
|
||||
"""
|
||||
if self.physical_file:
|
||||
return self.physical_file.storage_path
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
"""是否为文件"""
|
||||
return self.type == ObjectType.FILE
|
||||
|
||||
@property
|
||||
def is_folder(self) -> bool:
|
||||
"""是否为目录"""
|
||||
return self.type == ObjectType.FOLDER
|
||||
|
||||
# ==================== 业务方法 ====================
|
||||
|
||||
@classmethod
|
||||
async def get_root(cls, session, user_id: UUID) -> "Object | None":
|
||||
"""
|
||||
获取用户的根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:return: 根目录对象,不存在则返回 None
|
||||
"""
|
||||
return await cls.get(
|
||||
session,
|
||||
(cls.owner_id == user_id) & (cls.parent_id == None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_path(
|
||||
cls,
|
||||
session,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
) -> "Object | None":
|
||||
"""
|
||||
根据路径获取对象
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 "/" 表示根目录,"/docs/images" 表示根目录下的 docs/images。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param path: 路径,如 "/" 或 "/docs/images"
|
||||
:return: Object 或 None
|
||||
"""
|
||||
path = path.strip()
|
||||
if not path:
|
||||
raise ValueError("路径不能为空")
|
||||
|
||||
# 获取用户根目录
|
||||
root = await cls.get_root(session, user_id)
|
||||
if not root:
|
||||
return None
|
||||
|
||||
# 移除开头的斜杠并分割路径
|
||||
if path.startswith("/"):
|
||||
path = path[1:]
|
||||
parts = [p for p in path.split("/") if p]
|
||||
|
||||
# 空路径 -> 返回根目录
|
||||
if not parts:
|
||||
return root
|
||||
|
||||
# 从根目录开始遍历路径
|
||||
current = root
|
||||
for part in parts:
|
||||
if not current:
|
||||
return None
|
||||
|
||||
current = await cls.get(
|
||||
session,
|
||||
(cls.owner_id == user_id) &
|
||||
(cls.parent_id == current.id) &
|
||||
(cls.name == part)
|
||||
)
|
||||
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
|
||||
"""
|
||||
获取目录下的所有子对象
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param parent_id: 父目录UUID
|
||||
:return: 子对象列表
|
||||
"""
|
||||
return await cls.get(
|
||||
session,
|
||||
(cls.owner_id == user_id) & (cls.parent_id == parent_id),
|
||||
fetch_mode="all"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def resolve_uri(
|
||||
cls,
|
||||
session,
|
||||
uri: "DiskNextURI",
|
||||
requesting_user_id: UUID | None = None,
|
||||
) -> "Object":
|
||||
"""
|
||||
将 URI 解析为 Object 实例
|
||||
|
||||
分派逻辑(类似 Cloudreve 的 getNavigator):
|
||||
- MY → user_id = uri.id(str(requesting_user_id))
|
||||
验证权限(自己的或管理员),然后 get_by_path
|
||||
- SHARE → 通过 uri.fs_id 查 Share 表,验证密码和有效期
|
||||
获取 share.object,然后沿 uri.path 遍历子对象
|
||||
- TRASH → 延后实现
|
||||
|
||||
:param session: 数据库会话
|
||||
:param uri: DiskNextURI 实例
|
||||
:param requesting_user_id: 请求用户UUID
|
||||
:return: Object 实例
|
||||
:raises ValueError: URI 无法解析
|
||||
:raises PermissionError: 权限不足
|
||||
:raises NotImplementedError: 不支持的命名空间
|
||||
"""
|
||||
from .uri import FileSystemNamespace
|
||||
|
||||
if uri.namespace == FileSystemNamespace.MY:
|
||||
# 确定目标用户
|
||||
target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None)
|
||||
if not target_user_id_str:
|
||||
raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id")
|
||||
|
||||
target_user_id = UUID(target_user_id_str)
|
||||
|
||||
# 权限检查:只能访问自己的空间(管理员权限由路由层判断)
|
||||
if requesting_user_id and target_user_id != requesting_user_id:
|
||||
raise PermissionError("无权访问其他用户的文件空间")
|
||||
|
||||
obj = await cls.get_by_path(session, target_user_id, uri.path)
|
||||
if not obj:
|
||||
raise ValueError(f"路径不存在: {uri.path}")
|
||||
return obj
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.SHARE:
|
||||
raise NotImplementedError("分享空间解析尚未实现")
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.TRASH:
|
||||
raise NotImplementedError("回收站解析尚未实现")
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的命名空间: {uri.namespace}")
|
||||
|
||||
async def get_full_path(self, session) -> str:
|
||||
"""
|
||||
从当前对象沿 parent_id 向上遍历到根目录,返回完整路径
|
||||
|
||||
:param session: 数据库会话
|
||||
:return: 完整路径,如 "/docs/images/photo.jpg"
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: Object | None = self
|
||||
|
||||
while current and current.parent_id is not None:
|
||||
parts.append(current.name)
|
||||
current = await Object.get(session, Object.id == current.parent_id)
|
||||
|
||||
# 反转顺序(从根到当前)
|
||||
parts.reverse()
|
||||
return "/" + "/".join(parts)
|
||||
|
||||
|
||||
# ==================== 上传会话模型 ====================
|
||||
|
||||
class UploadSessionBase(SQLModelBase):
|
||||
"""上传会话基础字段"""
|
||||
|
||||
file_name: str = Field(max_length=255)
|
||||
"""原始文件名"""
|
||||
|
||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||
"""文件总大小(字节)"""
|
||||
|
||||
chunk_size: int = Field(ge=1, sa_type=BigInteger)
|
||||
"""分片大小(字节)"""
|
||||
|
||||
total_chunks: int = Field(ge=1)
|
||||
"""总分片数"""
|
||||
|
||||
|
||||
class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
上传会话模型
|
||||
|
||||
用于管理分片上传的会话状态。
|
||||
会话有效期为24小时,过期后自动失效。
|
||||
"""
|
||||
|
||||
# 会话状态
|
||||
uploaded_chunks: int = 0
|
||||
"""已上传分片数"""
|
||||
|
||||
uploaded_size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""已上传大小(字节)"""
|
||||
|
||||
storage_path: str | None = Field(default=None, max_length=512)
|
||||
"""文件存储路径"""
|
||||
|
||||
expires_at: datetime
|
||||
"""会话过期时间"""
|
||||
|
||||
# 外键
|
||||
owner_id: UUID = Field(foreign_key="user.id", index=True, ondelete="CASCADE")
|
||||
"""上传者用户UUID"""
|
||||
|
||||
parent_id: UUID = Field(foreign_key="object.id", index=True, ondelete="CASCADE")
|
||||
"""目标父目录UUID"""
|
||||
|
||||
policy_id: UUID = Field(foreign_key="policy.id", index=True, ondelete="RESTRICT")
|
||||
"""存储策略UUID"""
|
||||
|
||||
# 关系
|
||||
owner: "User" = Relationship()
|
||||
"""上传者"""
|
||||
|
||||
parent: "Object" = Relationship(
|
||||
sa_relationship_kwargs={"foreign_keys": "[UploadSession.parent_id]"}
|
||||
)
|
||||
"""目标父目录"""
|
||||
|
||||
policy: "Policy" = Relationship()
|
||||
"""存储策略"""
|
||||
|
||||
@property
|
||||
def is_expired(self) -> bool:
|
||||
"""会话是否已过期"""
|
||||
return datetime.now() > self.expires_at
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""上传是否完成"""
|
||||
return self.uploaded_chunks >= self.total_chunks
|
||||
|
||||
|
||||
# ==================== 上传会话相关 DTO ====================
|
||||
|
||||
class CreateUploadSessionRequest(SQLModelBase):
|
||||
"""创建上传会话请求 DTO"""
|
||||
|
||||
file_name: str = Field(max_length=255)
|
||||
"""文件名"""
|
||||
|
||||
file_size: int = Field(ge=0)
|
||||
"""文件大小(字节)"""
|
||||
|
||||
parent_id: UUID
|
||||
"""父目录UUID"""
|
||||
|
||||
policy_id: UUID | None = None
|
||||
"""存储策略UUID,不指定则使用父目录的策略"""
|
||||
|
||||
|
||||
class UploadSessionResponse(SQLModelBase):
|
||||
"""上传会话响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""会话UUID"""
|
||||
|
||||
file_name: str
|
||||
"""原始文件名"""
|
||||
|
||||
file_size: int
|
||||
"""文件总大小(字节)"""
|
||||
|
||||
chunk_size: int
|
||||
"""分片大小(字节)"""
|
||||
|
||||
total_chunks: int
|
||||
"""总分片数"""
|
||||
|
||||
uploaded_chunks: int
|
||||
"""已上传分片数"""
|
||||
|
||||
expires_at: datetime
|
||||
"""过期时间"""
|
||||
|
||||
|
||||
class UploadChunkResponse(SQLModelBase):
|
||||
"""上传分片响应 DTO"""
|
||||
|
||||
uploaded_chunks: int
|
||||
"""已上传分片数"""
|
||||
|
||||
total_chunks: int
|
||||
"""总分片数"""
|
||||
|
||||
is_complete: bool
|
||||
"""是否上传完成"""
|
||||
|
||||
object_id: UUID | None = None
|
||||
"""完成后的文件对象UUID,未完成时为None"""
|
||||
|
||||
|
||||
class CreateFileRequest(SQLModelBase):
|
||||
"""创建空白文件请求 DTO"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""文件名"""
|
||||
|
||||
parent_id: UUID
|
||||
"""父目录UUID"""
|
||||
|
||||
policy_id: UUID | None = None
|
||||
"""存储策略UUID,不指定则使用父目录的策略"""
|
||||
|
||||
|
||||
# ==================== 对象操作相关 DTO ====================
|
||||
|
||||
class ObjectCopyRequest(SQLModelBase):
|
||||
"""复制对象请求 DTO"""
|
||||
|
||||
src_ids: list[UUID]
|
||||
"""源对象UUID列表"""
|
||||
|
||||
dst_id: UUID
|
||||
"""目标文件夹UUID"""
|
||||
|
||||
|
||||
class ObjectRenameRequest(SQLModelBase):
|
||||
"""重命名对象请求 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""对象UUID"""
|
||||
|
||||
new_name: str = Field(max_length=255)
|
||||
"""新名称"""
|
||||
|
||||
|
||||
class ObjectPropertyResponse(SQLModelBase):
|
||||
"""对象基本属性响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""对象UUID"""
|
||||
|
||||
name: str
|
||||
"""对象名称"""
|
||||
|
||||
type: ObjectType
|
||||
"""对象类型"""
|
||||
|
||||
size: int
|
||||
"""文件大小(字节)"""
|
||||
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""修改时间"""
|
||||
|
||||
parent_id: UUID | None
|
||||
"""父目录UUID"""
|
||||
|
||||
|
||||
class ObjectPropertyDetailResponse(ObjectPropertyResponse):
|
||||
"""对象详细属性响应 DTO(继承基本属性)"""
|
||||
|
||||
# 元数据信息
|
||||
mime_type: str | None = None
|
||||
"""MIME类型"""
|
||||
|
||||
width: int | None = None
|
||||
"""图片宽度(像素)"""
|
||||
|
||||
height: int | None = None
|
||||
"""图片高度(像素)"""
|
||||
|
||||
duration: float | None = None
|
||||
"""音视频时长(秒)"""
|
||||
|
||||
checksum_md5: str | None = None
|
||||
"""MD5校验和"""
|
||||
|
||||
# 分享统计
|
||||
share_count: int = 0
|
||||
"""分享次数"""
|
||||
|
||||
total_views: int = 0
|
||||
"""总浏览次数"""
|
||||
|
||||
total_downloads: int = 0
|
||||
"""总下载次数"""
|
||||
|
||||
# 存储信息
|
||||
policy_name: str | None = None
|
||||
"""存储策略名称"""
|
||||
|
||||
reference_count: int = 1
|
||||
"""物理文件引用计数(仅文件有效)"""
|
||||
|
||||
|
||||
# ==================== 管理员文件管理 DTO ====================
|
||||
|
||||
class AdminFileResponse(ObjectResponse):
|
||||
"""管理员文件响应 DTO"""
|
||||
|
||||
owner_id: UUID
|
||||
"""所有者UUID"""
|
||||
|
||||
owner_email: str
|
||||
"""所有者邮箱"""
|
||||
|
||||
policy_name: str
|
||||
"""存储策略名称"""
|
||||
|
||||
is_banned: bool = False
|
||||
"""是否被封禁"""
|
||||
|
||||
banned_at: datetime | None = None
|
||||
"""封禁时间"""
|
||||
|
||||
ban_reason: str | None = None
|
||||
"""封禁原因"""
|
||||
|
||||
@classmethod
|
||||
def from_object(
|
||||
cls,
|
||||
obj: "Object",
|
||||
owner: "User | None",
|
||||
policy: "Policy | None",
|
||||
) -> "AdminFileResponse":
|
||||
"""从 Object ORM 对象构建"""
|
||||
return cls(
|
||||
# ObjectBase 字段
|
||||
**ObjectBase.model_validate(obj, from_attributes=True).model_dump(),
|
||||
# ObjectResponse 字段
|
||||
id=obj.id,
|
||||
thumb=False,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
source_enabled=False,
|
||||
# AdminFileResponse 字段
|
||||
owner_id=obj.owner_id,
|
||||
owner_email=owner.email if owner else "unknown",
|
||||
policy_name=policy.name if policy else "unknown",
|
||||
is_banned=obj.is_banned,
|
||||
banned_at=obj.banned_at,
|
||||
ban_reason=obj.ban_reason,
|
||||
)
|
||||
|
||||
|
||||
class FileBanRequest(SQLModelBase):
|
||||
"""文件封禁请求 DTO"""
|
||||
|
||||
ban: bool = True
|
||||
"""是否封禁"""
|
||||
|
||||
reason: str | None = Field(default=None, max_length=500)
|
||||
"""封禁原因"""
|
||||
|
||||
|
||||
class AdminFileListResponse(SQLModelBase):
|
||||
"""管理员文件列表响应 DTO"""
|
||||
|
||||
files: list[AdminFileResponse] = []
|
||||
"""文件列表"""
|
||||
|
||||
total: int = 0
|
||||
"""总数"""
|
||||
66
sqlmodels/order.py
Normal file
66
sqlmodels/order.py
Normal file
@@ -0,0 +1,66 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class OrderStatus(StrEnum):
|
||||
"""订单状态枚举"""
|
||||
PENDING = "pending"
|
||||
"""待支付"""
|
||||
COMPLETED = "completed"
|
||||
"""已完成"""
|
||||
CANCELLED = "cancelled"
|
||||
"""已取消"""
|
||||
|
||||
|
||||
class OrderType(StrEnum):
|
||||
"""订单类型枚举"""
|
||||
# [TODO] 补充具体订单类型
|
||||
pass
|
||||
|
||||
|
||||
class Order(SQLModelBase, TableBaseMixin):
|
||||
"""订单模型"""
|
||||
|
||||
order_no: str = Field(max_length=255, unique=True, index=True)
|
||||
"""订单号,唯一"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单类型 [TODO] 待定义枚举"""
|
||||
|
||||
method: str | None = Field(default=None, max_length=255)
|
||||
"""支付方式"""
|
||||
|
||||
product_id: int | None = Field(default=None)
|
||||
"""商品ID"""
|
||||
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"})
|
||||
"""购买数量"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""商品名称"""
|
||||
|
||||
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""订单价格(分)"""
|
||||
|
||||
status: OrderStatus = Field(default=OrderStatus.PENDING)
|
||||
"""订单状态"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="orders")
|
||||
91
sqlmodels/physical_file.py
Normal file
91
sqlmodels/physical_file.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
物理文件模型
|
||||
|
||||
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile,
|
||||
实现文件共享而不复制物理文件。
|
||||
|
||||
引用计数逻辑:
|
||||
- 每个引用此文件的 Object 都会增加引用计数
|
||||
- 当 Object 被删除时,减少引用计数
|
||||
- 只有当引用计数为 0 时,才物理删除文件
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
from .policy import Policy
|
||||
|
||||
|
||||
class PhysicalFileBase(SQLModelBase):
|
||||
"""物理文件基础模型"""
|
||||
|
||||
storage_path: str = Field(max_length=512)
|
||||
"""物理存储路径(相对于存储策略根目录)"""
|
||||
|
||||
size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""文件大小(字节)"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
"""MD5校验和(用于文件去重和完整性校验)"""
|
||||
|
||||
|
||||
class PhysicalFile(PhysicalFileBase, UUIDTableBaseMixin):
|
||||
"""
|
||||
物理文件模型
|
||||
|
||||
表示磁盘上的实际文件。多个 Object 可以引用同一个 PhysicalFile,
|
||||
实现文件共享而不复制物理文件。
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_physical_file_policy_path", "policy_id", "storage_path"),
|
||||
Index("ix_physical_file_checksum", "checksum_md5"),
|
||||
)
|
||||
|
||||
policy_id: UUID = Field(
|
||||
foreign_key="policy.id",
|
||||
index=True,
|
||||
ondelete="RESTRICT",
|
||||
)
|
||||
"""存储策略UUID"""
|
||||
|
||||
reference_count: int = Field(default=1, ge=0)
|
||||
"""引用计数(有多少个 Object 引用此物理文件)"""
|
||||
|
||||
# 关系
|
||||
policy: "Policy" = Relationship()
|
||||
"""存储策略"""
|
||||
|
||||
objects: list["Object"] = Relationship(back_populates="physical_file")
|
||||
"""引用此物理文件的所有逻辑对象"""
|
||||
|
||||
def increment_reference(self) -> int:
|
||||
"""
|
||||
增加引用计数
|
||||
|
||||
:return: 更新后的引用计数
|
||||
"""
|
||||
self.reference_count += 1
|
||||
return self.reference_count
|
||||
|
||||
def decrement_reference(self) -> int:
|
||||
"""
|
||||
减少引用计数
|
||||
|
||||
:return: 更新后的引用计数
|
||||
"""
|
||||
if self.reference_count > 0:
|
||||
self.reference_count -= 1
|
||||
return self.reference_count
|
||||
|
||||
@property
|
||||
def can_be_deleted(self) -> bool:
|
||||
"""是否可以物理删除(引用计数为0)"""
|
||||
return self.reference_count == 0
|
||||
187
sqlmodels/policy.py
Normal file
187
sqlmodels/policy.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from enum import StrEnum
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import UUIDTableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
from .group import Group
|
||||
|
||||
|
||||
class GroupPolicyLink(SQLModelBase, table=True):
|
||||
"""用户组与存储策略的多对多关联表"""
|
||||
|
||||
group_id: UUID = Field(
|
||||
foreign_key="group.id",
|
||||
primary_key=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""用户组UUID"""
|
||||
|
||||
policy_id: UUID = Field(
|
||||
foreign_key="policy.id",
|
||||
primary_key=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""存储策略UUID"""
|
||||
|
||||
|
||||
class PolicyType(StrEnum):
|
||||
LOCAL = "local"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
class PolicyBase(SQLModelBase):
|
||||
"""存储策略基础字段,供 DTO 和数据库模型共享"""
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""策略名称"""
|
||||
|
||||
type: PolicyType
|
||||
"""存储策略类型"""
|
||||
|
||||
server: str | None = Field(default=None, max_length=255)
|
||||
"""服务器地址(本地策略为绝对路径)"""
|
||||
|
||||
bucket_name: str | None = Field(default=None, max_length=255)
|
||||
"""存储桶名称"""
|
||||
|
||||
is_private: bool = True
|
||||
"""是否为私有空间"""
|
||||
|
||||
base_url: str | None = Field(default=None, max_length=255)
|
||||
"""访问文件的基础URL"""
|
||||
|
||||
access_key: str | None = None
|
||||
"""Access Key"""
|
||||
|
||||
secret_key: str | None = None
|
||||
"""Secret Key"""
|
||||
|
||||
max_size: int = Field(default=0, ge=0)
|
||||
"""允许上传的最大文件尺寸(字节)"""
|
||||
|
||||
auto_rename: bool = False
|
||||
"""是否自动重命名"""
|
||||
|
||||
dir_name_rule: str | None = Field(default=None, max_length=255)
|
||||
"""目录命名规则"""
|
||||
|
||||
file_name_rule: str | None = Field(default=None, max_length=255)
|
||||
"""文件命名规则"""
|
||||
|
||||
is_origin_link_enable: bool = False
|
||||
"""是否开启源链接访问"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
|
||||
class PolicySummary(SQLModelBase):
|
||||
"""策略摘要,用于列表展示"""
|
||||
|
||||
id: UUID
|
||||
"""策略UUID"""
|
||||
|
||||
name: str
|
||||
"""策略名称"""
|
||||
|
||||
type: PolicyType
|
||||
"""策略类型"""
|
||||
|
||||
server: str | None
|
||||
"""服务器地址"""
|
||||
|
||||
max_size: int
|
||||
"""最大文件尺寸"""
|
||||
|
||||
is_private: bool
|
||||
"""是否私有"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
|
||||
class PolicyOptionsBase(SQLModelBase):
|
||||
"""存储策略选项的基础模型"""
|
||||
|
||||
token: str | None = Field(default=None)
|
||||
"""访问令牌"""
|
||||
|
||||
file_type: str | None = Field(default=None)
|
||||
"""允许的文件类型"""
|
||||
|
||||
mimetype: str | None = Field(default=None, max_length=127)
|
||||
"""MIME类型"""
|
||||
|
||||
od_redirect: str | None = Field(default=None, max_length=255)
|
||||
"""OneDrive重定向地址"""
|
||||
|
||||
chunk_size: int = Field(default=52428800, sa_column_kwargs={"server_default": "52428800"})
|
||||
"""分片上传大小(字节),默认50MB"""
|
||||
|
||||
s3_path_style: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否使用S3路径风格"""
|
||||
|
||||
|
||||
class PolicyOptions(PolicyOptionsBase, UUIDTableBaseMixin):
|
||||
"""存储策略选项模型(与Policy一对一关联)"""
|
||||
|
||||
policy_id: UUID = Field(
|
||||
foreign_key="policy.id",
|
||||
unique=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的策略UUID"""
|
||||
|
||||
# 反向关系
|
||||
policy: "Policy" = Relationship(back_populates="options")
|
||||
"""关联的策略"""
|
||||
|
||||
|
||||
class Policy(PolicyBase, UUIDTableBaseMixin):
|
||||
"""存储策略模型"""
|
||||
|
||||
# 覆盖基类字段以添加数据库专有配置
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
"""策略名称"""
|
||||
|
||||
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||
"""是否为私有空间"""
|
||||
|
||||
max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""允许上传的最大文件尺寸(字节)"""
|
||||
|
||||
auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否自动重命名"""
|
||||
|
||||
is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||
"""是否开启源链接访问"""
|
||||
|
||||
# 一对一关系:策略选项
|
||||
options: PolicyOptions | None = Relationship(
|
||||
back_populates="policy",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""策略的扩展选项"""
|
||||
|
||||
# 关系
|
||||
objects: list["Object"] = Relationship(back_populates="policy")
|
||||
"""策略下的所有对象"""
|
||||
|
||||
# 多对多关系:策略可以被多个用户组使用
|
||||
groups: list["Group"] = Relationship(
|
||||
back_populates="policies",
|
||||
link_model=GroupPolicyLink,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
policy: 'Policy | None' = None,
|
||||
**kwargs
|
||||
):
|
||||
pass
|
||||
23
sqlmodels/redeem.py
Normal file
23
sqlmodels/redeem.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlmodel import Field, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
|
||||
class RedeemType(StrEnum):
|
||||
"""兑换码类型枚举"""
|
||||
# [TODO] 补充具体兑换码类型
|
||||
pass
|
||||
|
||||
|
||||
class Redeem(SQLModelBase, TableBaseMixin):
|
||||
"""兑换码模型"""
|
||||
|
||||
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""兑换码类型 [TODO] 待定义枚举"""
|
||||
product_id: int | None = Field(default=None, description="关联的商品/权益ID")
|
||||
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
|
||||
code: str = Field(unique=True, index=True, description="兑换码,唯一")
|
||||
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
|
||||
35
sqlmodels/report.py
Normal file
35
sqlmodels/report.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .share import Share
|
||||
|
||||
|
||||
class ReportReason(StrEnum):
|
||||
"""举报原因枚举"""
|
||||
# [TODO] 补充具体举报原因
|
||||
pass
|
||||
|
||||
|
||||
class Report(SQLModelBase, TableBaseMixin):
|
||||
"""举报模型"""
|
||||
|
||||
reason: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""举报原因 [TODO] 待定义枚举"""
|
||||
description: str | None = Field(default=None, max_length=255, description="补充描述")
|
||||
|
||||
# 外键
|
||||
share_id: int = Field(
|
||||
foreign_key="share.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""被举报的分享ID"""
|
||||
|
||||
# 关系
|
||||
share: "Share" = Relationship(back_populates="reports")
|
||||
136
sqlmodels/setting.py
Normal file
136
sqlmodels/setting.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlmodel import UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
from .user import UserResponse
|
||||
|
||||
class CaptchaType(StrEnum):
|
||||
"""验证码类型枚举"""
|
||||
DEFAULT = "default"
|
||||
GCAPTCHA = "gcaptcha"
|
||||
CLOUD_FLARE_TURNSTILE = "cloudflare turnstile"
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class SiteConfigResponse(SQLModelBase):
|
||||
"""站点配置响应 DTO"""
|
||||
|
||||
title: str = "DiskNext"
|
||||
"""网站标题"""
|
||||
|
||||
site_notice: str | None = None
|
||||
"""网站公告"""
|
||||
|
||||
user: UserResponse | None = None
|
||||
"""用户信息"""
|
||||
|
||||
logo_light: str | None = None
|
||||
"""网站Logo URL"""
|
||||
|
||||
logo_dark: str | None = None
|
||||
"""网站Logo URL(深色模式)"""
|
||||
|
||||
register_enabled: bool = True
|
||||
"""是否允许注册"""
|
||||
|
||||
login_captcha: bool = False
|
||||
"""登录是否需要验证码"""
|
||||
|
||||
reg_captcha: bool = False
|
||||
"""注册是否需要验证码"""
|
||||
|
||||
forget_captcha: bool = False
|
||||
"""找回密码是否需要验证码"""
|
||||
|
||||
captcha_type: CaptchaType = CaptchaType.DEFAULT
|
||||
"""验证码类型"""
|
||||
|
||||
captcha_key: str | None = None
|
||||
"""验证码 public key(DEFAULT 类型时为 None)"""
|
||||
|
||||
|
||||
# ==================== 管理员设置 DTO ====================
|
||||
|
||||
class SettingItem(SQLModelBase):
|
||||
"""单个设置项 DTO"""
|
||||
|
||||
type: str
|
||||
"""设置类型"""
|
||||
|
||||
name: str
|
||||
"""设置项名称"""
|
||||
|
||||
value: str | None = None
|
||||
"""设置值"""
|
||||
|
||||
|
||||
class SettingsListResponse(SQLModelBase):
|
||||
"""获取设置列表响应 DTO"""
|
||||
|
||||
settings: list[SettingItem]
|
||||
"""设置项列表"""
|
||||
|
||||
total: int
|
||||
"""总数"""
|
||||
|
||||
|
||||
class SettingsUpdateRequest(SQLModelBase):
|
||||
"""更新设置请求 DTO"""
|
||||
|
||||
settings: list[SettingItem]
|
||||
"""要更新的设置项列表"""
|
||||
|
||||
|
||||
class SettingsUpdateResponse(SQLModelBase):
|
||||
"""更新设置响应 DTO"""
|
||||
|
||||
updated: int
|
||||
"""更新的设置项数量"""
|
||||
|
||||
created: int
|
||||
"""新建的设置项数量"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class SettingsType(StrEnum):
|
||||
"""设置类型枚举"""
|
||||
|
||||
ARIA2 = "aria2"
|
||||
AUTH = "auth"
|
||||
AUTHN = "authn"
|
||||
AVATAR = "avatar"
|
||||
BASIC = "basic"
|
||||
CAPTCHA = "captcha"
|
||||
CRON = "cron"
|
||||
FILE_EDIT = "file_edit"
|
||||
LOGIN = "login"
|
||||
MAIL = "mail"
|
||||
MAIL_TEMPLATE = "mail_template"
|
||||
MOBILE = "mobile"
|
||||
OAUTH = "oauth"
|
||||
PATH = "path"
|
||||
PREVIEW = "preview"
|
||||
PWA = "pwa"
|
||||
REGISTER = "register"
|
||||
RETRY = "retry"
|
||||
SHARE = "share"
|
||||
SLAVE = "slave"
|
||||
TASK = "task"
|
||||
THUMB = "thumb"
|
||||
TIMEOUT = "timeout"
|
||||
UPLOAD = "upload"
|
||||
VERSION = "version"
|
||||
VIEW = "view"
|
||||
WOPI = "wopi"
|
||||
|
||||
# 数据库模型
|
||||
class Setting(SettingItem, TableBaseMixin):
|
||||
"""设置模型,继承 SettingItem 中的 name 和 value 字段"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
|
||||
|
||||
type: SettingsType
|
||||
"""设置类型/分组(覆盖基类的 str 类型为枚举类型)"""
|
||||
220
sqlmodels/share.py
Normal file
220
sqlmodels/share.py
Normal file
@@ -0,0 +1,220 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .report import Report
|
||||
from .object import Object
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class ShareBase(SQLModelBase):
|
||||
"""分享基础字段,供 DTO 和数据库模型共享"""
|
||||
|
||||
object_id: UUID
|
||||
"""关联的对象UUID"""
|
||||
|
||||
password: str | None = None
|
||||
"""分享密码"""
|
||||
|
||||
expires: datetime | None = None
|
||||
"""过期时间(NULL为永不过期)"""
|
||||
|
||||
remain_downloads: int | None = None
|
||||
"""剩余下载次数(NULL为不限制)"""
|
||||
|
||||
preview_enabled: bool = True
|
||||
"""是否允许预览"""
|
||||
|
||||
score: int = 0
|
||||
"""兑换此分享所需的积分"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class Share(SQLModelBase, TableBaseMixin):
|
||||
"""分享模型"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("code", name="uq_share_code"),
|
||||
Index("ix_share_source_name", "source_name"),
|
||||
Index("ix_share_user_created", "user_id", "created_at"),
|
||||
Index("ix_share_object", "object_id"),
|
||||
)
|
||||
|
||||
code: str = Field(max_length=64, nullable=False, index=True)
|
||||
"""分享码"""
|
||||
|
||||
password: str | None = Field(default=None, max_length=255)
|
||||
"""分享密码(加密后)"""
|
||||
|
||||
object_id: UUID = Field(
|
||||
foreign_key="object.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的对象UUID"""
|
||||
|
||||
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""浏览次数"""
|
||||
|
||||
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""下载次数"""
|
||||
|
||||
remain_downloads: int | None = Field(default=None)
|
||||
"""剩余下载次数 (NULL为不限制)"""
|
||||
|
||||
expires: datetime | None = Field(default=None)
|
||||
"""过期时间 (NULL为永不过期)"""
|
||||
|
||||
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||
"""是否允许预览"""
|
||||
|
||||
source_name: str | None = Field(default=None, max_length=255)
|
||||
"""源名称(冗余字段,便于展示)"""
|
||||
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""兑换此分享所需的积分"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""创建分享的用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="shares")
|
||||
"""分享创建者"""
|
||||
|
||||
object: "Object" = Relationship(back_populates="shares")
|
||||
"""关联的对象"""
|
||||
|
||||
reports: list["Report"] = Relationship(
|
||||
back_populates="share",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
"""举报列表"""
|
||||
|
||||
@property
|
||||
def is_dir(self) -> bool:
|
||||
"""是否为目录分享(向后兼容属性)"""
|
||||
from .object import ObjectType
|
||||
return self.object.type == ObjectType.FOLDER if self.object else False
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class ShareCreateRequest(ShareBase):
|
||||
"""创建分享请求 DTO,继承 ShareBase 中的所有字段"""
|
||||
pass
|
||||
|
||||
|
||||
class ShareResponse(SQLModelBase):
|
||||
"""分享响应 DTO"""
|
||||
|
||||
id: int
|
||||
"""分享ID"""
|
||||
|
||||
code: str
|
||||
"""分享码"""
|
||||
|
||||
object_id: UUID
|
||||
"""关联对象UUID"""
|
||||
|
||||
source_name: str | None
|
||||
"""源名称"""
|
||||
|
||||
views: int
|
||||
"""浏览次数"""
|
||||
|
||||
downloads: int
|
||||
"""下载次数"""
|
||||
|
||||
remain_downloads: int | None
|
||||
"""剩余下载次数"""
|
||||
|
||||
expires: datetime | None
|
||||
"""过期时间"""
|
||||
|
||||
preview_enabled: bool
|
||||
"""是否允许预览"""
|
||||
|
||||
score: int
|
||||
"""积分"""
|
||||
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
is_expired: bool
|
||||
"""是否已过期"""
|
||||
|
||||
has_password: bool
|
||||
"""是否有密码"""
|
||||
|
||||
|
||||
class ShareListItemBase(SQLModelBase):
|
||||
"""分享列表项基础字段"""
|
||||
|
||||
id: int
|
||||
"""分享ID"""
|
||||
|
||||
code: str
|
||||
"""分享码"""
|
||||
|
||||
views: int
|
||||
"""浏览次数"""
|
||||
|
||||
downloads: int
|
||||
"""下载次数"""
|
||||
|
||||
remain_downloads: int | None
|
||||
"""剩余下载次数"""
|
||||
|
||||
expires: datetime | None
|
||||
"""过期时间"""
|
||||
|
||||
preview_enabled: bool
|
||||
"""是否允许预览"""
|
||||
|
||||
score: int
|
||||
"""积分"""
|
||||
|
||||
user_id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
|
||||
class AdminShareListItem(ShareListItemBase):
|
||||
"""管理员分享列表项 DTO,添加关联字段"""
|
||||
|
||||
username: str | None
|
||||
"""用户名"""
|
||||
|
||||
object_name: str | None
|
||||
"""对象名称"""
|
||||
|
||||
@classmethod
|
||||
def from_share(
|
||||
cls,
|
||||
share: "Share",
|
||||
user: "User | None",
|
||||
obj: "Object | None",
|
||||
) -> "AdminShareListItem":
|
||||
"""从 Share ORM 对象构建"""
|
||||
return cls(
|
||||
**ShareListItemBase.model_validate(share, from_attributes=True).model_dump(),
|
||||
username=user.email if user else None,
|
||||
object_name=obj.name if obj else None,
|
||||
)
|
||||
37
sqlmodels/source_link.py
Normal file
37
sqlmodels/source_link.py
Normal file
@@ -0,0 +1,37 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
|
||||
|
||||
class SourceLink(SQLModelBase, TableBaseMixin):
|
||||
"""链接模型"""
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_sourcelink_object_name", "object_id", "name"),
|
||||
)
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""链接名称"""
|
||||
|
||||
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""通过此链接的下载次数"""
|
||||
|
||||
# 外键
|
||||
object_id: UUID = Field(
|
||||
foreign_key="object.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的对象UUID(必须是文件类型)"""
|
||||
|
||||
# 关系
|
||||
object: "Object" = Relationship(back_populates="source_links")
|
||||
"""关联的对象"""
|
||||
31
sqlmodels/storage_pack.py
Normal file
31
sqlmodels/storage_pack.py
Normal file
@@ -0,0 +1,31 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class StoragePack(SQLModelBase, TableBaseMixin):
|
||||
"""容量包模型"""
|
||||
|
||||
name: str = Field(max_length=255, description="容量包名称")
|
||||
active_time: datetime | None = Field(default=None, description="激活时间")
|
||||
expired_time: datetime | None = Field(default=None, index=True, description="过期时间")
|
||||
size: int = Field(description="容量包大小(字节)")
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="storage_packs")
|
||||
50
sqlmodels/tag.py
Normal file
50
sqlmodels/tag.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
class TagType(StrEnum):
|
||||
"""标签类型枚举"""
|
||||
MANUAL = "manual"
|
||||
"""手动标签"""
|
||||
AUTOMATIC = "automatic"
|
||||
"""自动标签"""
|
||||
|
||||
|
||||
class Tag(SQLModelBase, TableBaseMixin):
|
||||
"""标签模型"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255)
|
||||
"""标签名称"""
|
||||
|
||||
icon: str | None = Field(default=None, max_length=255)
|
||||
"""标签图标"""
|
||||
|
||||
color: str | None = Field(default=None, max_length=255)
|
||||
"""标签颜色"""
|
||||
|
||||
type: TagType = Field(default=TagType.MANUAL)
|
||||
"""标签类型"""
|
||||
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="tags")
|
||||
153
sqlmodels/task.py
Normal file
153
sqlmodels/task.py
Normal file
@@ -0,0 +1,153 @@
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, CheckConstraint, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .download import Download
|
||||
from .user import User
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
"""任务状态枚举"""
|
||||
QUEUED = "queued"
|
||||
"""排队中"""
|
||||
RUNNING = "running"
|
||||
"""处理中"""
|
||||
COMPLETED = "completed"
|
||||
"""已完成"""
|
||||
ERROR = "error"
|
||||
"""错误"""
|
||||
|
||||
|
||||
class TaskType(StrEnum):
|
||||
"""任务类型枚举"""
|
||||
# [TODO] 补充具体任务类型
|
||||
pass
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
|
||||
class TaskSummaryBase(SQLModelBase):
|
||||
"""任务摘要基础字段"""
|
||||
|
||||
id: int
|
||||
"""任务ID"""
|
||||
|
||||
type: int
|
||||
"""任务类型"""
|
||||
|
||||
status: TaskStatus
|
||||
"""任务状态"""
|
||||
|
||||
progress: int
|
||||
"""进度(0-100)"""
|
||||
|
||||
error: str | None
|
||||
"""错误信息"""
|
||||
|
||||
user_id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
created_at: datetime
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class TaskSummary(TaskSummaryBase):
|
||||
"""任务摘要,用于管理员列表展示"""
|
||||
|
||||
username: str | None
|
||||
"""用户名"""
|
||||
|
||||
@classmethod
|
||||
def from_task(cls, task: "Task", user: "User | None") -> "TaskSummary":
|
||||
"""从 Task ORM 对象构建"""
|
||||
return cls(
|
||||
**TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(),
|
||||
username=user.email if user else None,
|
||||
)
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
|
||||
class TaskPropsBase(SQLModelBase):
|
||||
"""任务属性基础模型"""
|
||||
|
||||
source_path: str | None = None
|
||||
"""源路径"""
|
||||
|
||||
dest_path: str | None = None
|
||||
"""目标路径"""
|
||||
|
||||
file_ids: str | None = None
|
||||
"""文件ID列表(逗号分隔)"""
|
||||
|
||||
# [TODO] 根据业务需求补充更多字段
|
||||
|
||||
|
||||
class TaskProps(TaskPropsBase, TableBaseMixin):
|
||||
"""任务属性模型(与Task一对一关联)"""
|
||||
|
||||
task_id: int = Field(
|
||||
foreign_key="task.id",
|
||||
primary_key=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""关联的任务ID"""
|
||||
|
||||
# 反向关系
|
||||
task: "Task" = Relationship(back_populates="props")
|
||||
"""关联的任务"""
|
||||
|
||||
|
||||
class Task(SQLModelBase, TableBaseMixin):
|
||||
"""任务模型"""
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("progress BETWEEN 0 AND 100", name="ck_task_progress_range"),
|
||||
Index("ix_task_status", "status"),
|
||||
Index("ix_task_user_status", "user_id", "status"),
|
||||
)
|
||||
|
||||
status: TaskStatus = Field(default=TaskStatus.QUEUED)
|
||||
"""任务状态"""
|
||||
|
||||
type: int = Field(default=0)
|
||||
"""任务类型 [TODO] 待定义枚举"""
|
||||
|
||||
progress: int = Field(default=0, ge=0, le=100)
|
||||
"""任务进度(0-100)"""
|
||||
|
||||
error: str | None = Field(default=None)
|
||||
"""错误信息"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
props: TaskProps | None = Relationship(
|
||||
back_populates="task",
|
||||
sa_relationship_kwargs={"uselist": False, "cascade": "all, delete-orphan"},
|
||||
)
|
||||
"""任务属性"""
|
||||
|
||||
user: "User" = Relationship(back_populates="tasks")
|
||||
"""所属用户"""
|
||||
|
||||
downloads: list["Download"] = Relationship(back_populates="task")
|
||||
"""关联的下载任务"""
|
||||
258
sqlmodels/uri.py
Normal file
258
sqlmodels/uri.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
from enum import StrEnum
|
||||
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
|
||||
|
||||
from .base import SQLModelBase
|
||||
|
||||
|
||||
class FileSystemNamespace(StrEnum):
|
||||
"""文件系统命名空间"""
|
||||
MY = "my"
|
||||
"""用户个人空间"""
|
||||
|
||||
SHARE = "share"
|
||||
"""分享空间"""
|
||||
|
||||
TRASH = "trash"
|
||||
"""回收站"""
|
||||
|
||||
|
||||
class DiskNextURI(SQLModelBase):
|
||||
"""
|
||||
DiskNext 文件 URI
|
||||
|
||||
URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query]
|
||||
|
||||
fs_id 可省略:
|
||||
- my/trash 命名空间省略时默认当前用户
|
||||
- share 命名空间必须提供 fs_id(Share.code)
|
||||
"""
|
||||
|
||||
fs_id: str | None = None
|
||||
"""文件系统标识符,可省略"""
|
||||
|
||||
namespace: FileSystemNamespace
|
||||
"""命名空间"""
|
||||
|
||||
path: str = "/"
|
||||
"""路径"""
|
||||
|
||||
password: str | None = None
|
||||
"""访问密码(用于有密码的分享)"""
|
||||
|
||||
query: dict[str, str] | None = None
|
||||
"""查询参数"""
|
||||
|
||||
# === 属性 ===
|
||||
|
||||
@property
|
||||
def path_parts(self) -> list[str]:
|
||||
"""路径分割为列表(过滤空串)"""
|
||||
return [p for p in self.path.split("/") if p]
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""是否指向根目录"""
|
||||
return self.path.strip("/") == ""
|
||||
|
||||
# === 核心方法 ===
|
||||
|
||||
def id(self, default_id: str | None = None) -> str | None:
|
||||
"""
|
||||
获取 fs_id,省略时返回 default_id
|
||||
|
||||
参考 Cloudreve URI.ID(defaultUid) 方法
|
||||
|
||||
:param default_id: 默认值(通常为当前用户 ID)
|
||||
:return: fs_id 或 default_id
|
||||
"""
|
||||
return self.fs_id if self.fs_id else default_id
|
||||
|
||||
# === 类方法 ===
|
||||
|
||||
@classmethod
|
||||
def parse(cls, uri: str) -> "DiskNextURI":
|
||||
"""
|
||||
解析 URI 字符串
|
||||
|
||||
实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析
|
||||
- hostname → namespace
|
||||
- username → fs_id
|
||||
- password → password
|
||||
- path → path
|
||||
- query → query dict
|
||||
|
||||
:param uri: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
:return: DiskNextURI 实例
|
||||
:raises ValueError: URI 格式无效
|
||||
"""
|
||||
if not uri.startswith("disknext://"):
|
||||
raise ValueError(f"URI 必须以 disknext:// 开头: {uri}")
|
||||
|
||||
# 替换协议为 http:// 以利用 urllib.parse 解析
|
||||
http_uri = "http://" + uri[len("disknext://"):]
|
||||
parsed = urlparse(http_uri)
|
||||
|
||||
# 解析 namespace
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError(f"URI 缺少命名空间: {uri}")
|
||||
|
||||
try:
|
||||
namespace = FileSystemNamespace(hostname)
|
||||
except ValueError:
|
||||
raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}")
|
||||
|
||||
# 解析 fs_id 和 password
|
||||
fs_id = unquote(parsed.username) if parsed.username else None
|
||||
password = unquote(parsed.password) if parsed.password else None
|
||||
|
||||
# 解析 path
|
||||
path = unquote(parsed.path) if parsed.path else "/"
|
||||
if not path:
|
||||
path = "/"
|
||||
|
||||
# 解析 query
|
||||
query: dict[str, str] | None = None
|
||||
if parsed.query:
|
||||
raw_query = parse_qs(parsed.query, keep_blank_values=True)
|
||||
query = {k: v[0] for k, v in raw_query.items()}
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
query=query,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
namespace: FileSystemNamespace,
|
||||
path: str = "/",
|
||||
fs_id: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> "DiskNextURI":
|
||||
"""
|
||||
构建 URI 实例
|
||||
|
||||
:param namespace: 命名空间
|
||||
:param path: 路径
|
||||
:param fs_id: 文件系统标识符
|
||||
:param password: 访问密码
|
||||
:return: DiskNextURI 实例
|
||||
"""
|
||||
# 确保 path 以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
)
|
||||
|
||||
# === 实例方法 ===
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""
|
||||
序列化为 URI 字符串
|
||||
|
||||
:return: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
"""
|
||||
result = "disknext://"
|
||||
|
||||
# fs_id 和 password
|
||||
if self.fs_id:
|
||||
result += quote(self.fs_id, safe="")
|
||||
if self.password:
|
||||
result += ":" + quote(self.password, safe="")
|
||||
result += "@"
|
||||
|
||||
# namespace
|
||||
result += self.namespace.value
|
||||
|
||||
# path
|
||||
result += self.path
|
||||
|
||||
# query
|
||||
if self.query:
|
||||
result += "?" + urlencode(self.query)
|
||||
|
||||
return result
|
||||
|
||||
def join(self, *elements: str) -> "DiskNextURI":
|
||||
"""
|
||||
拼接路径元素,返回新 URI
|
||||
|
||||
:param elements: 路径元素
|
||||
:return: 新的 DiskNextURI 实例
|
||||
"""
|
||||
base = self.path.rstrip("/")
|
||||
for element in elements:
|
||||
element = element.strip("/")
|
||||
if element:
|
||||
base += "/" + element
|
||||
|
||||
if not base:
|
||||
base = "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=base,
|
||||
password=self.password,
|
||||
query=self.query,
|
||||
)
|
||||
|
||||
def dir_uri(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回父目录的 URI
|
||||
|
||||
:return: 父目录的 DiskNextURI 实例
|
||||
"""
|
||||
parts = self.path_parts
|
||||
if not parts:
|
||||
# 已经是根目录
|
||||
return self.root()
|
||||
|
||||
parent_path = "/" + "/".join(parts[:-1])
|
||||
if not parent_path.endswith("/"):
|
||||
parent_path += "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=parent_path,
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def root(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回根目录的 URI(保留 namespace 和 fs_id)
|
||||
|
||||
:return: 根目录的 DiskNextURI 实例
|
||||
"""
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path="/",
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
返回路径的最后一段(文件名或目录名)
|
||||
|
||||
:return: 文件名或目录名,根目录返回空字符串
|
||||
"""
|
||||
parts = self.path_parts
|
||||
return parts[-1] if parts else ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_string()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DiskNextURI({self.to_string()!r})"
|
||||
554
sqlmodels/user.py
Normal file
554
sqlmodels/user.py
Normal file
@@ -0,0 +1,554 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Literal, TYPE_CHECKING, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import BinaryExpression, ClauseElement, and_
|
||||
from sqlmodel import Field, Relationship
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .model_base import ResponseBase
|
||||
from .mixin import UUIDTableBaseMixin, TableViewRequest, ListResponse
|
||||
|
||||
T = TypeVar("T", bound="User")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .object import Object
|
||||
from .order import Order
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .user_authn import UserAuthn
|
||||
from .webdav import WebDAV
|
||||
|
||||
class AvatarType(StrEnum):
|
||||
"""头像类型枚举"""
|
||||
|
||||
DEFAULT = "default"
|
||||
GRAVATAR = "gravatar"
|
||||
FILE = "file"
|
||||
|
||||
class ThemeType(StrEnum):
|
||||
"""主题类型枚举"""
|
||||
|
||||
LIGHT = "light"
|
||||
DARK = "dark"
|
||||
SYSTEM = "system"
|
||||
|
||||
class UserStatus(StrEnum):
|
||||
"""用户状态枚举"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ADMIN_BANNED = "admin_banned"
|
||||
SYSTEM_BANNED = "system_banned"
|
||||
|
||||
|
||||
# ==================== 筛选参数 ====================
|
||||
|
||||
class UserFilterParams(SQLModelBase):
|
||||
"""
|
||||
用户过滤参数
|
||||
|
||||
用于管理员搜索用户列表,支持用户组、用户名、昵称、状态等过滤。
|
||||
"""
|
||||
group_id: UUID | None = None
|
||||
"""按用户组UUID筛选"""
|
||||
|
||||
email_contains: str | None = Field(default=None, max_length=50)
|
||||
"""邮箱包含(不区分大小写的模糊搜索)"""
|
||||
|
||||
nickname_contains: str | None = Field(default=None, max_length=50)
|
||||
"""昵称包含(不区分大小写的模糊搜索)"""
|
||||
|
||||
status: UserStatus | None = None
|
||||
"""按用户状态筛选"""
|
||||
|
||||
|
||||
# ==================== Base 模型 ====================
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
"""用户基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
score: int = 0
|
||||
"""用户积分"""
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class LoginRequest(SQLModelBase):
|
||||
"""登录请求 DTO"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
two_fa_code: int | None = Field(min_length=6, max_length=6)
|
||||
"""两步验证代码"""
|
||||
|
||||
|
||||
class RegisterRequest(SQLModelBase):
|
||||
"""注册请求 DTO"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
|
||||
class BatchDeleteRequest(SQLModelBase):
|
||||
"""批量删除请求 DTO"""
|
||||
|
||||
ids: list[UUID]
|
||||
"""待删除 UUID 列表"""
|
||||
|
||||
|
||||
class RefreshTokenRequest(SQLModelBase):
|
||||
"""刷新令牌请求 DTO"""
|
||||
|
||||
refresh_token: str
|
||||
"""刷新令牌"""
|
||||
|
||||
|
||||
class WebAuthnInfo(SQLModelBase):
|
||||
"""WebAuthn 信息 DTO"""
|
||||
|
||||
credential_id: str
|
||||
"""凭证 ID"""
|
||||
|
||||
credential_public_key: str
|
||||
"""凭证公钥"""
|
||||
|
||||
sign_count: int
|
||||
"""签名计数器"""
|
||||
|
||||
credential_device_type: bool
|
||||
"""是否为平台认证器"""
|
||||
|
||||
credential_backed_up: bool
|
||||
"""凭证是否已备份"""
|
||||
|
||||
transports: list[str]
|
||||
"""支持的传输方式"""
|
||||
|
||||
class AccessTokenBase(BaseModel):
|
||||
"""访问令牌响应 DTO"""
|
||||
|
||||
access_expires: datetime
|
||||
"""访问令牌过期时间"""
|
||||
|
||||
access_token: str
|
||||
"""访问令牌"""
|
||||
|
||||
class RefreshTokenBase(BaseModel):
|
||||
"""刷新令牌响应DTO"""
|
||||
|
||||
refresh_expires: datetime
|
||||
"""刷新令牌过期时间"""
|
||||
|
||||
refresh_token: str
|
||||
"""刷新令牌"""
|
||||
|
||||
|
||||
class TokenResponse(ResponseBase, AccessTokenBase, RefreshTokenBase):
|
||||
"""令牌响应 DTO"""
|
||||
|
||||
|
||||
class UserResponse(ResponseBase):
|
||||
"""用户响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""用户昵称"""
|
||||
|
||||
avatar: Literal["default", "gravatar", "file"] = "default"
|
||||
"""头像类型"""
|
||||
|
||||
created_at: datetime
|
||||
"""用户创建时间"""
|
||||
|
||||
anonymous: bool = False
|
||||
"""是否为匿名用户"""
|
||||
|
||||
group: "GroupResponse | None" = None
|
||||
"""用户所属用户组"""
|
||||
|
||||
tags: list[str] = []
|
||||
"""用户标签列表"""
|
||||
|
||||
class UserStorageResponse(SQLModelBase):
|
||||
"""用户存储信息 DTO"""
|
||||
|
||||
used: int
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
free: int
|
||||
"""剩余存储空间(字节)"""
|
||||
|
||||
total: int
|
||||
"""总存储空间(字节)"""
|
||||
|
||||
|
||||
class UserPublic(UserBase):
|
||||
"""用户公开信息 DTO,用于 API 响应"""
|
||||
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""昵称"""
|
||||
|
||||
storage: int = 0
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
avatar: str | None = None
|
||||
"""头像地址"""
|
||||
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
group_id: UUID | None = None
|
||||
"""所属用户组UUID"""
|
||||
|
||||
group_name: str | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,null 表示未启用)"""
|
||||
|
||||
created_at: datetime | None = None
|
||||
"""创建时间"""
|
||||
|
||||
updated_at: datetime | None = None
|
||||
"""更新时间"""
|
||||
|
||||
|
||||
class UserSettingResponse(SQLModelBase):
|
||||
"""用户设置响应 DTO"""
|
||||
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""昵称"""
|
||||
|
||||
created_at: datetime
|
||||
"""用户注册时间"""
|
||||
|
||||
group_name: str
|
||||
"""用户所属用户组名称"""
|
||||
|
||||
language: str
|
||||
"""语言偏好"""
|
||||
|
||||
timezone: int
|
||||
"""时区"""
|
||||
|
||||
authn: "list[AuthnResponse] | None" = None
|
||||
"""认证信息"""
|
||||
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
two_factor: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
|
||||
# ==================== 管理员用户管理 DTO ====================
|
||||
|
||||
class UserAdminCreateRequest(SQLModelBase):
|
||||
"""管理员创建用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码(明文,由服务端加密)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
group_id: UUID
|
||||
"""所属用户组UUID"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
|
||||
class UserAdminUpdateRequest(SQLModelBase):
|
||||
"""管理员更新用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""邮箱"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
password: str | None = None
|
||||
"""新密码(为空则不修改)"""
|
||||
|
||||
group_id: UUID | None = None
|
||||
"""用户组UUID"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
score: int | None = Field(default=None, ge=0)
|
||||
"""积分"""
|
||||
|
||||
storage: int | None = Field(default=None, ge=0)
|
||||
"""已用存储空间(用于手动校准)"""
|
||||
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,传 null 可清除,不传则不修改)"""
|
||||
|
||||
|
||||
class UserCalibrateResponse(SQLModelBase):
|
||||
"""用户存储校准响应 DTO"""
|
||||
|
||||
user_id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
previous_storage: int
|
||||
"""校准前的存储空间(字节)"""
|
||||
|
||||
current_storage: int
|
||||
"""校准后的存储空间(字节)"""
|
||||
|
||||
difference: int
|
||||
"""差异值(字节)"""
|
||||
|
||||
file_count: int
|
||||
"""实际文件数量"""
|
||||
|
||||
|
||||
class UserAdminDetailResponse(UserPublic):
|
||||
"""管理员用户详情响应 DTO"""
|
||||
|
||||
two_factor_enabled: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
file_count: int = 0
|
||||
"""文件数量"""
|
||||
|
||||
share_count: int = 0
|
||||
"""分享数量"""
|
||||
|
||||
task_count: int = 0
|
||||
"""任务数量"""
|
||||
|
||||
|
||||
# 前向引用导入
|
||||
from .group import GroupResponse # noqa: E402
|
||||
from .user_authn import AuthnResponse # noqa: E402
|
||||
|
||||
# 更新前向引用
|
||||
UserResponse.model_rebuild()
|
||||
UserSettingResponse.model_rebuild()
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class User(UserBase, UUIDTableBaseMixin):
|
||||
"""用户模型"""
|
||||
|
||||
email: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
||||
|
||||
password: str = Field(max_length=255)
|
||||
"""用户密码(加密后)"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
two_factor: str | None = Field(default=None, min_length=32, max_length=32)
|
||||
"""两步验证密钥"""
|
||||
|
||||
avatar: str = Field(default="default", max_length=255)
|
||||
"""头像地址"""
|
||||
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, ge=0)
|
||||
"""用户积分"""
|
||||
|
||||
group_expires: datetime | None = Field(default=None)
|
||||
"""当前用户组过期时间"""
|
||||
|
||||
# Option 相关字段
|
||||
# theme: ThemeType = Field(default=ThemeType.SYSTEM)
|
||||
# """主题类型: light/dark/system"""
|
||||
|
||||
language: str = Field(default="zh-CN", max_length=5)
|
||||
"""语言偏好"""
|
||||
|
||||
timezone: int = Field(default=8, ge=-12, le=12)
|
||||
"""时区,UTC 偏移小时数"""
|
||||
|
||||
# 外键
|
||||
group_id: UUID = Field(
|
||||
foreign_key="group.id",
|
||||
index=True,
|
||||
ondelete="RESTRICT"
|
||||
)
|
||||
"""所属用户组UUID"""
|
||||
|
||||
previous_group_id: UUID | None = Field(
|
||||
default=None,
|
||||
foreign_key="group.id",
|
||||
ondelete="SET NULL"
|
||||
)
|
||||
"""之前的用户组UUID(用于过期后恢复)"""
|
||||
|
||||
|
||||
# 关系
|
||||
group: "Group" = Relationship(
|
||||
back_populates="users",
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "User.group_id"
|
||||
}
|
||||
)
|
||||
previous_group: "Group" = Relationship(
|
||||
back_populates="previous_users",
|
||||
sa_relationship_kwargs={
|
||||
"foreign_keys": "User.previous_group_id"
|
||||
}
|
||||
)
|
||||
|
||||
downloads: list["Download"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
objects: list["Object"] = Relationship(
|
||||
back_populates="owner",
|
||||
sa_relationship_kwargs={
|
||||
"cascade": "all, delete-orphan",
|
||||
"foreign_keys": "[Object.owner_id]"
|
||||
}
|
||||
)
|
||||
"""用户的所有对象(文件和目录)"""
|
||||
orders: list["Order"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
shares: list["Share"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
storage_packs: list["StoragePack"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
tags: list["Tag"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
tasks: list["Task"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
webdavs: list["WebDAV"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
authns: list["UserAuthn"] = Relationship(
|
||||
back_populates="user",
|
||||
sa_relationship_kwargs={"cascade": "all, delete-orphan"}
|
||||
)
|
||||
|
||||
def to_public(self) -> "UserPublic":
|
||||
"""转换为公开 DTO,排除敏感字段。需要预加载 group 关系。"""
|
||||
data = UserPublic.model_validate(self)
|
||||
data.group_name = self.group.name
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_with_count(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
*,
|
||||
filter_params: 'UserFilterParams | None' = None,
|
||||
join: type[T] | tuple[type[T], ClauseElement] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
) -> 'ListResponse[T]':
|
||||
"""
|
||||
获取分页用户列表及总数,支持用户过滤参数
|
||||
|
||||
:param filter_params: UserFilterParams 过滤参数对象(用户组、用户名、昵称、状态等)
|
||||
:param 其他参数: 继承自 UUIDTableBaseMixin.get_with_count()
|
||||
"""
|
||||
# 构建过滤条件
|
||||
merged_condition = condition
|
||||
if filter_params is not None:
|
||||
filter_conditions: list[BinaryExpression] = []
|
||||
|
||||
if filter_params.group_id is not None:
|
||||
filter_conditions.append(cls.group_id == filter_params.group_id)
|
||||
|
||||
if filter_params.email_contains is not None:
|
||||
filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%"))
|
||||
|
||||
if filter_params.nickname_contains is not None:
|
||||
filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%"))
|
||||
|
||||
if filter_params.status is not None:
|
||||
filter_conditions.append(cls.status == filter_params.status)
|
||||
|
||||
if filter_conditions:
|
||||
combined_filter = and_(*filter_conditions)
|
||||
if merged_condition is not None:
|
||||
merged_condition = and_(merged_condition, combined_filter)
|
||||
else:
|
||||
merged_condition = combined_filter
|
||||
|
||||
return await super().get_with_count(
|
||||
session,
|
||||
merged_condition,
|
||||
join=join,
|
||||
options=options,
|
||||
load=load,
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
)
|
||||
|
||||
61
sqlmodels/user_authn.py
Normal file
61
sqlmodels/user_authn.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
|
||||
# ==================== DTO 模型 ====================
|
||||
|
||||
class AuthnResponse(SQLModelBase):
|
||||
"""WebAuthn 响应 DTO"""
|
||||
|
||||
id: str
|
||||
"""凭证ID"""
|
||||
|
||||
fingerprint: str
|
||||
"""凭证指纹"""
|
||||
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class UserAuthn(SQLModelBase, TableBaseMixin):
|
||||
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
||||
|
||||
credential_id: str = Field(max_length=255, unique=True, index=True)
|
||||
"""凭证 ID,Base64 编码"""
|
||||
|
||||
credential_public_key: str = Field(sa_column=Column(Text))
|
||||
"""凭证公钥,Base64 编码"""
|
||||
|
||||
sign_count: int = Field(default=0, ge=0)
|
||||
"""签名计数器,用于防重放攻击"""
|
||||
|
||||
credential_device_type: str = Field(max_length=32)
|
||||
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
||||
|
||||
credential_backed_up: bool = Field(default=False)
|
||||
"""凭证是否已备份"""
|
||||
|
||||
transports: str | None = Field(default=None, max_length=255)
|
||||
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
||||
|
||||
name: str | None = Field(default=None, max_length=100)
|
||||
"""用户自定义的凭证名称,便于识别"""
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="authns")
|
||||
33
sqlmodels/webdav.py
Normal file
33
sqlmodels/webdav.py
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class WebDAV(SQLModelBase, TableBaseMixin):
|
||||
"""WebDAV账户模型"""
|
||||
|
||||
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
|
||||
|
||||
name: str = Field(max_length=255, description="WebDAV账户名")
|
||||
password: str = Field(max_length=255, description="WebDAV密码")
|
||||
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
|
||||
readonly: bool = Field(default=False, description="是否只读")
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||
|
||||
# 外键
|
||||
user_id: UUID = Field(
|
||||
foreign_key="user.id",
|
||||
index=True,
|
||||
ondelete="CASCADE"
|
||||
)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
Reference in New Issue
Block a user