feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
@@ -3,7 +3,9 @@
|
||||
"allow": [
|
||||
"Bash(git rev-parse:*)",
|
||||
"Bash(findstr:*)",
|
||||
"Bash(find:*)"
|
||||
"Bash(find:*)",
|
||||
"Bash(yarn tsc:*)",
|
||||
"Bash(dir:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
12
main.py
12
main.py
@@ -1,25 +1,29 @@
|
||||
from typing import NoReturn
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from utils.conf import appmeta
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
from models.database import init_db
|
||||
from models.migration import migration
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
from utils import JWT
|
||||
from routers import router
|
||||
from service.redis import RedisManager
|
||||
from loguru import logger as l
|
||||
|
||||
async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(init_db)
|
||||
lifespan.add_startup(_init_db)
|
||||
lifespan.add_startup(migration)
|
||||
lifespan.add_startup(JWT.load_secret_key)
|
||||
lifespan.add_startup(RedisManager.connect)
|
||||
|
||||
# 添加关闭项
|
||||
lifespan.add_shutdown(DatabaseManager.close)
|
||||
lifespan.add_shutdown(RedisManager.disconnect)
|
||||
|
||||
# 创建应用实例并设置元数据
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID
|
||||
from fastapi import Depends
|
||||
import jwt
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import User
|
||||
from utils import JWT
|
||||
from .dependencies import SessionDep
|
||||
from utils import http_exceptions
|
||||
@@ -25,8 +25,8 @@ async def auth_required(
|
||||
|
||||
user_id = UUID(user_id)
|
||||
|
||||
# 从数据库获取用户信息
|
||||
user = await User.get(session, User.id == user_id)
|
||||
# 从数据库获取用户信息(预加载 group 关系)
|
||||
user = await User.get(session, User.id == user_id, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
@@ -44,8 +44,7 @@ async def admin_required(
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(admin_required)])
|
||||
"""
|
||||
group = await user.awaitable_attrs.group
|
||||
if group.admin:
|
||||
if user.group.admin:
|
||||
return user
|
||||
raise http_exceptions.raise_forbidden("Admin Required")
|
||||
|
||||
|
||||
@@ -14,14 +14,14 @@ from uuid import UUID
|
||||
from fastapi import Depends, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.database import get_session
|
||||
from models.mixin import TimeFilterRequest, TableViewRequest
|
||||
from models.user import UserFilterParams, UserStatus
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
|
||||
from sqlmodels.user import UserFilterParams, UserStatus
|
||||
|
||||
|
||||
# --- 数据库会话依赖 ---
|
||||
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(get_session)]
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(DatabaseManager.get_session)]
|
||||
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
||||
|
||||
|
||||
@@ -79,14 +79,14 @@ TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends(_get_table_
|
||||
|
||||
async def _get_user_filter_params(
|
||||
group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None,
|
||||
username: Annotated[str | None, Query(max_length=50, description="按用户名模糊搜索")] = None,
|
||||
email: Annotated[str | None, Query(max_length=50, description="按邮箱模糊搜索")] = None,
|
||||
nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None,
|
||||
status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None,
|
||||
) -> UserFilterParams:
|
||||
"""解析用户过滤查询参数"""
|
||||
return UserFilterParams(
|
||||
group_id=group_id,
|
||||
username_contains=username,
|
||||
email_contains=email,
|
||||
nickname_contains=nickname,
|
||||
status=status,
|
||||
)
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import String, inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的default_factory丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
|
||||
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory,
|
||||
遵循单一真相原则(不硬编码 default_factory)。
|
||||
|
||||
需要修复的字段:
|
||||
- id: 主键(从父类获取 default_factory)
|
||||
- created_at: 创建时间戳(从父类获取 default_factory)
|
||||
- updated_at: 更新时间戳(从父类获取 default_factory)
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 污染的
|
||||
if not isinstance(field_info.default, InstrumentedAttribute):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
# 动态检测所有需要修复的字段
|
||||
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
|
||||
# 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染)
|
||||
# 2. 原始定义有 default_factory 或明确的 default 值
|
||||
#
|
||||
# 覆盖场景:
|
||||
# - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复
|
||||
# - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增)
|
||||
# - created_at/updated_at:有 default_factory=now,需要修复
|
||||
# - 外键字段(created_by_id等):有 default=None,需要修复
|
||||
# - 普通字段(name, temperature等):无 default_factory,不需要修复
|
||||
#
|
||||
# MRO 查找保证:
|
||||
# - 在多重继承场景下,MRO 顺序是确定性的
|
||||
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute)
|
||||
if not isinstance(current_field.default, InstrumentedAttribute):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制SQLModel特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
Example:
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
@@ -5,15 +5,15 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
)
|
||||
from models.base import SQLModelBase
|
||||
from models.setting import (
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.setting import (
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from models.setting import SettingsType
|
||||
from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
from .file import admin_file_router
|
||||
|
||||
@@ -5,14 +5,60 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
Policy, PolicyType, User, ResponseBase, ListResponse,
|
||||
from sqlmodels import (
|
||||
Policy, PolicyType, User, ListResponse,
|
||||
Object, ObjectType, AdminFileResponse, FileBanRequest, )
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
async def _set_ban_recursive(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
ban: bool,
|
||||
admin_id: UUID,
|
||||
reason: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
递归设置封禁状态,返回受影响对象数量。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要封禁/解禁的对象
|
||||
:param ban: True=封禁, False=解禁
|
||||
:param admin_id: 管理员UUID
|
||||
:param reason: 封禁原因
|
||||
:return: 受影响的对象数量
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# 如果是文件夹,先递归处理子对象
|
||||
if obj.is_folder:
|
||||
children = await Object.get(
|
||||
session,
|
||||
Object.parent_id == obj.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for child in children:
|
||||
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
|
||||
|
||||
# 设置当前对象
|
||||
obj.is_banned = ban
|
||||
if ban:
|
||||
obj.banned_at = datetime.now()
|
||||
obj.banned_by = admin_id
|
||||
obj.ban_reason = reason
|
||||
else:
|
||||
obj.banned_at = None
|
||||
obj.banned_by = None
|
||||
obj.ban_reason = None
|
||||
|
||||
await obj.save(session)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
admin_file_router = APIRouter(
|
||||
prefix="/file",
|
||||
tags=["admin", "admin_file"],
|
||||
@@ -119,15 +165,17 @@ async def router_admin_preview_file(
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
封禁或解禁文件。封禁后用户无法访问该文件。
|
||||
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
|
||||
封禁文件夹时会级联封禁所有子对象。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param file_id: 文件UUID
|
||||
@@ -139,24 +187,10 @@ async def router_admin_ban_file(
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
file_obj.is_banned = request.is_banned
|
||||
if request.is_banned:
|
||||
file_obj.banned_at = datetime.now()
|
||||
file_obj.banned_by = admin.id
|
||||
file_obj.ban_reason = request.reason
|
||||
else:
|
||||
file_obj.banned_at = None
|
||||
file_obj.banned_by = None
|
||||
file_obj.ban_reason = None
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason)
|
||||
|
||||
file_obj = await file_obj.save(session)
|
||||
|
||||
action = "封禁" if request.is_banned else "解禁"
|
||||
l.info(f"管理员{action}了文件: {file_obj.name}")
|
||||
return ResponseBase(data={
|
||||
"id": str(file_obj.id),
|
||||
"is_banned": file_obj.is_banned,
|
||||
})
|
||||
action = "封禁" if request.ban else "解禁"
|
||||
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
|
||||
|
||||
|
||||
@admin_file_router.delete(
|
||||
@@ -164,12 +198,13 @@ async def router_admin_ban_file(
|
||||
summary='删除文件',
|
||||
description='Delete file by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
delete_physical: bool = True,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除文件。
|
||||
|
||||
@@ -211,5 +246,4 @@ async def router_admin_delete_file(
|
||||
# 使用条件删除
|
||||
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
@@ -5,12 +5,12 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, GroupOptions, )
|
||||
from models.group import (
|
||||
from sqlmodels.group import (
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
|
||||
from models.policy import GroupPolicyLink
|
||||
from sqlmodels.policy import GroupPolicyLink
|
||||
|
||||
admin_group_router = APIRouter(
|
||||
prefix="/group",
|
||||
@@ -113,11 +113,12 @@ async def router_admin_get_group_members(
|
||||
summary='创建用户组',
|
||||
description='Create a new user group',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_create_group(
|
||||
session: SessionDep,
|
||||
request: GroupCreateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建新的用户组。
|
||||
|
||||
@@ -164,7 +165,6 @@ async def router_admin_create_group(
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员创建了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id), "name": group.name})
|
||||
|
||||
|
||||
@admin_group_router.patch(
|
||||
@@ -172,12 +172,13 @@ async def router_admin_create_group(
|
||||
summary='更新用户组信息',
|
||||
description='Update user group information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_update_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
request: GroupUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID更新用户组信息。
|
||||
|
||||
@@ -233,8 +234,7 @@ async def router_admin_update_group(
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员更新了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id)})
|
||||
l.info(f"管理员更新了用户组: {group_id}")
|
||||
|
||||
|
||||
@admin_group_router.delete(
|
||||
@@ -242,11 +242,12 @@ async def router_admin_update_group(
|
||||
summary='删除用户组',
|
||||
description='Delete user group by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID删除用户组。
|
||||
|
||||
@@ -271,5 +272,4 @@ async def router_admin_delete_group(
|
||||
group_name = group.name
|
||||
await Group.delete(session, group)
|
||||
|
||||
l.info(f"管理员删除了用户组: {group_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了用户组: {group_id}")
|
||||
@@ -6,10 +6,10 @@ from sqlmodel import Field
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object, )
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Share, AdminShareListItem, )
|
||||
|
||||
@@ -80,7 +80,7 @@ async def router_admin_get_share(
|
||||
"score": share.score,
|
||||
"has_password": bool(share.password),
|
||||
"user_id": str(share.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"object": {
|
||||
"id": str(obj.id),
|
||||
"name": obj.name,
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Task, TaskSummary,
|
||||
)
|
||||
@@ -89,7 +89,7 @@ async def router_admin_get_task(
|
||||
"progress": task.progress,
|
||||
"error": task.error,
|
||||
"user_id": str(task.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"props": props.model_dump() if props else None,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
"updated_at": task.updated_at.isoformat(),
|
||||
|
||||
@@ -6,11 +6,13 @@ from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, Object, ObjectType, )
|
||||
from models.user import (
|
||||
UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
)
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
@@ -26,19 +28,19 @@ admin_user_router = APIRouter(
|
||||
description='Get user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
|
||||
async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
|
||||
"""
|
||||
根据用户ID获取用户信息,包括用户名、邮箱、注册时间等。
|
||||
|
||||
Args:
|
||||
session(SessionDep): 数据库会话依赖项。
|
||||
user_id (int): 用户ID。
|
||||
user_id (UUID): 用户ID。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含用户信息的响应模型。
|
||||
"""
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.get(
|
||||
@@ -60,7 +62,7 @@ async def router_admin_get_users(
|
||||
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
|
||||
:return: 分页用户列表
|
||||
"""
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view)
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
|
||||
return ListResponse(
|
||||
items=[user.to_public() for user in result.items],
|
||||
count=result.count,
|
||||
@@ -75,22 +77,33 @@ async def router_admin_get_users(
|
||||
)
|
||||
async def router_admin_create_user(
|
||||
session: SessionDep,
|
||||
user: User,
|
||||
) -> ResponseBase:
|
||||
request: UserAdminCreateRequest,
|
||||
) -> UserPublic:
|
||||
"""
|
||||
创建一个新的用户,设置用户名、密码等信息。
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含创建结果的响应模型。
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.username == user.username)
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg="User with this username already exists."
|
||||
)
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
password=Password.hash(request.password),
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.patch(
|
||||
@@ -98,12 +111,13 @@ async def router_admin_create_user(
|
||||
summary='更新用户信息',
|
||||
description='Update user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204
|
||||
)
|
||||
async def router_admin_update_user(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
request: UserAdminUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID更新用户信息。
|
||||
|
||||
@@ -116,8 +130,15 @@ async def router_admin_update_user(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 默认管理员(用户名为 admin)不允许更改用户组
|
||||
if request.group_id and user.username == "admin" and request.group_id != user.group_id:
|
||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
if (request.group_id
|
||||
and default_admin_setting
|
||||
and default_admin_setting.value == str(user_id)
|
||||
and request.group_id != user.group_id):
|
||||
http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
|
||||
|
||||
# 如果更新用户组,验证新组存在
|
||||
@@ -143,38 +164,35 @@ async def router_admin_update_user(
|
||||
setattr(user, key, value)
|
||||
user = await user.save(session)
|
||||
|
||||
l.info(f"管理员更新了用户: {user.username}")
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
path='/{user_id}',
|
||||
summary='删除用户',
|
||||
description='Delete user by ID',
|
||||
path='/',
|
||||
summary='删除用户(支持批量)',
|
||||
description='Delete users by ID list',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_user(
|
||||
async def router_admin_delete_users(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
request: BatchDeleteRequest,
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID删除用户及其所有数据。
|
||||
批量删除用户及其所有数据。
|
||||
|
||||
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除结果
|
||||
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
||||
:return: 删除结果(已删除数 / 总请求数)
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
username = user.username
|
||||
await User.delete(session, user)
|
||||
|
||||
l.info(f"管理员删除了用户: {username}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
deleted = 0
|
||||
for uid in request.ids:
|
||||
user = await User.get(session, User.id == uid)
|
||||
if user:
|
||||
await User.delete(session, user)
|
||||
l.info(f"管理员删除了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.post(
|
||||
@@ -186,7 +204,7 @@ async def router_admin_delete_user(
|
||||
async def router_admin_calibrate_storage(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> UserCalibrateResponse:
|
||||
"""
|
||||
重新计算用户的已用存储空间。
|
||||
|
||||
@@ -228,5 +246,5 @@ async def router_admin_calibrate_storage(
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}")
|
||||
return ResponseBase(data=response.model_dump())
|
||||
l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
|
||||
return response
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
import service.oauth
|
||||
from utils import http_exceptions
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
@@ -14,50 +16,28 @@ from models import (
|
||||
User,
|
||||
ResponseBase,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
directory_router = APIRouter(
|
||||
prefix="/directory",
|
||||
tags=["directory"]
|
||||
)
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
|
||||
async def _get_directory_response(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
folder: Object,
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
|
||||
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
|
||||
构建目录响应 DTO
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(必须以用户名开头)
|
||||
:return: 目录内容
|
||||
:param user_id: 用户UUID
|
||||
:param folder: 目录对象
|
||||
:return: DirectoryResponse
|
||||
"""
|
||||
# 路径必须以用户名开头
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
|
||||
|
||||
path_parts = path.split("/")
|
||||
if path_parts[0] != user.username:
|
||||
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
children = await Object.get_children(session, user.id, folder.id)
|
||||
children = await Object.get_children(session, user_id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
|
||||
objects = [
|
||||
@@ -67,8 +47,8 @@ async def router_directory_get(
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
date=child.updated_at,
|
||||
create_date=child.created_at,
|
||||
created_at=child.created_at,
|
||||
updated_at=child.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for child in children
|
||||
@@ -89,7 +69,74 @@ async def router_directory_get(
|
||||
)
|
||||
|
||||
|
||||
@directory_router.put(
|
||||
@directory_router.get(
|
||||
path="/",
|
||||
summary="获取根目录内容",
|
||||
)
|
||||
async def router_directory_root(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取当前用户的根目录内容
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:return: 根目录内容
|
||||
"""
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
|
||||
if root.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(从根目录开始的相对路径)
|
||||
:return: 目录内容
|
||||
"""
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
# 空路径交给根目录端点处理(理论上不会到达这里)
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
if folder.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, folder)
|
||||
|
||||
|
||||
@directory_router.post(
|
||||
path="/",
|
||||
summary="创建目录",
|
||||
)
|
||||
@@ -123,6 +170,9 @@ async def router_directory_create(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父路径不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名对象
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
download_router = APIRouter(
|
||||
|
||||
@@ -18,7 +18,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required, verify_download_token
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
Object,
|
||||
@@ -91,6 +91,9 @@ async def create_upload_session(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
@@ -100,7 +103,7 @@ async def create_upload_session(
|
||||
# 验证文件大小限制
|
||||
if policy.max_size > 0 and request.file_size > policy.max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
status_code=413,
|
||||
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
|
||||
)
|
||||
|
||||
@@ -221,30 +224,40 @@ async def upload_chunk(
|
||||
upload_session.uploaded_size += len(content)
|
||||
upload_session = await upload_session.save(session)
|
||||
|
||||
# 检查是否完成
|
||||
# 在后续可能的 commit 前保存需要的属性
|
||||
is_complete = upload_session.is_complete
|
||||
uploaded_chunks = upload_session.uploaded_chunks
|
||||
total_chunks = upload_session.total_chunks
|
||||
file_object_id: UUID | None = None
|
||||
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=upload_session.storage_path,
|
||||
size=upload_session.uploaded_size,
|
||||
policy_id=upload_session.policy_id,
|
||||
storage_path=storage_path,
|
||||
size=uploaded_size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=upload_session.file_name,
|
||||
name=file_name,
|
||||
type=ObjectType.FILE,
|
||||
size=upload_session.uploaded_size,
|
||||
size=uploaded_size,
|
||||
physical_file_id=physical_file.id,
|
||||
upload_session_id=str(upload_session.id),
|
||||
parent_id=upload_session.parent_id,
|
||||
upload_session_id=str(upload_session_id),
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=upload_session.policy_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
file_object = await file_object.save(session, commit=False)
|
||||
file_object_id = file_object.id
|
||||
@@ -252,18 +265,18 @@ async def upload_chunk(
|
||||
# 删除上传会话(使用条件删除)
|
||||
await UploadSession.delete(
|
||||
session,
|
||||
condition=UploadSession.id == upload_session.id,
|
||||
condition=UploadSession.id == upload_session_id,
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
|
||||
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
||||
l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}")
|
||||
|
||||
return UploadChunkResponse(
|
||||
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks,
|
||||
total_chunks=upload_session.total_chunks,
|
||||
uploaded_chunks=uploaded_chunks if not is_complete else total_chunks,
|
||||
total_chunks=total_chunks,
|
||||
is_complete=is_complete,
|
||||
object_id=file_object_id,
|
||||
)
|
||||
@@ -368,6 +381,9 @@ async def create_download_token_endpoint(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
token = create_download_token(file_id, user.id)
|
||||
|
||||
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
|
||||
@@ -410,6 +426,9 @@ async def download_file(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 预加载 physical_file 关系以获取存储路径
|
||||
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
@@ -470,6 +489,9 @@ async def create_empty_file(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import MCPRequestBase, MCPResponseBase, MCPMethod
|
||||
|
||||
# MCP 路由
|
||||
MCP_router = APIRouter(
|
||||
prefix='/mcp',
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
@MCP_router.get(
|
||||
"/",
|
||||
)
|
||||
async def mcp_root(
|
||||
param: MCPRequestBase
|
||||
):
|
||||
match param.method:
|
||||
case MCPMethod.PING:
|
||||
return MCPResponseBase(result="pong", **param.model_dump())
|
||||
@@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -26,10 +27,11 @@ from models import (
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
ResponseBase,
|
||||
User,
|
||||
)
|
||||
from models import ResponseBase
|
||||
from service.storage import LocalStorageService
|
||||
from utils import http_exceptions
|
||||
|
||||
object_router = APIRouter(
|
||||
prefix="/object",
|
||||
@@ -59,15 +61,22 @@ async def _delete_object_recursive(
|
||||
"""
|
||||
deleted_count = 0
|
||||
|
||||
if obj.is_folder:
|
||||
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
|
||||
obj_id = obj.id
|
||||
obj_name = obj.name
|
||||
obj_is_folder = obj.is_folder
|
||||
obj_is_file = obj.is_file
|
||||
obj_physical_file_id = obj.physical_file_id
|
||||
|
||||
if obj_is_folder:
|
||||
# 递归删除子对象
|
||||
children = await Object.get_children(session, user_id, obj.id)
|
||||
children = await Object.get_children(session, user_id, obj_id)
|
||||
for child in children:
|
||||
deleted_count += await _delete_object_recursive(session, child, user_id)
|
||||
|
||||
# 如果是文件,处理物理文件引用
|
||||
if obj.is_file and obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if obj_is_file and obj_physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
|
||||
if physical_file:
|
||||
# 减少引用计数
|
||||
new_count = physical_file.decrement_reference()
|
||||
@@ -81,11 +90,11 @@ async def _delete_object_recursive(
|
||||
await storage_service.move_to_trash(
|
||||
source_path=physical_file.storage_path,
|
||||
user_id=user_id,
|
||||
object_id=obj.id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
l.debug(f"物理文件已移动到回收站: {obj.name}")
|
||||
l.debug(f"物理文件已移动到回收站: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
|
||||
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
|
||||
|
||||
# 删除 PhysicalFile 记录
|
||||
await PhysicalFile.delete(session, physical_file)
|
||||
@@ -95,8 +104,8 @@ async def _delete_object_recursive(
|
||||
await physical_file.save(session)
|
||||
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
|
||||
|
||||
# 删除数据库记录
|
||||
await Object.delete(session, obj)
|
||||
# 使用条件删除,避免访问过期的 obj 实例
|
||||
await Object.delete(session, condition=Object.id == obj_id)
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
@@ -168,6 +177,97 @@ async def _copy_object_recursive(
|
||||
return copied_count, new_ids
|
||||
|
||||
|
||||
@object_router.post(
|
||||
path='/',
|
||||
summary='创建空白文件',
|
||||
description='在指定目录下创建空白文件。',
|
||||
)
|
||||
async def router_object_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateFileRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
创建空白文件端点
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 创建文件请求(parent_id, name)
|
||||
:return: 创建结果
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 验证文件名
|
||||
if not request.name or '/' in request.name or '\\' in request.name:
|
||||
raise HTTPException(status_code=400, detail="无效的文件名")
|
||||
|
||||
# 验证父目录
|
||||
parent = await Object.get(session, Object.id == request.parent_id)
|
||||
if not parent or parent.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == request.name)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件已存在")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
parent_id = parent.id
|
||||
|
||||
# 生成存储路径并创建空文件
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=0,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=request.name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
await file_object.save(session)
|
||||
|
||||
l.info(f"创建空白文件: {request.name}")
|
||||
|
||||
return ResponseBase()
|
||||
|
||||
|
||||
@object_router.delete(
|
||||
path='/',
|
||||
summary='删除对象',
|
||||
@@ -197,10 +297,7 @@ async def router_object_delete(
|
||||
user_id = user.id
|
||||
deleted_count = 0
|
||||
|
||||
# 处理单个 UUID 或 UUID 列表
|
||||
ids = request.ids if isinstance(request.ids, list) else [request.ids]
|
||||
|
||||
for obj_id in ids:
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
if not obj or obj.owner_id != user_id:
|
||||
continue
|
||||
@@ -219,7 +316,7 @@ async def router_object_delete(
|
||||
return ResponseBase(
|
||||
data={
|
||||
"deleted": deleted_count,
|
||||
"total": len(ids),
|
||||
"total": len(request.ids),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -253,6 +350,9 @@ async def router_object_move(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
|
||||
dst_id = dst.id
|
||||
dst_parent_id = dst.parent_id
|
||||
@@ -264,6 +364,9 @@ async def router_object_move(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能移动根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -348,6 +451,9 @@ async def router_object_copy(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
|
||||
@@ -356,6 +462,9 @@ async def router_object_copy(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能复制根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -438,6 +547,9 @@ async def router_object_rename(
|
||||
if obj.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 不能重命名根目录
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="无法重命名根目录")
|
||||
@@ -543,7 +655,7 @@ async def router_object_property_detail(
|
||||
policy_name = policy.name if policy else None
|
||||
|
||||
# 获取分享统计
|
||||
from models import Share
|
||||
from sqlmodels import Share
|
||||
shares = await Share.get(
|
||||
session,
|
||||
Share.object_id == obj.id,
|
||||
|
||||
@@ -7,11 +7,11 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase
|
||||
from models.user import User
|
||||
from models.share import Share, ShareCreateRequest, ShareResponse
|
||||
from models.object import Object
|
||||
from models.mixin import ListResponse, TableViewRequest
|
||||
from sqlmodels import ResponseBase
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
|
||||
from sqlmodels.object import Object
|
||||
from sqlmodels.mixin import ListResponse, TableViewRequest
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase:
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/doc/{id}',
|
||||
summary='取得Office文档预览地址',
|
||||
description='Get Office document preview URL by ID.',
|
||||
)
|
||||
def router_share_doc(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get Office document preview URL by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the Office document.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the document preview URL.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/content/{id}',
|
||||
summary='获取文本文件内容',
|
||||
@@ -261,6 +244,9 @@ async def router_share_create(
|
||||
if not obj or obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 生成分享码
|
||||
code = str(uuid4())
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
site_router = APIRouter(
|
||||
@@ -43,16 +44,43 @@ def router_site_captcha():
|
||||
@site_router.get(
|
||||
path='/config',
|
||||
summary='站点全局配置',
|
||||
description='Get the configuration file.',
|
||||
response_model=ResponseBase,
|
||||
description='获取站点全局配置,包括验证码设置、注册开关等。',
|
||||
)
|
||||
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
"""
|
||||
Get the configuration file.
|
||||
获取站点全局配置
|
||||
|
||||
Returns:
|
||||
dict: The site configuration.
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码等配置。
|
||||
"""
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
# 构建 name→value 映射
|
||||
s: dict[str, str | None] = {item.name: item.value for item in settings}
|
||||
|
||||
# 根据 captcha_type 选择对应的 public key
|
||||
captcha_type_str = s.get("captcha_type", "default")
|
||||
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
|
||||
captcha_key: str | None = None
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
captcha_key = s.get("captcha_ReCaptchaKey") or None
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
login_captcha=s.get("login_captcha") == "1",
|
||||
reg_captcha=s.get("reg_captcha") == "1",
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
slave_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import auth_required
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
tag_router = APIRouter(
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
from typing import Annotated, Literal
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger
|
||||
|
||||
import models
|
||||
import service
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils.JWT import SECRET_KEY
|
||||
from utils import Password, http_exceptions
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
|
||||
user_router = APIRouter(
|
||||
prefix="/user",
|
||||
tags=["user"],
|
||||
)
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/user/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
@@ -34,7 +30,7 @@ user_settings_router = APIRouter(
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> models.TokenResponse:
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点。
|
||||
|
||||
@@ -43,7 +39,7 @@ async def router_user_session(
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
"""
|
||||
username = form_data.username
|
||||
email = form_data.username # OAuth2 表单字段名为 username,实际传入的是 email
|
||||
password = form_data.password
|
||||
|
||||
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||
@@ -59,8 +55,8 @@ async def router_user_session(
|
||||
|
||||
result = await service.user.login(
|
||||
session,
|
||||
models.LoginRequest(
|
||||
username=username,
|
||||
sqlmodels.LoginRequest(
|
||||
email=email,
|
||||
password=password,
|
||||
two_fa_code=otp_code,
|
||||
),
|
||||
@@ -75,19 +71,70 @@ async def router_user_session(
|
||||
)
|
||||
async def router_user_session_refresh(
|
||||
session: SessionDep,
|
||||
request, # RefreshTokenRequest
|
||||
) -> models.TokenResponse:
|
||||
http_exceptions.raise_not_implemented()
|
||||
request: sqlmodels.RefreshTokenRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
使用 refresh_token 签发新的 access_token 和 refresh_token。
|
||||
|
||||
流程:
|
||||
1. 解码 refresh_token JWT
|
||||
2. 验证 token_type 为 refresh
|
||||
3. 验证用户存在且状态正常
|
||||
4. 签发新的 access_token + refresh_token
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 刷新令牌请求
|
||||
:return: 新的 TokenResponse
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
|
||||
|
||||
# 验证是 refresh token
|
||||
if payload.get("token_type") != "refresh":
|
||||
http_exceptions.raise_unauthorized("非刷新令牌")
|
||||
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
http_exceptions.raise_unauthorized("令牌缺少用户标识")
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
if not user.status:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 签发新令牌
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return sqlmodels.TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
access_expires=access_token.access_expires,
|
||||
refresh_token=refresh_token.refresh_token,
|
||||
refresh_expires=refresh_token.refresh_expires,
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: models.RegisterRequest,
|
||||
) -> models.ResponseBase:
|
||||
request: sqlmodels.RegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
|
||||
@@ -95,7 +142,7 @@ async def router_user_register(
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建以用户名命名的根目录
|
||||
4. 创建用户根目录(name="/")
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
@@ -103,62 +150,53 @@ async def router_user_register(
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
"""
|
||||
# 1. 验证用户名唯一性
|
||||
existing_user = await models.User.get(
|
||||
# 1. 验证邮箱唯一性
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
models.User.username == request.username
|
||||
sqlmodels.User.email == request.email
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
default_group = await models.Group.get(session, models.Group.id == default_group_id)
|
||||
default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
|
||||
if not default_group:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
new_user = models.User(
|
||||
username=request.username,
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id # 在 save 前保存 UUID
|
||||
new_user_username = new_user.username
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建以用户名命名的根目录
|
||||
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
|
||||
# 4. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
await models.Object(
|
||||
name=new_user_username,
|
||||
type=models.ObjectType.FOLDER,
|
||||
await sqlmodels.Object(
|
||||
name="/",
|
||||
type=sqlmodels.ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"user_id": new_user_id,
|
||||
"username": new_user_username,
|
||||
},
|
||||
msg="注册成功",
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -166,7 +204,7 @@ async def router_user_register(
|
||||
)
|
||||
def router_user_email_code(
|
||||
reason: Literal['register', 'reset'] = 'register',
|
||||
) -> models.ResponseBase:
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
@@ -180,7 +218,7 @@ def router_user_email_code(
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> models.ResponseBase:
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
@@ -194,7 +232,7 @@ def router_user_qq() -> models.ResponseBase:
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@@ -203,7 +241,7 @@ async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
@@ -220,7 +258,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
summary='获取用户主页展示用分享',
|
||||
description='Get user profile for display.',
|
||||
)
|
||||
def router_user_profile(id: str) -> models.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
@@ -237,7 +275,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
|
||||
summary='获取用户头像',
|
||||
description='Get user avatar by ID and size.',
|
||||
)
|
||||
def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
@@ -259,12 +297,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
summary='获取用户信息',
|
||||
description='Get user information.',
|
||||
dependencies=[Depends(dependency=auth_required)],
|
||||
response_model=models.UserResponse,
|
||||
response_model=sqlmodels.UserResponse,
|
||||
)
|
||||
async def router_user_me(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserResponse:
|
||||
"""
|
||||
获取用户信息.
|
||||
|
||||
@@ -272,10 +310,10 @@ async def router_user_me(
|
||||
:rtype: ResponseBase
|
||||
"""
|
||||
# 加载 group 及其 options 关系
|
||||
group = await models.Group.get(
|
||||
group = await sqlmodels.Group.get(
|
||||
session,
|
||||
models.Group.id == user.group_id,
|
||||
load=models.Group.options
|
||||
sqlmodels.Group.id == user.group_id,
|
||||
load=sqlmodels.Group.options
|
||||
)
|
||||
|
||||
# 构建 GroupResponse
|
||||
@@ -284,9 +322,9 @@ async def router_user_me(
|
||||
# 异步加载 tags 关系
|
||||
user_tags = await user.awaitable_attrs.tags
|
||||
|
||||
return models.UserResponse(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
@@ -304,30 +342,26 @@ async def router_user_me(
|
||||
)
|
||||
async def router_user_storage(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserStorageResponse:
|
||||
"""
|
||||
获取用户存储空间信息。
|
||||
|
||||
返回值:
|
||||
- used: 已使用空间(字节)
|
||||
- free: 剩余空间(字节)
|
||||
- total: 总容量(字节)= 用户组容量
|
||||
"""
|
||||
# 获取用户组的基础存储容量
|
||||
group = await models.Group.get(session, models.Group.id == user.group_id)
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=500, detail="用户组不存在")
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"used": used,
|
||||
"free": free,
|
||||
"total": total,
|
||||
}
|
||||
return sqlmodels.UserStorageResponse(
|
||||
used=used,
|
||||
free=free,
|
||||
total=total,
|
||||
)
|
||||
|
||||
@user_router.put(
|
||||
@@ -338,8 +372,8 @@ async def router_user_storage(
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
|
||||
@@ -347,30 +381,30 @@ async def router_user_authn_start(
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await models.Setting.get(
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
|
||||
site_url_setting = await models.Setting.get(
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
|
||||
)
|
||||
site_title_setting = await models.Setting.get(
|
||||
site_title_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
|
||||
)
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.username,
|
||||
user_display_name=user.nick or user.username,
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
)
|
||||
|
||||
return models.ResponseBase(data=options_to_json_dict(options))
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
@@ -378,179 +412,11 @@ async def router_user_authn_start(
|
||||
description='Finish WebAuthn login for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> models.ResponseBase:
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> models.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings() -> models.ResponseBase:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return models.ResponseBase(data=models.UserSettingResponse().model_dump())
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> models.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> models.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return models.ResponseBase(
|
||||
data=await Password.generate_totp(user.username)
|
||||
)
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
http_exceptions.raise_not_implemented()
|
||||
203
routers/api/v1/user/settings/__init__.py
Normal file
203
routers/api/v1/user/settings/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils import JWT, Password, http_exceptions
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
)
|
||||
def router_user_settings(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserSettingResponse:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=user.two_factor is not None,
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data=await Password.generate_totp(user.email)
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
# WebDAV 管理路由
|
||||
|
||||
@@ -15,7 +15,7 @@ import aiofiles
|
||||
import aiofiles.os
|
||||
from loguru import logger as l
|
||||
|
||||
from models.policy import Policy
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import (
|
||||
DirectoryCreationError,
|
||||
FileReadError,
|
||||
|
||||
@@ -23,7 +23,7 @@ import string
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class NamingContext(SQLModelBase):
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import uuid4
|
||||
from loguru import logger
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import LoginRequest, TokenResponse, User
|
||||
from sqlmodels import LoginRequest, TokenResponse, User
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
@@ -30,17 +30,17 @@ async def login(
|
||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||
|
||||
# 获取用户信息
|
||||
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
|
||||
current_user: User = await User.get(session, User.email == login_request.email, fetch_mode="first") #type: ignore
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
logger.debug(f"Cannot find user with username: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
logger.debug(f"Cannot find user with email: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
logger.debug(f"Password verification failed for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证用户是否可登录
|
||||
if not current_user.status:
|
||||
@@ -50,23 +50,23 @@ async def login(
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
logger.debug(f"2FA required for user: {login_request.username}")
|
||||
logger.debug(f"2FA required for user: {login_request.email}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
refresh_token = create_refresh_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
access_token = create_access_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
RegisterRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
UserStorageResponse,
|
||||
UserPublic,
|
||||
UserResponse,
|
||||
UserSettingResponse,
|
||||
@@ -66,6 +69,7 @@ from .object import (
|
||||
FileBanRequest,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .uri import DiskNextURI, FileSystemNamespace
|
||||
from .order import Order, OrderStatus, OrderType
|
||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||
from .redeem import Redeem, RedeemType
|
||||
@@ -82,7 +86,7 @@ from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
|
||||
from .database import engine, get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
from .model_base import (
|
||||
MCPBase,
|
||||
@@ -630,7 +630,7 @@ For developers modifying this module:
|
||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
||||
|
||||
**Metaclass processing order**:
|
||||
1. Check if class should be a table (`_is_table_mixin`)
|
||||
1. Check if class should be a table (`_has_table_mixin`)
|
||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
||||
4. Resolve annotations using `get_type_hints()`
|
||||
@@ -5,8 +5,8 @@ SQLModel 基础模块
|
||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
||||
|
||||
注意:
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
|
||||
为了避免循环导入,此处不再重新导出它们
|
||||
请直接从 models.mixin 导入这些类
|
||||
请直接从 sqlmodels.mixin 导入这些类
|
||||
"""
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
@@ -414,7 +414,7 @@ class __DeclarativeMeta(SQLModelMetaclass):
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
# 1. 约定优于配置:自动设置 table=True
|
||||
is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases)
|
||||
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
|
||||
if is_intended_as_table and 'table' not in kwargs:
|
||||
kwargs['table'] = True
|
||||
|
||||
78
sqlmodels/database_connection.py
Normal file
78
sqlmodels/database_connection.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import AsyncGenerator, ClassVar
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
engine: ClassVar[AsyncEngine | None] = None
|
||||
_async_session_factory: ClassVar[sessionmaker | None] = None
|
||||
|
||||
@classmethod
|
||||
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
|
||||
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
|
||||
async with cls._async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
@classmethod
|
||||
async def init(
|
||||
cls,
|
||||
database_url: str,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化数据库连接引擎。
|
||||
|
||||
:param database_url: 数据库连接URL
|
||||
:param debug: 是否开启调试模式
|
||||
"""
|
||||
# 构建引擎参数
|
||||
engine_kwargs: dict = {
|
||||
'echo': debug,
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if debug:
|
||||
# Debug 模式使用 NullPool(无连接池,每次创建新连接)
|
||||
engine_kwargs['poolclass'] = NullPool
|
||||
else:
|
||||
# 生产模式使用 AsyncAdaptedQueuePool 连接池
|
||||
engine_kwargs.update({
|
||||
'poolclass': AsyncAdaptedQueuePool,
|
||||
'pool_size': 40,
|
||||
'max_overflow': 80,
|
||||
'pool_timeout': 30,
|
||||
'pool_recycle': 1800,
|
||||
'pool_pre_ping': True,
|
||||
})
|
||||
|
||||
# 只在需要时添加 connect_args
|
||||
if database_url.startswith("sqlite"):
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
cls.engine = create_async_engine(database_url, **engine_kwargs)
|
||||
|
||||
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
|
||||
|
||||
# 开发阶段直接 create_all 创建表结构
|
||||
async with cls.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("数据库引擎初始化完成")
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
"""
|
||||
优雅地关闭数据库连接引擎。
|
||||
仅应在应用结束时调用。
|
||||
"""
|
||||
if cls.engine:
|
||||
logger.info("正在关闭数据库连接引擎...")
|
||||
await cls.engine.dispose()
|
||||
logger.info("数据库连接引擎已成功关闭。")
|
||||
else:
|
||||
logger.info("数据库连接引擎未初始化,无需关闭。")
|
||||
@@ -104,9 +104,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsUseReCaptcha", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="defaultKey", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="defaultSecret", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
|
||||
Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
|
||||
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
|
||||
@@ -119,11 +121,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
@@ -139,11 +141,11 @@ async def init_default_group() -> None:
|
||||
from .group import Group, GroupOptions
|
||||
from .policy import Policy, GroupPolicyLink
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
default_policy_id = default_policy.id if default_policy else None
|
||||
@@ -231,13 +233,20 @@ async def init_default_user() -> None:
|
||||
from .group import Group
|
||||
from .object import Object, ObjectType
|
||||
from .policy import Policy
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.username == "admin")
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断)
|
||||
admin_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
admin_user = None
|
||||
if admin_id_setting and admin_id_setting.value:
|
||||
from uuid import UUID
|
||||
admin_user = await User.get(session, User.id == UUID(admin_id_setting.value))
|
||||
|
||||
if not admin_user:
|
||||
# 获取管理员组
|
||||
@@ -256,18 +265,24 @@ async def init_default_user() -> None:
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
nickname="admin",
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
admin_username = admin_user.username
|
||||
await admin_user.save(session)
|
||||
|
||||
# 为管理员创建根目录(使用用户名作为目录名)
|
||||
# 记录默认管理员 ID 到 Setting
|
||||
await Setting(
|
||||
name="default_admin_id",
|
||||
value=str(admin_user_id),
|
||||
type=SettingsType.AUTH,
|
||||
).save(session)
|
||||
|
||||
# 为管理员创建根目录
|
||||
await Object(
|
||||
name=admin_username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user_id,
|
||||
parent_id=None,
|
||||
@@ -275,18 +290,18 @@ async def init_default_user() -> None:
|
||||
).save(session)
|
||||
|
||||
log.warning('请注意,账号密码仅显示一次,请妥善保管')
|
||||
log.info(f'初始管理员账号: admin')
|
||||
log.info(f'初始管理员邮箱: admin@disknext.local')
|
||||
log.info(f'初始管理员密码: {admin_password}')
|
||||
|
||||
|
||||
async def init_default_policy() -> None:
|
||||
from .policy import Policy, PolicyType
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
log.info('初始化默认存储策略...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查默认存储策略是否存在
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
|
||||
@@ -5,42 +5,58 @@ SQLModel Mixin模块
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. table(依赖 polymorphic)
|
||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
||||
4. relation_preload(只依赖 SQLModelBase)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
register_sti_column_properties_for_all_subclasses,
|
||||
register_sti_columns_for_all_subclasses,
|
||||
)
|
||||
# table 依赖 polymorphic
|
||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
||||
from .optimistic_lock import (
|
||||
OptimisticLockError,
|
||||
OptimisticLockMixin,
|
||||
)
|
||||
# table 依赖 polymorphic 和 optimistic_lock
|
||||
from .table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
TableViewRequest,
|
||||
ListResponse,
|
||||
PaginationRequest,
|
||||
T,
|
||||
TableBaseMixin,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
UUIDTableBaseMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# relation_preload 只依赖 SQLModelBase
|
||||
from .relation_preload import (
|
||||
RelationPreloadMixin,
|
||||
requires_relations,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
IntIdInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
IntIdInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
)
|
||||
@@ -12,7 +12,7 @@ InfoResponse DTO Mixin模块
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
90
sqlmodels/mixin/optimistic_lock.py
Normal file
90
sqlmodels/mixin/optimistic_lock.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
||||
|
||||
乐观锁适用场景:
|
||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
||||
- 涉及"数值变动"的表(如:余额、库存)
|
||||
|
||||
不适用场景:
|
||||
- 日志表、纯插入表、低价值统计表
|
||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
||||
|
||||
使用示例:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
status: OrderStatusEnum
|
||||
amount: Decimal
|
||||
|
||||
# save/update 时自动检查版本号
|
||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
||||
try:
|
||||
order = await order.save(session)
|
||||
except OptimisticLockError as e:
|
||||
# 处理冲突:重新查询并重试,或报错给用户
|
||||
l.warning(f"乐观锁冲突: {e}")
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
|
||||
class OptimisticLockError(Exception):
|
||||
"""
|
||||
乐观锁冲突异常
|
||||
|
||||
当 save/update 操作检测到版本号不匹配时抛出。
|
||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
||||
|
||||
Attributes:
|
||||
model_class: 发生冲突的模型类名
|
||||
record_id: 记录 ID(如果可用)
|
||||
expected_version: 期望的版本号(如果可用)
|
||||
original_error: 原始的 StaleDataError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model_class: str | None = None,
|
||||
record_id: str | None = None,
|
||||
expected_version: int | None = None,
|
||||
original_error: StaleDataError | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.model_class = model_class
|
||||
self.record_id = record_id
|
||||
self.expected_version = expected_version
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OptimisticLockMixin:
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
||||
|
||||
原理:
|
||||
1. 每条记录有一个 version 字段,初始值为 0
|
||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
||||
|
||||
继承顺序:
|
||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
...
|
||||
|
||||
配套重试:
|
||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
||||
- 自动重试:重新查询最新数据并再次尝试
|
||||
"""
|
||||
_has_optimistic_lock: ClassVar[bool] = True
|
||||
"""标记此类启用了乐观锁"""
|
||||
|
||||
version: int = 0
|
||||
"""乐观锁版本号,每次更新自动递增"""
|
||||
710
sqlmodels/mixin/polymorphic.py
Normal file
710
sqlmodels/mixin/polymorphic.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import Column, String, inspect
|
||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
from sqlmodel.main import get_column_from_field
|
||||
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# 用于延迟注册 STI 子类列的队列
|
||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||||
_sti_subclasses_to_register: list[type] = []
|
||||
|
||||
|
||||
def register_sti_columns_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||||
|
||||
此函数应在 configure_mappers() 之前调用。
|
||||
将 STI 子类的字段添加到父表的 metadata 中。
|
||||
同时修复被 Column 对象污染的 model_fields。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_columns()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||||
|
||||
# 修复被 Column 对象污染的 model_fields
|
||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||||
try:
|
||||
_fix_polluted_model_fields(cls)
|
||||
except Exception as e:
|
||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||||
|
||||
|
||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||||
|
||||
此函数应在 configure_mappers() 之后调用。
|
||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_column_properties()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||||
|
||||
# 清空队列
|
||||
_sti_subclasses_to_register.clear()
|
||||
|
||||
|
||||
def _fix_polluted_model_fields(cls: type) -> None:
|
||||
"""
|
||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||||
|
||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||||
|
||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||||
|
||||
:param cls: 要修复的类
|
||||
"""
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制 SQLModel 特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
**重要:数据库迁移注意事项**
|
||||
|
||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||||
|
||||
例如,对于以下继承链::
|
||||
|
||||
LLM (polymorphic_on='_polymorphic_name')
|
||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||||
|
||||
迁移脚本中设置 _polymorphic_name 时::
|
||||
|
||||
# ❌ 错误:缺少父类前缀
|
||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||||
|
||||
# ✅ 正确:包含完整的继承链前缀
|
||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||||
|
||||
**STI(单表继承)支持**:
|
||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||||
注册到父表的问题。
|
||||
|
||||
Example (JTI):
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Example (STI):
|
||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||||
... user_id: UUID
|
||||
...
|
||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
# 注册 STI 子类列的延迟执行
|
||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||||
_sti_subclasses_to_register.append(cls)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
@classmethod
|
||||
def _register_sti_columns(cls) -> None:
|
||||
"""
|
||||
将STI子类的新字段注册到父表的列定义中
|
||||
|
||||
检测当前类是否是STI子类(与父类共用同一张表),
|
||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_fields: set[str] = set()
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
# 收集父类的所有字段名
|
||||
if hasattr(base, 'model_fields'):
|
||||
parent_fields.update(base.model_fields.keys())
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列注册
|
||||
|
||||
# 获取当前类的新字段(不在父类中的字段)
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
existing_columns = {col.name for col in parent_table.columns}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段和ClassVar
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 跳过已存在的列
|
||||
if field_name in existing_columns:
|
||||
continue
|
||||
|
||||
# 使用 SQLModel 的内置 API 创建列
|
||||
try:
|
||||
column = get_column_from_field(field_info)
|
||||
column.name = field_name
|
||||
column.key = field_name
|
||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||||
column.nullable = True
|
||||
|
||||
# 将列添加到父表
|
||||
parent_table.append_column(column)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def _register_sti_column_properties(cls) -> None:
|
||||
"""
|
||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||||
|
||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||||
|
||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_class = None
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
parent_class = base
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列属性注册
|
||||
|
||||
# 获取子类和父类的 mapper
|
||||
child_mapper = inspect(cls).mapper
|
||||
parent_mapper = inspect(parent_class).mapper
|
||||
local_table = child_mapper.local_table
|
||||
|
||||
# 查找父类的所有字段名
|
||||
parent_fields: set[str] = set()
|
||||
if hasattr(parent_class, 'model_fields'):
|
||||
parent_fields.update(parent_class.model_fields.keys())
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
# 获取两个 mapper 已有的列属性
|
||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||||
|
||||
for field_name in cls.model_fields:
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 检查表中是否有这个列
|
||||
if field_name not in local_table.columns:
|
||||
continue
|
||||
|
||||
column = local_table.columns[field_name]
|
||||
|
||||
# 添加到子类的 mapper(如果尚不存在)
|
||||
if field_name not in child_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
child_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||||
|
||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||||
if field_name not in parent_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
parent_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def _is_joined_table_inheritance(cls) -> bool:
|
||||
"""
|
||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||||
|
||||
通过检查子类是否有独立的表来判断:
|
||||
- JTI: 子类有独立的 local_table(与父类不同)
|
||||
- STI: 子类与父类共用同一个 local_table
|
||||
|
||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||||
"""
|
||||
mapper = inspect(cls)
|
||||
base_table_name = mapper.local_table.name
|
||||
|
||||
# 检查所有直接子类
|
||||
for subclass in cls.__subclasses__():
|
||||
sub_mapper = inspect(subclass)
|
||||
# 如果任何子类有不同的表名,说明是 JTI
|
||||
if sub_mapper.local_table.name != base_table_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
470
sqlmodels/mixin/relation_preload.py
Normal file
470
sqlmodels/mixin/relation_preload.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
@@ -12,7 +12,14 @@
|
||||
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
|
||||
↓
|
||||
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
|
||||
|
||||
维护须知:
|
||||
增删功能时必须更新 __version__ 字段(遵循语义化版本)
|
||||
|
||||
版本历史:
|
||||
0.1.0 - delete() 方法支持条件删除(condition 参数)
|
||||
"""
|
||||
__version__ = "0.1.0"
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
@@ -26,16 +33,19 @@ from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
# 未来: PR #1275合并后可改回继承SQLModelBase
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect
|
||||
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
from sqlmodel import Field, select
|
||||
|
||||
from .optimistic_lock import OptimisticLockError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from .polymorphic import PolymorphicBaseMixin
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# Type variables for generic type hints, improving code completion and analysis.
|
||||
T = TypeVar("T", bound="TableBaseMixin")
|
||||
@@ -196,8 +206,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
created_at (datetime): 记录创建时的时间戳, 自动设置.
|
||||
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
|
||||
"""
|
||||
_is_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类为表混入类的内部属性"""
|
||||
_has_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类继承了表混入类的内部属性"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
@@ -218,7 +228,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
向数据库中添加一个新的或多个新的记录.
|
||||
|
||||
@@ -230,8 +240,6 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||
@@ -246,11 +254,6 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 添加单个实例
|
||||
item3 = Item(name="Cherry")
|
||||
added_item = await Item.add(session, item3)
|
||||
|
||||
# 批量操作,减少提交次数
|
||||
await Item.add(session, [item1, item2], commit=False)
|
||||
await Item.add(session, [item3, item4], commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
@@ -259,10 +262,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
@@ -278,14 +278,16 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session: AsyncSession,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
commit: bool = True,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
optimistic_retry_count: int = 0,
|
||||
) -> T:
|
||||
"""
|
||||
保存(插入或更新)当前模型实例到数据库.
|
||||
|
||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||
预加载(eager load)一个或多个关联关系.
|
||||
预加载(eager load)一个关联关系.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
@@ -298,13 +300,17 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.save(session, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await item1.save(session, commit=False)
|
||||
await item2.save(session, commit=False)
|
||||
# ✅ 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.save(session, load=[User.group, User.tags])
|
||||
# ✅ 正确:保存后需要访问多态关系时
|
||||
tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all')
|
||||
return tool_set # tools 关系已正确加载子类数据
|
||||
|
||||
# ✅ 正确:启用乐观锁自动重试
|
||||
order = await order.save(session, optimistic_retry_count=3)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.save(session)
|
||||
@@ -313,34 +319,77 @@ class TableBaseMixin(AsyncAttrs):
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.posts` 或 `[User.group, User.tags]`.
|
||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
例如 `User.posts`.
|
||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
commit (bool): 是否在保存后提交事务。如果为 False,只会 flush 获取 ID
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类
|
||||
- None(默认): 不使用多态加载
|
||||
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
|
||||
重试时会重新查询最新数据,将当前修改合并后再次保存。
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
|
||||
Raises:
|
||||
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
|
||||
"""
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
cls = type(self)
|
||||
instance = self
|
||||
retries_remaining = optimistic_retry_count
|
||||
current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时
|
||||
|
||||
while True:
|
||||
session.add(instance)
|
||||
try:
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
break # 成功,退出循环
|
||||
except StaleDataError as e:
|
||||
await session.rollback()
|
||||
if retries_remaining <= 0:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(instance, 'id', None)),
|
||||
expected_version=getattr(instance, 'version', None),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
# 失败后重试:重新查询最新数据并合并修改
|
||||
retries_remaining -= 1
|
||||
if current_data is None:
|
||||
current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'})
|
||||
|
||||
fresh = await cls.get(session, cls.id == self.id)
|
||||
if fresh is None:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 重试失败:记录已被删除",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(self, 'id', None)),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
for key, value in current_data.items():
|
||||
if hasattr(fresh, key):
|
||||
setattr(fresh, key, value)
|
||||
instance = fresh
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
return instance
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
# 如果指定了 load, 重新获取实例并加载关联关系
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
await session.refresh(instance)
|
||||
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
@@ -351,7 +400,9 @@ class TableBaseMixin(AsyncAttrs):
|
||||
exclude: set[str] | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
commit: bool = True,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
optimistic_retry_count: int = 0,
|
||||
) -> T:
|
||||
"""
|
||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||
@@ -371,16 +422,20 @@ class TableBaseMixin(AsyncAttrs):
|
||||
user = await user.update(session, update_data, load=User.permission)
|
||||
return user
|
||||
|
||||
# ✅ 正确:更新后需要访问多态关系时
|
||||
tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all')
|
||||
return tool_set # tools 关系已正确加载子类数据
|
||||
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.update(session, update_data, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await user1.update(session, data1, commit=False)
|
||||
await user2.update(session, data2, commit=False)
|
||||
# ✅ 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.update(session, data, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.update(session, data, load=[User.group, User.tags])
|
||||
# ✅ 正确:启用乐观锁自动重试
|
||||
order = await order.update(session, update_data, optimistic_retry_count=3)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.update(session, update_data)
|
||||
@@ -394,111 +449,134 @@ class TableBaseMixin(AsyncAttrs):
|
||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||
的字段将被忽略. 默认为 True.
|
||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.permission` 或 `[User.group, User.tags]`.
|
||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
例如 `User.permission`.
|
||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
commit (bool): 是否在更新后提交事务。如果为 False,只会 flush
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类
|
||||
- None(默认): 不使用多态加载
|
||||
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
|
||||
重试时会重新查询最新数据,将 other 的更新重新应用后再次保存。
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
|
||||
update=extra_data
|
||||
)
|
||||
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
Raises:
|
||||
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
|
||||
"""
|
||||
cls = type(self)
|
||||
update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
|
||||
instance = self
|
||||
retries_remaining = optimistic_retry_count
|
||||
|
||||
while True:
|
||||
instance.sqlmodel_update(update_data, update=extra_data)
|
||||
session.add(instance)
|
||||
|
||||
try:
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
break # 成功,退出循环
|
||||
except StaleDataError as e:
|
||||
await session.rollback()
|
||||
if retries_remaining <= 0:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(instance, 'id', None)),
|
||||
expected_version=getattr(instance, 'version', None),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
# 失败后重试:重新查询最新数据并重新应用更新
|
||||
retries_remaining -= 1
|
||||
fresh = await cls.get(session, cls.id == self.id)
|
||||
if fresh is None:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 重试失败:记录已被删除",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(self, 'id', None)),
|
||||
original_error=e,
|
||||
) from e
|
||||
instance = fresh
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
return instance
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
await session.refresh(instance)
|
||||
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
从数据库中删除记录.
|
||||
|
||||
支持两种删除方式:
|
||||
1. 实例删除:传入 instances 参数,先加载再删除
|
||||
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
|
||||
从数据库中删除记录,支持实例删除和条件删除两种模式。
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
|
||||
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
session: 用于数据库操作的异步会话对象
|
||||
instances: 要删除的单个模型实例或模型实例列表(实例删除模式)
|
||||
condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE)
|
||||
commit: 是否在删除后提交事务。默认为 True
|
||||
|
||||
Returns:
|
||||
int: 删除的记录数量
|
||||
删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数)
|
||||
|
||||
Raises:
|
||||
ValueError: 同时提供 instances 和 condition,或两者都未提供
|
||||
|
||||
Usage:
|
||||
# 实例删除
|
||||
item_to_delete = await Item.get(session, Item.id == 1)
|
||||
if item_to_delete:
|
||||
deleted_count = await Item.delete(session, item_to_delete)
|
||||
# 实例删除模式
|
||||
item = await Item.get(session, Item.id == 1)
|
||||
if item:
|
||||
await Item.delete(session, item)
|
||||
|
||||
# 条件删除(更高效,无需加载实例)
|
||||
items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
|
||||
if items:
|
||||
await Item.delete(session, items)
|
||||
|
||||
# 条件删除模式(高效批量删除,不加载实例到内存)
|
||||
deleted_count = await Item.delete(
|
||||
session,
|
||||
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
|
||||
condition=(Item.user_id == user_id) & (Item.status == "expired"),
|
||||
)
|
||||
|
||||
# 批量删除后手动提交
|
||||
await Item.delete(session, item1, commit=False)
|
||||
await Item.delete(session, item2, commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
# 条件删除模式
|
||||
if condition is not None:
|
||||
from sqlmodel import delete as sql_delete
|
||||
|
||||
if instances is not None:
|
||||
raise ValueError("不能同时指定 instances 和 condition")
|
||||
|
||||
# 执行条件删除
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.exec(stmt)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
# 实例删除模式(原有逻辑)
|
||||
if instances is None:
|
||||
raise ValueError("必须指定 instances 或 condition")
|
||||
if instances is not None and condition is not None:
|
||||
raise ValueError("不能同时提供 instances 和 condition 参数")
|
||||
if instances is None and condition is None:
|
||||
raise ValueError("必须提供 instances 或 condition 参数之一")
|
||||
|
||||
deleted_count = 0
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count += 1
|
||||
|
||||
if condition is not None:
|
||||
# 条件删除模式:直接执行 SQL DELETE
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.execute(stmt)
|
||||
deleted_count = result.rowcount
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
# 实例删除模式
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count = len(instances)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
@@ -552,7 +630,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
with_for_update: bool = False,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
populate_existing: bool = False,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
@@ -581,8 +660,10 @@ class TableBaseMixin(AsyncAttrs):
|
||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||
例如 `[selectinload(User.posts)]`.
|
||||
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.profile` 或 `[User.group, User.tags]`.
|
||||
可以是单个关系或关系列表。支持嵌套关系预加载:
|
||||
当传入多个关系时,会自动检测依赖关系并构建链式 selectinload。
|
||||
例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]`
|
||||
会自动构建 `selectinload(element_links).selectinload(node)`。
|
||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||
@@ -593,11 +674,16 @@ class TableBaseMixin(AsyncAttrs):
|
||||
会覆盖offset、limit、order_by及时间筛选参数。
|
||||
这是推荐的分页排序方式,统一了所有LIST端点的参数格式。
|
||||
|
||||
load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
|
||||
- None(默认): 不使用多态加载
|
||||
|
||||
populate_existing (bool): 如果为 True,强制用数据库数据覆盖 session 中已存在的对象(identity map)。
|
||||
用于批量刷新对象,避免循环调用 session.refresh() 导致的 N 次查询。
|
||||
注意:只刷新标量字段,不影响运行时属性(_开头的属性)。
|
||||
对于 STI(单表继承)对象,推荐按子类分组查询以包含子类字段。默认为 False。
|
||||
|
||||
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
|
||||
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
|
||||
@@ -607,7 +693,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用.
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值,或 jti_subclasses 未与 load 配合使用.
|
||||
|
||||
Examples:
|
||||
# 使用table_view参数(推荐)
|
||||
@@ -621,13 +707,13 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # 只加载实际关联的子类
|
||||
jti_subclasses='all' # 只加载实际关联的子类
|
||||
)
|
||||
"""
|
||||
# 参数验证:load_polymorphic 需要与 load 配合使用
|
||||
if load_polymorphic is not None and load is None:
|
||||
# 参数验证:jti_subclasses 需要与 load 配合使用
|
||||
if jti_subclasses is not None and load is None:
|
||||
raise ValueError(
|
||||
"load_polymorphic 参数需要与 load 参数配合使用,"
|
||||
"jti_subclasses 参数需要与 load 参数配合使用,"
|
||||
"请同时指定要加载的关系"
|
||||
)
|
||||
|
||||
@@ -656,13 +742,34 @@ class TableBaseMixin(AsyncAttrs):
|
||||
|
||||
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
|
||||
# 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误)
|
||||
if issubclass(cls, PolymorphicBaseMixin):
|
||||
polymorphic_cls = None # 保存多态实体,用于子类关系预加载
|
||||
is_polymorphic = issubclass(cls, PolymorphicBaseMixin)
|
||||
is_jti = is_polymorphic and cls._is_joined_table_inheritance()
|
||||
is_sti = is_polymorphic and not cls._is_joined_table_inheritance()
|
||||
|
||||
# JTI 模式:总是使用 with_polymorphic(避免 N+1 查询)
|
||||
# STI 模式:不使用 with_polymorphic(批量刷新时请按子类分组查询)
|
||||
if is_jti:
|
||||
# '*' 表示加载所有子类
|
||||
polymorphic_cls = with_polymorphic(cls, '*')
|
||||
statement = select(polymorphic_cls)
|
||||
else:
|
||||
statement = select(cls)
|
||||
|
||||
# 对于 STI(单表继承)子类,自动添加多态过滤条件
|
||||
# SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤
|
||||
# 这是已知行为,参考:
|
||||
# - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤)
|
||||
# - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整)
|
||||
# 社区最佳实践是显式添加多态过滤条件
|
||||
if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance():
|
||||
mapper = inspect(cls)
|
||||
# 检查是否有 polymorphic_identity 且不是抽象类
|
||||
if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract:
|
||||
poly_on = mapper.polymorphic_on
|
||||
if poly_on is not None:
|
||||
statement = statement.where(poly_on == mapper.polymorphic_identity)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
@@ -688,12 +795,19 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 标准化为列表
|
||||
load_list = load if isinstance(load, list) else [load]
|
||||
|
||||
# 处理多态加载
|
||||
if load_polymorphic is not None:
|
||||
# 多态加载只支持单个关系
|
||||
if len(load_list) > 1:
|
||||
raise ValueError("load_polymorphic 仅支持单个关系")
|
||||
target_class = load_list[0].property.mapper.class_
|
||||
# 构建链式 selectinload(支持嵌套关系预加载)
|
||||
# 例如:load=[NodeGroupNode.element_links, NodeGroupElementLink.node]
|
||||
# 会构建:selectinload(element_links).selectinload(node)
|
||||
load_chains = cls._build_load_chains(load_list)
|
||||
|
||||
# 处理多态加载(仅支持单链且只有一个关系)
|
||||
if jti_subclasses is not None:
|
||||
if len(load_chains) > 1 or len(load_chains[0]) > 1:
|
||||
raise ValueError(
|
||||
"jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系"
|
||||
)
|
||||
single_load = load_chains[0][0]
|
||||
target_class = single_load.property.mapper.class_
|
||||
|
||||
# 检查目标类是否继承自 PolymorphicBaseMixin
|
||||
if not issubclass(target_class, PolymorphicBaseMixin):
|
||||
@@ -702,26 +816,48 @@ class TableBaseMixin(AsyncAttrs):
|
||||
f"请确保其继承自 PolymorphicBaseMixin"
|
||||
)
|
||||
|
||||
if load_polymorphic == 'all':
|
||||
if jti_subclasses == 'all':
|
||||
# 两阶段查询:获取实际关联的多态类型
|
||||
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
||||
session, condition, load_list[0], target_class
|
||||
session, condition, single_load, target_class
|
||||
)
|
||||
else:
|
||||
subclasses_to_load = load_polymorphic
|
||||
subclasses_to_load = jti_subclasses
|
||||
|
||||
if subclasses_to_load:
|
||||
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
||||
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
||||
statement = statement.options(
|
||||
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
|
||||
selectinload(single_load).selectin_polymorphic(subclasses_to_load)
|
||||
)
|
||||
else:
|
||||
statement = statement.options(selectinload(load_list[0]))
|
||||
statement = statement.options(selectinload(single_load))
|
||||
else:
|
||||
# 为每个关系添加 selectinload
|
||||
for rel in load_list:
|
||||
statement = statement.options(selectinload(rel))
|
||||
# 为每条链构建链式 selectinload
|
||||
for chain in load_chains:
|
||||
# 获取第一个关系并检查是否需要通过多态实体访问
|
||||
first_rel = chain[0]
|
||||
first_rel_parent = first_rel.property.parent.class_
|
||||
|
||||
# 如果关系的 parent_class 是当前类的子类(不是 cls 本身),
|
||||
# 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问
|
||||
if (
|
||||
polymorphic_cls is not None
|
||||
and first_rel_parent is not cls
|
||||
and issubclass(first_rel_parent, cls)
|
||||
):
|
||||
# 通过多态实体访问子类的关系属性
|
||||
# 例如:polymorphic_cls.NodeGroupNode.element_links
|
||||
subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__)
|
||||
rel_name = first_rel.key
|
||||
first_rel_via_poly = getattr(subclass_alias, rel_name)
|
||||
loader = selectinload(first_rel_via_poly)
|
||||
else:
|
||||
loader = selectinload(first_rel)
|
||||
|
||||
for rel in chain[1:]:
|
||||
loader = loader.selectinload(rel)
|
||||
statement = statement.options(loader)
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
@@ -736,7 +872,17 @@ class TableBaseMixin(AsyncAttrs):
|
||||
statement = statement.filter(filter)
|
||||
|
||||
if with_for_update:
|
||||
statement = statement.with_for_update()
|
||||
# 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制
|
||||
# PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE
|
||||
if issubclass(cls, PolymorphicBaseMixin):
|
||||
statement = statement.with_for_update(of=cls)
|
||||
else:
|
||||
statement = statement.with_for_update()
|
||||
|
||||
if populate_existing:
|
||||
# 强制用数据库数据覆盖 identity map 中的对象
|
||||
# 用于批量刷新,避免循环 refresh() 的 N 次查询
|
||||
statement = statement.execution_options(populate_existing=True)
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
@@ -749,6 +895,79 @@ class TableBaseMixin(AsyncAttrs):
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@staticmethod
|
||||
def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]:
|
||||
"""
|
||||
将关系列表构建为链式加载结构
|
||||
|
||||
自动检测关系之间的依赖关系,构建嵌套预加载链。
|
||||
例如:[NodeGroupNode.element_links, NodeGroupElementLink.node]
|
||||
会构建:[[element_links, node]](一条链)
|
||||
|
||||
算法:
|
||||
1. 获取每个关系的 parent class 和 target class
|
||||
2. 如果关系 B 的 parent class 等于关系 A 的 target class,则 B 链在 A 后面
|
||||
3. 独立的关系各自成为一条链
|
||||
|
||||
Args:
|
||||
load_list: 关系属性列表
|
||||
|
||||
Returns:
|
||||
链式关系列表,每条链是一个关系列表
|
||||
"""
|
||||
if not load_list:
|
||||
return []
|
||||
|
||||
# 构建关系信息:{关系: (parent_class, target_class)}
|
||||
rel_info: dict[RelationshipInfo, tuple[type, type]] = {}
|
||||
for rel in load_list:
|
||||
parent_class = rel.property.parent.class_
|
||||
target_class = rel.property.mapper.class_
|
||||
rel_info[rel] = (parent_class, target_class)
|
||||
|
||||
# 构建依赖图:{关系: 其前置关系}
|
||||
predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list}
|
||||
for rel_b in load_list:
|
||||
parent_b, _ = rel_info[rel_b]
|
||||
for rel_a in load_list:
|
||||
if rel_a is rel_b:
|
||||
continue
|
||||
_, target_a = rel_info[rel_a]
|
||||
# 如果 B 的 parent 精确等于 A 的 target,则 B 链在 A 后面
|
||||
# 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类)
|
||||
if parent_b is target_a:
|
||||
predecessors[rel_b] = rel_a
|
||||
break
|
||||
|
||||
# 找出所有链的起点(没有前置关系的)
|
||||
roots = [rel for rel, pred in predecessors.items() if pred is None]
|
||||
|
||||
# 构建链
|
||||
chains: list[list[RelationshipInfo]] = []
|
||||
used: set[RelationshipInfo] = set()
|
||||
|
||||
for root in roots:
|
||||
chain = [root]
|
||||
used.add(root)
|
||||
# 找后续节点
|
||||
current = root
|
||||
while True:
|
||||
# 找以 current 的 target 为 parent 的关系
|
||||
_, current_target = rel_info[current]
|
||||
next_rel = None
|
||||
for rel, (parent, _) in rel_info.items():
|
||||
if rel not in used and parent is current_target:
|
||||
next_rel = rel
|
||||
break
|
||||
if next_rel is None:
|
||||
break
|
||||
chain.append(next_rel)
|
||||
used.add(next_rel)
|
||||
current = next_rel
|
||||
chains.append(chain)
|
||||
|
||||
return chains
|
||||
|
||||
@classmethod
|
||||
async def _resolve_polymorphic_subclasses(
|
||||
cls: type[T],
|
||||
@@ -791,12 +1010,15 @@ class TableBaseMixin(AsyncAttrs):
|
||||
))
|
||||
)
|
||||
else:
|
||||
# 一对多关系:通过外键查询
|
||||
foreign_key_col = relationship_property.local_remote_pairs[0][1]
|
||||
# 多对一/一对多关系:通过外键查询
|
||||
# local_remote_pairs[0] = (local_fk_col, remote_pk_col)
|
||||
# 对于多对一:local 是当前类的外键,remote 是目标类的主键
|
||||
local_fk_col = relationship_property.local_remote_pairs[0][0]
|
||||
remote_pk_col = relationship_property.local_remote_pairs[0][1]
|
||||
type_query = (
|
||||
select(distinct(poly_name_col))
|
||||
.where(foreign_key_col.in_(
|
||||
select(cls.id).where(condition) if condition is not None else select(cls.id)
|
||||
.where(remote_pk_col.in_(
|
||||
select(local_fk_col).where(condition) if condition is not None else select(local_fk_col)
|
||||
))
|
||||
)
|
||||
|
||||
@@ -898,7 +1120,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
) -> 'ListResponse[T]':
|
||||
"""
|
||||
获取分页列表及总数,直接返回 ListResponse
|
||||
@@ -918,7 +1140,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by: 排序子句
|
||||
filter: 附加过滤条件
|
||||
table_view: 分页排序参数(推荐使用)
|
||||
load_polymorphic: 多态子类加载选项
|
||||
jti_subclasses: 多态子类加载选项
|
||||
|
||||
Returns:
|
||||
ListResponse[T]: 包含 count 和 items 的分页响应
|
||||
@@ -957,7 +1179,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
load_polymorphic=load_polymorphic,
|
||||
jti_subclasses=jti_subclasses,
|
||||
)
|
||||
|
||||
return ListResponse(count=total_count, items=items)
|
||||
@@ -973,8 +1195,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (int): 要查找的记录的主键 ID.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
@@ -1002,7 +1223,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
|
||||
"""
|
||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
@@ -1012,8 +1233,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
@@ -119,4 +119,5 @@ class MCPResponseBase(MCPBase):
|
||||
"""MCP 响应模型基础类"""
|
||||
|
||||
result: str
|
||||
"""方法返回结果"""
|
||||
"""方法返回结果"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from enum import StrEnum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from .source_link import SourceLink
|
||||
from .share import Share
|
||||
from .physical_file import PhysicalFile
|
||||
from .uri import DiskNextURI
|
||||
|
||||
|
||||
class ObjectType(StrEnum):
|
||||
@@ -103,7 +105,7 @@ class ObjectMoveRequest(SQLModelBase):
|
||||
class ObjectDeleteRequest(SQLModelBase):
|
||||
"""删除对象请求 DTO"""
|
||||
|
||||
ids: UUID | list[UUID]
|
||||
ids: list[UUID]
|
||||
"""待删除对象UUID列表"""
|
||||
|
||||
|
||||
@@ -116,12 +118,12 @@ class ObjectResponse(ObjectBase):
|
||||
thumb: bool = False
|
||||
"""是否有缩略图"""
|
||||
|
||||
date: datetime
|
||||
"""对象修改时间"""
|
||||
|
||||
create_date: datetime
|
||||
created_at: datetime
|
||||
"""对象创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""对象修改时间"""
|
||||
|
||||
source_enabled: bool = False
|
||||
"""是否启用离线下载源"""
|
||||
|
||||
@@ -138,7 +140,7 @@ class PolicyResponse(SQLModelBase):
|
||||
type: StorageType
|
||||
"""存储类型"""
|
||||
|
||||
max_size: int = Field(ge=0, default=0)
|
||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||
"""单文件最大限制,单位字节,0表示不限制"""
|
||||
|
||||
file_type: list[str] | None = None
|
||||
@@ -186,18 +188,18 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
|
||||
|
||||
根目录规则:
|
||||
- 每个用户有一个显式根目录对象(name=用户的username, parent_id=NULL)
|
||||
- 每个用户有一个显式根目录对象(name="/", parent_id=NULL)
|
||||
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
|
||||
- 根目录的 policy_id 指定用户默认存储策略
|
||||
- 路径格式:/username/path/to/file(如 /admin/docs/readme.md)
|
||||
- 路径格式:/path/to/file(如 /docs/readme.md),不包含用户名前缀
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
|
||||
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
|
||||
# 名称不能包含斜杠 ([TODO] 还有特殊字符)
|
||||
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/")
|
||||
CheckConstraint(
|
||||
"name NOT LIKE '%/%' AND name NOT LIKE '%\\%'",
|
||||
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
|
||||
name="ck_object_name_no_slash",
|
||||
),
|
||||
# 性能索引
|
||||
@@ -220,7 +222,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
|
||||
# ==================== 文件专属字段 ====================
|
||||
|
||||
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""文件大小(字节),目录为 0"""
|
||||
|
||||
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
@@ -374,15 +376,16 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
session,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
username: str,
|
||||
) -> "Object | None":
|
||||
"""
|
||||
根据路径获取对象
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 "/" 表示根目录,"/docs/images" 表示根目录下的 docs/images。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param path: 路径,如 "/username" 或 "/username/docs/images"
|
||||
:param username: 用户名,用于识别根目录
|
||||
:param path: 路径,如 "/" 或 "/docs/images"
|
||||
:return: Object 或 None
|
||||
"""
|
||||
path = path.strip()
|
||||
@@ -403,16 +406,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
if not parts:
|
||||
return root
|
||||
|
||||
# 检查第一部分是否是用户名(根目录名)
|
||||
if parts[0] == username:
|
||||
# 路径以用户名开头,如 /admin/docs
|
||||
if len(parts) == 1:
|
||||
# 只有用户名,返回根目录
|
||||
return root
|
||||
# 去掉用户名部分,从第二个部分开始遍历
|
||||
parts = parts[1:]
|
||||
|
||||
# 从根目录开始遍历剩余路径
|
||||
# 从根目录开始遍历路径
|
||||
current = root
|
||||
for part in parts:
|
||||
if not current:
|
||||
@@ -443,6 +437,77 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
fetch_mode="all"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def resolve_uri(
|
||||
cls,
|
||||
session,
|
||||
uri: "DiskNextURI",
|
||||
requesting_user_id: UUID | None = None,
|
||||
) -> "Object":
|
||||
"""
|
||||
将 URI 解析为 Object 实例
|
||||
|
||||
分派逻辑(类似 Cloudreve 的 getNavigator):
|
||||
- MY → user_id = uri.id(str(requesting_user_id))
|
||||
验证权限(自己的或管理员),然后 get_by_path
|
||||
- SHARE → 通过 uri.fs_id 查 Share 表,验证密码和有效期
|
||||
获取 share.object,然后沿 uri.path 遍历子对象
|
||||
- TRASH → 延后实现
|
||||
|
||||
:param session: 数据库会话
|
||||
:param uri: DiskNextURI 实例
|
||||
:param requesting_user_id: 请求用户UUID
|
||||
:return: Object 实例
|
||||
:raises ValueError: URI 无法解析
|
||||
:raises PermissionError: 权限不足
|
||||
:raises NotImplementedError: 不支持的命名空间
|
||||
"""
|
||||
from .uri import FileSystemNamespace
|
||||
|
||||
if uri.namespace == FileSystemNamespace.MY:
|
||||
# 确定目标用户
|
||||
target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None)
|
||||
if not target_user_id_str:
|
||||
raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id")
|
||||
|
||||
target_user_id = UUID(target_user_id_str)
|
||||
|
||||
# 权限检查:只能访问自己的空间(管理员权限由路由层判断)
|
||||
if requesting_user_id and target_user_id != requesting_user_id:
|
||||
raise PermissionError("无权访问其他用户的文件空间")
|
||||
|
||||
obj = await cls.get_by_path(session, target_user_id, uri.path)
|
||||
if not obj:
|
||||
raise ValueError(f"路径不存在: {uri.path}")
|
||||
return obj
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.SHARE:
|
||||
raise NotImplementedError("分享空间解析尚未实现")
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.TRASH:
|
||||
raise NotImplementedError("回收站解析尚未实现")
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的命名空间: {uri.namespace}")
|
||||
|
||||
async def get_full_path(self, session) -> str:
|
||||
"""
|
||||
从当前对象沿 parent_id 向上遍历到根目录,返回完整路径
|
||||
|
||||
:param session: 数据库会话
|
||||
:return: 完整路径,如 "/docs/images/photo.jpg"
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: Object | None = self
|
||||
|
||||
while current and current.parent_id is not None:
|
||||
parts.append(current.name)
|
||||
current = await Object.get(session, Object.id == current.parent_id)
|
||||
|
||||
# 反转顺序(从根到当前)
|
||||
parts.reverse()
|
||||
return "/" + "/".join(parts)
|
||||
|
||||
|
||||
# ==================== 上传会话模型 ====================
|
||||
|
||||
@@ -452,10 +517,10 @@ class UploadSessionBase(SQLModelBase):
|
||||
file_name: str = Field(max_length=255)
|
||||
"""原始文件名"""
|
||||
|
||||
file_size: int = Field(ge=0)
|
||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||
"""文件总大小(字节)"""
|
||||
|
||||
chunk_size: int = Field(ge=1)
|
||||
chunk_size: int = Field(ge=1, sa_type=BigInteger)
|
||||
"""分片大小(字节)"""
|
||||
|
||||
total_chunks: int = Field(ge=1)
|
||||
@@ -474,7 +539,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
||||
uploaded_chunks: int = 0
|
||||
"""已上传分片数"""
|
||||
|
||||
uploaded_size: int = 0
|
||||
uploaded_size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""已上传大小(字节)"""
|
||||
|
||||
storage_path: str | None = Field(default=None, max_length=512)
|
||||
@@ -680,8 +745,8 @@ class AdminFileResponse(ObjectResponse):
|
||||
owner_id: UUID
|
||||
"""所有者UUID"""
|
||||
|
||||
owner_username: str
|
||||
"""所有者用户名"""
|
||||
owner_email: str
|
||||
"""所有者邮箱"""
|
||||
|
||||
policy_name: str
|
||||
"""存储策略名称"""
|
||||
@@ -709,12 +774,12 @@ class AdminFileResponse(ObjectResponse):
|
||||
# ObjectResponse 字段
|
||||
id=obj.id,
|
||||
thumb=False,
|
||||
date=obj.updated_at,
|
||||
create_date=obj.created_at,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
source_enabled=False,
|
||||
# AdminFileResponse 字段
|
||||
owner_id=obj.owner_id,
|
||||
owner_username=owner.username if owner else "unknown",
|
||||
owner_email=owner.email if owner else "unknown",
|
||||
policy_name=policy.name if policy else "unknown",
|
||||
is_banned=obj.is_banned,
|
||||
banned_at=obj.banned_at,
|
||||
@@ -725,7 +790,7 @@ class AdminFileResponse(ObjectResponse):
|
||||
class FileBanRequest(SQLModelBase):
|
||||
"""文件封禁请求 DTO"""
|
||||
|
||||
is_banned: bool = True
|
||||
ban: bool = True
|
||||
"""是否封禁"""
|
||||
|
||||
reason: str | None = Field(default=None, max_length=500)
|
||||
@@ -12,6 +12,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
@@ -28,7 +29,7 @@ class PhysicalFileBase(SQLModelBase):
|
||||
storage_path: str = Field(max_length=512)
|
||||
"""物理存储路径(相对于存储策略根目录)"""
|
||||
|
||||
size: int = 0
|
||||
size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""文件大小(字节)"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
@@ -20,16 +20,10 @@ class SiteConfigResponse(SQLModelBase):
|
||||
title: str = "DiskNext"
|
||||
"""网站标题"""
|
||||
|
||||
# themes: dict[str, str] = {}
|
||||
# """网站主题配置"""
|
||||
|
||||
# default_theme: dict[str, str] = {}
|
||||
# """默认主题RGB色号"""
|
||||
|
||||
site_notice: str | None = None
|
||||
"""网站公告"""
|
||||
|
||||
user: UserResponse
|
||||
user: UserResponse | None = None
|
||||
"""用户信息"""
|
||||
|
||||
logo_light: str | None = None
|
||||
@@ -38,11 +32,23 @@ class SiteConfigResponse(SQLModelBase):
|
||||
logo_dark: str | None = None
|
||||
"""网站Logo URL(深色模式)"""
|
||||
|
||||
captcha_type: CaptchaType | None = None
|
||||
register_enabled: bool = True
|
||||
"""是否允许注册"""
|
||||
|
||||
login_captcha: bool = False
|
||||
"""登录是否需要验证码"""
|
||||
|
||||
reg_captcha: bool = False
|
||||
"""注册是否需要验证码"""
|
||||
|
||||
forget_captcha: bool = False
|
||||
"""找回密码是否需要验证码"""
|
||||
|
||||
captcha_type: CaptchaType = CaptchaType.DEFAULT
|
||||
"""验证码类型"""
|
||||
|
||||
captcha_key: str | None = None
|
||||
"""验证码密钥"""
|
||||
"""验证码 public key(DEFAULT 类型时为 None)"""
|
||||
|
||||
|
||||
# ==================== 管理员设置 DTO ====================
|
||||
@@ -215,6 +215,6 @@ class AdminShareListItem(ShareListItemBase):
|
||||
"""从 Share ORM 对象构建"""
|
||||
return cls(
|
||||
**ShareListItemBase.model_validate(share, from_attributes=True).model_dump(),
|
||||
username=user.username if user else None,
|
||||
username=user.email if user else None,
|
||||
object_name=obj.name if obj else None,
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class TaskSummary(TaskSummaryBase):
|
||||
"""从 Task ORM 对象构建"""
|
||||
return cls(
|
||||
**TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(),
|
||||
username=user.username if user else None,
|
||||
username=user.email if user else None,
|
||||
)
|
||||
|
||||
|
||||
258
sqlmodels/uri.py
Normal file
258
sqlmodels/uri.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
from enum import StrEnum
|
||||
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
|
||||
|
||||
from .base import SQLModelBase
|
||||
|
||||
|
||||
class FileSystemNamespace(StrEnum):
|
||||
"""文件系统命名空间"""
|
||||
MY = "my"
|
||||
"""用户个人空间"""
|
||||
|
||||
SHARE = "share"
|
||||
"""分享空间"""
|
||||
|
||||
TRASH = "trash"
|
||||
"""回收站"""
|
||||
|
||||
|
||||
class DiskNextURI(SQLModelBase):
|
||||
"""
|
||||
DiskNext 文件 URI
|
||||
|
||||
URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query]
|
||||
|
||||
fs_id 可省略:
|
||||
- my/trash 命名空间省略时默认当前用户
|
||||
- share 命名空间必须提供 fs_id(Share.code)
|
||||
"""
|
||||
|
||||
fs_id: str | None = None
|
||||
"""文件系统标识符,可省略"""
|
||||
|
||||
namespace: FileSystemNamespace
|
||||
"""命名空间"""
|
||||
|
||||
path: str = "/"
|
||||
"""路径"""
|
||||
|
||||
password: str | None = None
|
||||
"""访问密码(用于有密码的分享)"""
|
||||
|
||||
query: dict[str, str] | None = None
|
||||
"""查询参数"""
|
||||
|
||||
# === 属性 ===
|
||||
|
||||
@property
|
||||
def path_parts(self) -> list[str]:
|
||||
"""路径分割为列表(过滤空串)"""
|
||||
return [p for p in self.path.split("/") if p]
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""是否指向根目录"""
|
||||
return self.path.strip("/") == ""
|
||||
|
||||
# === 核心方法 ===
|
||||
|
||||
def id(self, default_id: str | None = None) -> str | None:
|
||||
"""
|
||||
获取 fs_id,省略时返回 default_id
|
||||
|
||||
参考 Cloudreve URI.ID(defaultUid) 方法
|
||||
|
||||
:param default_id: 默认值(通常为当前用户 ID)
|
||||
:return: fs_id 或 default_id
|
||||
"""
|
||||
return self.fs_id if self.fs_id else default_id
|
||||
|
||||
# === 类方法 ===
|
||||
|
||||
@classmethod
|
||||
def parse(cls, uri: str) -> "DiskNextURI":
|
||||
"""
|
||||
解析 URI 字符串
|
||||
|
||||
实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析
|
||||
- hostname → namespace
|
||||
- username → fs_id
|
||||
- password → password
|
||||
- path → path
|
||||
- query → query dict
|
||||
|
||||
:param uri: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
:return: DiskNextURI 实例
|
||||
:raises ValueError: URI 格式无效
|
||||
"""
|
||||
if not uri.startswith("disknext://"):
|
||||
raise ValueError(f"URI 必须以 disknext:// 开头: {uri}")
|
||||
|
||||
# 替换协议为 http:// 以利用 urllib.parse 解析
|
||||
http_uri = "http://" + uri[len("disknext://"):]
|
||||
parsed = urlparse(http_uri)
|
||||
|
||||
# 解析 namespace
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError(f"URI 缺少命名空间: {uri}")
|
||||
|
||||
try:
|
||||
namespace = FileSystemNamespace(hostname)
|
||||
except ValueError:
|
||||
raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}")
|
||||
|
||||
# 解析 fs_id 和 password
|
||||
fs_id = unquote(parsed.username) if parsed.username else None
|
||||
password = unquote(parsed.password) if parsed.password else None
|
||||
|
||||
# 解析 path
|
||||
path = unquote(parsed.path) if parsed.path else "/"
|
||||
if not path:
|
||||
path = "/"
|
||||
|
||||
# 解析 query
|
||||
query: dict[str, str] | None = None
|
||||
if parsed.query:
|
||||
raw_query = parse_qs(parsed.query, keep_blank_values=True)
|
||||
query = {k: v[0] for k, v in raw_query.items()}
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
query=query,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
namespace: FileSystemNamespace,
|
||||
path: str = "/",
|
||||
fs_id: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> "DiskNextURI":
|
||||
"""
|
||||
构建 URI 实例
|
||||
|
||||
:param namespace: 命名空间
|
||||
:param path: 路径
|
||||
:param fs_id: 文件系统标识符
|
||||
:param password: 访问密码
|
||||
:return: DiskNextURI 实例
|
||||
"""
|
||||
# 确保 path 以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
)
|
||||
|
||||
# === 实例方法 ===
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""
|
||||
序列化为 URI 字符串
|
||||
|
||||
:return: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
"""
|
||||
result = "disknext://"
|
||||
|
||||
# fs_id 和 password
|
||||
if self.fs_id:
|
||||
result += quote(self.fs_id, safe="")
|
||||
if self.password:
|
||||
result += ":" + quote(self.password, safe="")
|
||||
result += "@"
|
||||
|
||||
# namespace
|
||||
result += self.namespace.value
|
||||
|
||||
# path
|
||||
result += self.path
|
||||
|
||||
# query
|
||||
if self.query:
|
||||
result += "?" + urlencode(self.query)
|
||||
|
||||
return result
|
||||
|
||||
def join(self, *elements: str) -> "DiskNextURI":
|
||||
"""
|
||||
拼接路径元素,返回新 URI
|
||||
|
||||
:param elements: 路径元素
|
||||
:return: 新的 DiskNextURI 实例
|
||||
"""
|
||||
base = self.path.rstrip("/")
|
||||
for element in elements:
|
||||
element = element.strip("/")
|
||||
if element:
|
||||
base += "/" + element
|
||||
|
||||
if not base:
|
||||
base = "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=base,
|
||||
password=self.password,
|
||||
query=self.query,
|
||||
)
|
||||
|
||||
def dir_uri(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回父目录的 URI
|
||||
|
||||
:return: 父目录的 DiskNextURI 实例
|
||||
"""
|
||||
parts = self.path_parts
|
||||
if not parts:
|
||||
# 已经是根目录
|
||||
return self.root()
|
||||
|
||||
parent_path = "/" + "/".join(parts[:-1])
|
||||
if not parent_path.endswith("/"):
|
||||
parent_path += "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=parent_path,
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def root(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回根目录的 URI(保留 namespace 和 fs_id)
|
||||
|
||||
:return: 根目录的 DiskNextURI 实例
|
||||
"""
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path="/",
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
返回路径的最后一段(文件名或目录名)
|
||||
|
||||
:return: 文件名或目录名,根目录返回空字符串
|
||||
"""
|
||||
parts = self.path_parts
|
||||
return parts[-1] if parts else ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_string()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DiskNextURI({self.to_string()!r})"
|
||||
@@ -60,8 +60,8 @@ class UserFilterParams(SQLModelBase):
|
||||
group_id: UUID | None = None
|
||||
"""按用户组UUID筛选"""
|
||||
|
||||
username_contains: str | None = Field(default=None, max_length=50)
|
||||
"""用户名包含(不区分大小写的模糊搜索)"""
|
||||
email_contains: str | None = Field(default=None, max_length=50)
|
||||
"""邮箱包含(不区分大小写的模糊搜索)"""
|
||||
|
||||
nickname_contains: str | None = Field(default=None, max_length=50)
|
||||
"""昵称包含(不区分大小写的模糊搜索)"""
|
||||
@@ -75,8 +75,8 @@ class UserFilterParams(SQLModelBase):
|
||||
class UserBase(SQLModelBase):
|
||||
"""用户基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
username: str
|
||||
"""用户名"""
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
@@ -90,8 +90,8 @@ class UserBase(SQLModelBase):
|
||||
class LoginRequest(SQLModelBase):
|
||||
"""登录请求 DTO"""
|
||||
|
||||
username: str
|
||||
"""用户名或邮箱"""
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
@@ -106,8 +106,8 @@ class LoginRequest(SQLModelBase):
|
||||
class RegisterRequest(SQLModelBase):
|
||||
"""注册请求 DTO"""
|
||||
|
||||
username: str
|
||||
"""用户名,唯一,一经注册不可更改"""
|
||||
email: str
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
@@ -116,6 +116,20 @@ class RegisterRequest(SQLModelBase):
|
||||
"""验证码"""
|
||||
|
||||
|
||||
class BatchDeleteRequest(SQLModelBase):
|
||||
"""批量删除请求 DTO"""
|
||||
|
||||
ids: list[UUID]
|
||||
"""待删除 UUID 列表"""
|
||||
|
||||
|
||||
class RefreshTokenRequest(SQLModelBase):
|
||||
"""刷新令牌请求 DTO"""
|
||||
|
||||
refresh_token: str
|
||||
"""刷新令牌"""
|
||||
|
||||
|
||||
class WebAuthnInfo(SQLModelBase):
|
||||
"""WebAuthn 信息 DTO"""
|
||||
|
||||
@@ -166,6 +180,9 @@ class UserResponse(ResponseBase):
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""用户昵称"""
|
||||
|
||||
@@ -184,11 +201,23 @@ class UserResponse(ResponseBase):
|
||||
tags: list[str] = []
|
||||
"""用户标签列表"""
|
||||
|
||||
class UserStorageResponse(SQLModelBase):
|
||||
"""用户存储信息 DTO"""
|
||||
|
||||
used: int
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
free: int
|
||||
"""剩余存储空间(字节)"""
|
||||
|
||||
total: int
|
||||
"""总存储空间(字节)"""
|
||||
|
||||
|
||||
class UserPublic(UserBase):
|
||||
"""用户公开信息 DTO,用于 API 响应"""
|
||||
|
||||
id: UUID | None = None
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
nickname: str | None = None
|
||||
@@ -206,6 +235,9 @@ class UserPublic(UserBase):
|
||||
group_id: UUID | None = None
|
||||
"""所属用户组UUID"""
|
||||
|
||||
group_name: str | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,null 表示未启用)"""
|
||||
|
||||
@@ -219,29 +251,63 @@ class UserPublic(UserBase):
|
||||
class UserSettingResponse(SQLModelBase):
|
||||
"""用户设置响应 DTO"""
|
||||
|
||||
authn: "AuthnResponse | None" = None
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""昵称"""
|
||||
|
||||
created_at: datetime
|
||||
"""用户注册时间"""
|
||||
|
||||
group_name: str
|
||||
"""用户所属用户组名称"""
|
||||
|
||||
language: str
|
||||
"""语言偏好"""
|
||||
|
||||
timezone: int
|
||||
"""时区"""
|
||||
|
||||
authn: "list[AuthnResponse] | None" = None
|
||||
"""认证信息"""
|
||||
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
prefer_theme: str = "#5898d4"
|
||||
"""用户首选主题"""
|
||||
|
||||
themes: dict[str, str] = {}
|
||||
"""用户主题配置"""
|
||||
|
||||
two_factor: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
uid: UUID | None = None
|
||||
"""用户UUID"""
|
||||
|
||||
|
||||
# ==================== 管理员用户管理 DTO ====================
|
||||
|
||||
class UserAdminCreateRequest(SQLModelBase):
|
||||
"""管理员创建用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码(明文,由服务端加密)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
group_id: UUID
|
||||
"""所属用户组UUID"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
|
||||
class UserAdminUpdateRequest(SQLModelBase):
|
||||
"""管理员更新用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""邮箱"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
@@ -317,8 +383,8 @@ UserSettingResponse.model_rebuild()
|
||||
class User(UserBase, UUIDTableBaseMixin):
|
||||
"""用户模型"""
|
||||
|
||||
username: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户名,唯一,一经注册不可更改"""
|
||||
email: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
||||
@@ -426,8 +492,10 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
)
|
||||
|
||||
def to_public(self) -> "UserPublic":
|
||||
"""转换为公开 DTO,排除敏感字段"""
|
||||
return UserPublic.model_validate(self)
|
||||
"""转换为公开 DTO,排除敏感字段。需要预加载 group 关系。"""
|
||||
data = UserPublic.model_validate(self)
|
||||
data.group_name = self.group.name
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_with_count(
|
||||
@@ -457,8 +525,8 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
if filter_params.group_id is not None:
|
||||
filter_conditions.append(cls.group_id == filter_params.group_id)
|
||||
|
||||
if filter_params.username_contains is not None:
|
||||
filter_conditions.append(cls.username.ilike(f"%{filter_params.username_contains}%"))
|
||||
if filter_params.email_contains is not None:
|
||||
filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%"))
|
||||
|
||||
if filter_params.nickname_contains is not None:
|
||||
filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%"))
|
||||
@@ -482,4 +550,5 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -49,13 +49,13 @@ def main():
|
||||
("itsdangerous", "签名工具"),
|
||||
|
||||
# 项目模块
|
||||
("models", "数据库模型"),
|
||||
("models.user", "用户模型"),
|
||||
("models.group", "用户组模型"),
|
||||
("models.object", "对象模型"),
|
||||
("models.setting", "设置模型"),
|
||||
("models.policy", "策略模型"),
|
||||
("models.database", "数据库连接"),
|
||||
("sqlmodels", "数据库模型"),
|
||||
("sqlmodels.user", "用户模型"),
|
||||
("sqlmodels.group", "用户组模型"),
|
||||
("sqlmodels.object", "对象模型"),
|
||||
("sqlmodels.setting", "设置模型"),
|
||||
("sqlmodels.policy", "策略模型"),
|
||||
("sqlmodels.database", "数据库连接"),
|
||||
("utils.password.pwd", "密码工具"),
|
||||
("utils.JWT.JWT", "JWT 工具"),
|
||||
("service.user.login", "登录服务"),
|
||||
|
||||
@@ -23,12 +23,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from main import app
|
||||
from models.database import get_session
|
||||
from models.group import Group, GroupOptions
|
||||
from models.migration import migration
|
||||
from models.object import Object, ObjectType
|
||||
from models.policy import Policy, PolicyType
|
||||
from models.user import User
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.user import User
|
||||
from utils.JWT.JWT import create_access_token
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession):
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
创建测试用户并返回 {id, username, password, token}
|
||||
创建测试用户并返回 {id, email, password, token}
|
||||
|
||||
创建一个普通用户,包含用户组、存储策略和根目录。
|
||||
"""
|
||||
@@ -190,7 +190,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建测试用户
|
||||
password = "test_password_123"
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
@@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建用户根目录
|
||||
root_folder = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -216,7 +216,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"group_id": group.id,
|
||||
@@ -227,7 +227,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
获取管理员用户 {id, username, token}
|
||||
获取管理员用户 {id, email, token}
|
||||
|
||||
创建具有管理员权限的用户。
|
||||
"""
|
||||
@@ -267,7 +267,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建管理员用户
|
||||
password = "admin_password_456"
|
||||
admin = User(
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
@@ -279,7 +279,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建管理员根目录
|
||||
root_folder = Object(
|
||||
name=admin.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=admin.id,
|
||||
@@ -293,7 +293,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
return {
|
||||
"id": admin.id,
|
||||
"username": admin.username,
|
||||
"email": admin.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"group_id": admin_group.id,
|
||||
|
||||
@@ -8,9 +8,9 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from models.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from tests.fixtures import UserFactory, GroupFactory, ObjectFactory
|
||||
|
||||
|
||||
@@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession):
|
||||
user = await UserFactory.create(
|
||||
db_session,
|
||||
group_id=group.id,
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password="password123"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.group_id == group.id
|
||||
assert user.status is True
|
||||
|
||||
@@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession):
|
||||
async def test_object_factory(db_session: AsyncSession):
|
||||
"""测试对象工厂的基本功能"""
|
||||
# 准备依赖
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = await GroupFactory.create(db_session)
|
||||
user = await UserFactory.create(db_session, group_id=group.id)
|
||||
@@ -102,7 +102,7 @@ async def test_conftest_fixtures(
|
||||
"""测试 conftest.py 中的 fixtures"""
|
||||
# 验证 test_user fixture
|
||||
assert test_user["id"] is not None
|
||||
assert test_user["username"] == "testuser"
|
||||
assert test_user["email"] == "testuser@test.local"
|
||||
assert test_user["token"] is not None
|
||||
|
||||
# 验证 auth_headers fixture
|
||||
@@ -112,7 +112,7 @@ async def test_conftest_fixtures(
|
||||
# 验证用户在数据库中存在
|
||||
user = await User.get(db_session, User.id == test_user["id"])
|
||||
assert user is not None
|
||||
assert user.username == test_user["username"]
|
||||
assert user.email == test_user["email"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -145,7 +145,7 @@ async def test_test_directory_fixture(
|
||||
@pytest.mark.integration
|
||||
async def test_nested_structure_factory(db_session: AsyncSession):
|
||||
"""测试嵌套结构工厂"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
# 准备依赖
|
||||
group = await GroupFactory.create(db_session)
|
||||
|
||||
2
tests/fixtures/groups.py
vendored
2
tests/fixtures/groups.py
vendored
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
|
||||
|
||||
class GroupFactory:
|
||||
|
||||
6
tests/fixtures/objects.py
vendored
6
tests/fixtures/objects.py
vendored
@@ -7,8 +7,8 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
|
||||
|
||||
class ObjectFactory:
|
||||
@@ -119,7 +119,7 @@ class ObjectFactory:
|
||||
Object: 创建的根目录实例
|
||||
"""
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
|
||||
50
tests/fixtures/users.py
vendored
50
tests/fixtures/users.py
vendored
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import User
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class UserFactory:
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs
|
||||
) -> User:
|
||||
@@ -28,7 +28,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: test_user_{随机})
|
||||
email: 用户邮箱(默认: test_user_{随机}@test.local)
|
||||
password: 明文密码(默认: password123)
|
||||
**kwargs: 其他用户字段
|
||||
|
||||
@@ -37,15 +37,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"test_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"test_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
if password is None:
|
||||
password = "password123"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=kwargs.get("nickname", username),
|
||||
email=email,
|
||||
nickname=kwargs.get("nickname", email),
|
||||
password=Password.hash(password),
|
||||
status=kwargs.get("status", True),
|
||||
storage=kwargs.get("storage", 0),
|
||||
@@ -67,7 +67,7 @@ class UserFactory:
|
||||
async def create_admin(
|
||||
session: AsyncSession,
|
||||
admin_group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
admin_group_id: 管理员组UUID
|
||||
username: 用户名(默认: admin_{随机})
|
||||
email: 用户邮箱(默认: admin_{随机}@disknext.local)
|
||||
password: 明文密码(默认: admin_password)
|
||||
|
||||
返回:
|
||||
@@ -84,15 +84,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"admin_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local"
|
||||
|
||||
if password is None:
|
||||
password = "admin_password"
|
||||
|
||||
admin = User(
|
||||
username=username,
|
||||
nickname=f"管理员 {username}",
|
||||
email=email,
|
||||
nickname=f"管理员 {email}",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
storage=0,
|
||||
@@ -108,7 +108,7 @@ class UserFactory:
|
||||
async def create_banned(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建被封禁用户
|
||||
@@ -116,19 +116,19 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: banned_user_{随机})
|
||||
email: 用户邮箱(默认: banned_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的被封禁用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"banned_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
banned_user = User(
|
||||
username=username,
|
||||
nickname=f"封禁用户 {username}",
|
||||
email=email,
|
||||
nickname=f"封禁用户 {email}",
|
||||
password=Password.hash("banned_password"),
|
||||
status=False, # 封禁状态
|
||||
storage=0,
|
||||
@@ -145,7 +145,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
storage_bytes: int,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建已使用指定存储空间的用户
|
||||
@@ -154,19 +154,19 @@ class UserFactory:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
storage_bytes: 已使用的存储空间(字节)
|
||||
username: 用户名(默认: storage_user_{随机})
|
||||
email: 用户邮箱(默认: storage_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"storage_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
email=email,
|
||||
nickname=email,
|
||||
password=Password.hash("password123"),
|
||||
status=True,
|
||||
storage=storage_bytes,
|
||||
|
||||
@@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data(
|
||||
if len(users) > 0:
|
||||
user = users[0]
|
||||
assert "id" in user
|
||||
assert "username" in user
|
||||
assert "email" in user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient):
|
||||
"""测试创建用户需要认证"""
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin(
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
headers=auth_headers,
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from uuid import UUID
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取目录需要认证"""
|
||||
response = await async_client.get("/api/directory/testuser")
|
||||
response = await async_client.get("/api/directory/")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ async def test_directory_get_root(
|
||||
):
|
||||
"""测试获取用户根目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -45,7 +45,7 @@ async def test_directory_get_nested(
|
||||
):
|
||||
"""测试获取嵌套目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -63,7 +63,7 @@ async def test_directory_get_contains_children(
|
||||
):
|
||||
"""测试目录包含子对象"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -75,19 +75,6 @@ async def test_directory_get_contains_children(
|
||||
assert len(objects) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_forbidden_other_user(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试访问他人目录返回 403"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_not_found(
|
||||
async_client: AsyncClient,
|
||||
@@ -95,23 +82,23 @@ async def test_directory_not_found(
|
||||
):
|
||||
"""测试目录不存在返回 404"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/nonexistent",
|
||||
"/api/directory/nonexistent",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_empty_path_returns_400(
|
||||
async def test_directory_root_returns_200(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试空路径返回 400"""
|
||||
"""测试根目录端点返回 200"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -121,7 +108,7 @@ async def test_directory_response_includes_policy(
|
||||
):
|
||||
"""测试目录响应包含存储策略"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent(
|
||||
"""测试在他人目录下创建目录返回 404"""
|
||||
# 先用管理员账号获取管理员的根目录ID
|
||||
admin_response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
"/api/directory/",
|
||||
headers=admin_headers
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
|
||||
@@ -16,7 +16,7 @@ async def test_user_login_success(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": test_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -38,7 +38,7 @@ async def test_user_login_wrong_password(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": "wrongpassword",
|
||||
}
|
||||
)
|
||||
@@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": "nonexistent",
|
||||
"username": "nonexistent@test.local",
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -67,7 +67,7 @@ async def test_user_login_user_banned(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": banned_user_info["username"],
|
||||
"username": banned_user_info["email"],
|
||||
"password": banned_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": "newuser",
|
||||
"email": "newuser@test.local",
|
||||
"password": "newpass123",
|
||||
}
|
||||
)
|
||||
@@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "user_id" in data["data"]
|
||||
assert "username" in data["data"]
|
||||
assert data["data"]["username"] == "newuser"
|
||||
assert "email" in data["data"]
|
||||
assert data["data"]["email"] == "newuser@test.local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_register_duplicate_username(
|
||||
async def test_user_register_duplicate_email(
|
||||
async_client: AsyncClient,
|
||||
test_user_info: dict[str, str]
|
||||
):
|
||||
"""测试重复用户名返回 400"""
|
||||
"""测试重复邮箱返回 400"""
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": test_user_info["username"],
|
||||
"email": test_user_info["email"],
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -143,8 +143,8 @@ async def test_user_me_returns_user_info(
|
||||
assert "data" in data
|
||||
user_data = data["data"]
|
||||
assert "id" in user_data
|
||||
assert "username" in user_data
|
||||
assert user_data["username"] == "testuser"
|
||||
assert "email" in user_data
|
||||
assert user_data["email"] == "testuser@test.local"
|
||||
assert "group" in user_data
|
||||
assert "tags" in user_data
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from main import app
|
||||
from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
from utils.JWT import JWT
|
||||
@@ -92,6 +92,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.VIEW, name="home_view_method", value="list"),
|
||||
Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"),
|
||||
Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
|
||||
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
|
||||
@@ -180,7 +181,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 6. 创建测试用户
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=True,
|
||||
@@ -194,7 +195,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
admin_user = User(
|
||||
id=uuid4(),
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=True,
|
||||
@@ -208,7 +209,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
banned_user = User(
|
||||
id=uuid4(),
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=False, # 封禁状态
|
||||
@@ -230,7 +231,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 7. 创建用户根目录
|
||||
test_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=test_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=test_user.id,
|
||||
parent_id=None,
|
||||
@@ -241,7 +242,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
admin_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=admin_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user.id,
|
||||
parent_id=None,
|
||||
@@ -264,7 +265,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
def test_user_info() -> dict[str, str]:
|
||||
"""测试用户信息"""
|
||||
return {
|
||||
"username": "testuser",
|
||||
"email": "testuser@test.local",
|
||||
"password": "testpass123",
|
||||
}
|
||||
|
||||
@@ -273,7 +274,7 @@ def test_user_info() -> dict[str, str]:
|
||||
def admin_user_info() -> dict[str, str]:
|
||||
"""管理员用户信息"""
|
||||
return {
|
||||
"username": "admin",
|
||||
"email": "admin@disknext.local",
|
||||
"password": "adminpass123",
|
||||
}
|
||||
|
||||
@@ -282,7 +283,7 @@ def admin_user_info() -> dict[str, str]:
|
||||
def banned_user_info() -> dict[str, str]:
|
||||
"""封禁用户信息"""
|
||||
return {
|
||||
"username": "banneduser",
|
||||
"email": "banneduser@test.local",
|
||||
"password": "banned123",
|
||||
}
|
||||
|
||||
@@ -293,7 +294,7 @@ def banned_user_info() -> dict[str, str]:
|
||||
def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
"""生成测试用户的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
@@ -303,7 +304,7 @@ def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
"""生成管理员的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": admin_user_info["username"]},
|
||||
data={"sub": admin_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
@@ -313,7 +314,7 @@ def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
def expired_token() -> str:
|
||||
"""生成过期的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "testuser"},
|
||||
data={"sub": "testuser@test.local"},
|
||||
expires_delta=timedelta(seconds=-1), # 已过期
|
||||
)
|
||||
return token
|
||||
@@ -362,7 +363,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU
|
||||
"""创建测试目录结构"""
|
||||
|
||||
# 获取测试用户和根目录
|
||||
test_user = await User.get(initialized_db, User.username == "testuser")
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
test_user_root = await Object.get_root(initialized_db, test_user.id)
|
||||
|
||||
default_policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
||||
|
||||
@@ -83,7 +83,7 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
|
||||
"""测试用户不存在的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "nonexistent_user"},
|
||||
data={"sub": "nonexistent_user@test.local"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
|
||||
@@ -178,12 +178,12 @@ async def test_auth_on_directory_endpoint(
|
||||
):
|
||||
"""测试目录端点应用认证"""
|
||||
# 无认证
|
||||
response_no_auth = await async_client.get("/api/directory/testuser")
|
||||
response_no_auth = await async_client.get("/api/directory/")
|
||||
assert response_no_auth.status_code == 401
|
||||
|
||||
# 有认证
|
||||
response_with_auth = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response_with_auth.status_code == 200
|
||||
@@ -235,7 +235,7 @@ async def test_auth_on_storage_endpoint(
|
||||
async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
"""测试刷新token格式正确"""
|
||||
refresh_token, _ = JWT.create_refresh_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(days=7)
|
||||
)
|
||||
|
||||
@@ -247,7 +247,7 @@ async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
async def test_access_token_format(test_user_info: dict[str, str]):
|
||||
"""测试访问token格式正确"""
|
||||
access_token, expires = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_db():
|
||||
"""测试创建数据库结构"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
"""测试获取数据库连接Session"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@@ -20,8 +20,8 @@ async def db_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration():
|
||||
"""测试数据库创建并初始化配置"""
|
||||
from models import migration
|
||||
from models import database
|
||||
from sqlmodels import migration
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database
|
||||
from models.setting import Setting
|
||||
from sqlmodels import database
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from models.user import User
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.user import User
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
@@ -17,7 +17,7 @@ async def test_user_curd():
|
||||
created_group = await test_user_group.save(session)
|
||||
|
||||
test_user = User(
|
||||
username='test_user',
|
||||
email='test_user@test.local',
|
||||
password='test_password',
|
||||
group_id=created_group.id
|
||||
)
|
||||
@@ -27,7 +27,7 @@ async def test_user_curd():
|
||||
|
||||
# 验证用户是否存在
|
||||
assert created_user.id is not None
|
||||
assert created_user.username == 'test_user'
|
||||
assert created_user.email == 'test_user@test.local'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
@@ -35,18 +35,18 @@ async def test_user_curd():
|
||||
fetched_user = await User.get(session, User.id == created_user.id)
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.username == 'test_user'
|
||||
assert fetched_user.email == 'test_user@test.local'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await fetched_user.update(
|
||||
session,
|
||||
{"username": "updated_user", "password": "updated_password"}
|
||||
{"email": "updated_user@test.local", "password": "updated_password"}
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.username == 'updated_user'
|
||||
assert updated_user.email == 'updated_user@test.local'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 更新数据
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后名称")
|
||||
updated_group = await group.update(db_session, update_data)
|
||||
|
||||
@@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 更新记录
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后的名称")
|
||||
group = await group.update(db_session, update_data)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions, GroupResponse
|
||||
from sqlmodels.group import Group, GroupOptions, GroupResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -5,21 +5,21 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_folder(db_session: AsyncSession):
|
||||
"""测试创建目录"""
|
||||
# 创建必要的依赖数据
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_file(db_session: AsyncSession):
|
||||
"""测试创建文件"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=1024,
|
||||
source_name="test_source.txt"
|
||||
)
|
||||
file = await file.save(db_session)
|
||||
|
||||
@@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
assert file.name == "test.txt"
|
||||
assert file.type == ObjectType.FILE
|
||||
assert file.size == 1024
|
||||
assert file.source_name == "test_source.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_file_property(db_session: AsyncSession):
|
||||
"""测试 is_file 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
"""测试 is_folder 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_root(db_session: AsyncSession):
|
||||
"""测试 get_root() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="rootuser", password="password", group_id=group.id)
|
||||
user = User(email="rootuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
"""测试获取根目录"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
root = await root.save(db_session)
|
||||
|
||||
# 通过路径获取根目录
|
||||
result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username)
|
||||
result = await Object.get_by_path(db_session, user.id, "/")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == root.id
|
||||
@@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
"""测试获取嵌套路径"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="nesteduser", password="password", group_id=group.id)
|
||||
user = User(email="nesteduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
|
||||
# 创建目录结构: root -> docs -> work -> project
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/nesteduser/docs/work/project",
|
||||
user.username
|
||||
"/docs/work/project",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
@@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
"""测试路径不存在"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="notfounduser", password="password", group_id=group.id)
|
||||
user = User(email="notfounduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/notfounduser/nonexistent",
|
||||
user.username
|
||||
"/nonexistent",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_children(db_session: AsyncSession):
|
||||
"""测试 get_children() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="childrenuser", password="password", group_id=group.id)
|
||||
user = User(email="childrenuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
"""测试父子关系"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="reluser", password="password", group_id=group.id)
|
||||
user = User(email="reluser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
"""测试同目录名称唯一约束"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="uniqueuser", password="password", group_id=group.id)
|
||||
user = User(email="uniqueuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await file2.save(db_session)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_full_path(db_session: AsyncSession):
|
||||
"""测试 get_full_path() 方法"""
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
policy = await policy.save(db_session)
|
||||
|
||||
# 创建目录结构: root -> docs -> images -> photo.jpg
|
||||
root = Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
root = await root.save(db_session)
|
||||
|
||||
docs = Object(
|
||||
name="docs",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
docs = await docs.save(db_session)
|
||||
|
||||
images = Object(
|
||||
name="images",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=docs.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
images = await images.save(db_session)
|
||||
|
||||
photo = Object(
|
||||
name="photo.jpg",
|
||||
type=ObjectType.FILE,
|
||||
parent_id=images.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=2048
|
||||
)
|
||||
photo = await photo.save(db_session)
|
||||
|
||||
# 测试完整路径
|
||||
full_path = await photo.get_full_path(db_session)
|
||||
assert full_path == "/docs/images/photo.jpg"
|
||||
|
||||
# 测试根目录的 full_path
|
||||
root_path = await root.get_full_path(db_session)
|
||||
assert root_path == "/"
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.setting import Setting, SettingsType
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
|
||||
setting = await setting.save(db_session)
|
||||
|
||||
# 更新值
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class SettingUpdate(SQLModelBase):
|
||||
value: str | None = None
|
||||
|
||||
273
tests/unit/models/test_uri.py
Normal file
273
tests/unit/models/test_uri.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DiskNextURI 模型的单元测试
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from sqlmodels.uri import DiskNextURI, FileSystemNamespace
|
||||
|
||||
|
||||
class TestDiskNextURIParse:
|
||||
"""测试 URI 解析"""
|
||||
|
||||
def test_parse_my_root(self):
|
||||
"""测试解析个人空间根目录"""
|
||||
uri = DiskNextURI.parse("disknext://my/")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
assert uri.password is None
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_my_with_path(self):
|
||||
"""测试解析个人空间带路径"""
|
||||
uri = DiskNextURI.parse("disknext://my/docs/readme.md")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/docs/readme.md"
|
||||
assert uri.fs_id is None
|
||||
assert uri.path_parts == ["docs", "readme.md"]
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_parse_my_with_fs_id(self):
|
||||
"""测试解析带 fs_id 的个人空间"""
|
||||
uri = DiskNextURI.parse("disknext://some-uuid@my/docs")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.fs_id == "some-uuid"
|
||||
assert uri.path == "/docs"
|
||||
|
||||
def test_parse_share_with_code(self):
|
||||
"""测试解析分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123@share/")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.path == "/"
|
||||
assert uri.password is None
|
||||
|
||||
def test_parse_share_with_password(self):
|
||||
"""测试解析带密码的分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "mypass"
|
||||
assert uri.path == "/sub/dir"
|
||||
|
||||
def test_parse_trash(self):
|
||||
"""测试解析回收站"""
|
||||
uri = DiskNextURI.parse("disknext://trash/")
|
||||
assert uri.namespace == FileSystemNamespace.TRASH
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_with_query(self):
|
||||
"""测试解析带查询参数的 URI"""
|
||||
uri = DiskNextURI.parse("disknext://my/?name=report&type=file")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.query is not None
|
||||
assert uri.query["name"] == "report"
|
||||
assert uri.query["type"] == "file"
|
||||
|
||||
def test_parse_invalid_scheme(self):
|
||||
"""测试无效的协议前缀"""
|
||||
with pytest.raises(ValueError, match="disknext://"):
|
||||
DiskNextURI.parse("http://my/docs")
|
||||
|
||||
def test_parse_invalid_namespace(self):
|
||||
"""测试无效的命名空间"""
|
||||
with pytest.raises(ValueError, match="无效的命名空间"):
|
||||
DiskNextURI.parse("disknext://invalid/docs")
|
||||
|
||||
def test_parse_no_namespace(self):
|
||||
"""测试缺少命名空间"""
|
||||
with pytest.raises(ValueError):
|
||||
DiskNextURI.parse("disknext://")
|
||||
|
||||
|
||||
class TestDiskNextURIBuild:
|
||||
"""测试 URI 构建"""
|
||||
|
||||
def test_build_simple(self):
|
||||
"""测试简单构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
|
||||
def test_build_with_path(self):
|
||||
"""测试带路径构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_path_auto_prefix(self):
|
||||
"""测试路径自动添加 / 前缀"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_with_fs_id(self):
|
||||
"""测试带 fs_id 构建"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="abc123",
|
||||
password="secret",
|
||||
)
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "secret"
|
||||
|
||||
|
||||
class TestDiskNextURIToString:
|
||||
"""测试 URI 序列化"""
|
||||
|
||||
def test_to_string_simple(self):
|
||||
"""测试简单序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.to_string() == "disknext://my/"
|
||||
|
||||
def test_to_string_with_path(self):
|
||||
"""测试带路径序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.to_string() == "disknext://my/docs/readme.md"
|
||||
|
||||
def test_to_string_with_fs_id(self):
|
||||
"""测试带 fs_id 序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123")
|
||||
assert uri.to_string() == "disknext://uuid-123@my/"
|
||||
|
||||
def test_to_string_with_password(self):
|
||||
"""测试带密码序列化"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="code",
|
||||
password="pass",
|
||||
)
|
||||
assert uri.to_string() == "disknext://code:pass@share/"
|
||||
|
||||
def test_to_string_roundtrip(self):
|
||||
"""测试序列化-反序列化往返"""
|
||||
original = "disknext://abc123:pass@share/sub/dir"
|
||||
uri = DiskNextURI.parse(original)
|
||||
result = uri.to_string()
|
||||
assert result == original
|
||||
|
||||
|
||||
class TestDiskNextURIId:
|
||||
"""测试 id() 方法"""
|
||||
|
||||
def test_id_with_fs_id(self):
|
||||
"""测试有 fs_id 时返回 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid")
|
||||
assert uri.id("default") == "my-uuid"
|
||||
|
||||
def test_id_without_fs_id(self):
|
||||
"""测试无 fs_id 时返回默认值"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id("default-uuid") == "default-uuid"
|
||||
|
||||
def test_id_without_fs_id_no_default(self):
|
||||
"""测试无 fs_id 且无默认值时返回 None"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id() is None
|
||||
|
||||
|
||||
class TestDiskNextURIJoin:
|
||||
"""测试 join() 方法"""
|
||||
|
||||
def test_join_single(self):
|
||||
"""测试拼接单个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
joined = uri.join("readme.md")
|
||||
assert joined.path == "/docs/readme.md"
|
||||
|
||||
def test_join_multiple(self):
|
||||
"""测试拼接多个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
joined = uri.join("docs", "work", "report.pdf")
|
||||
assert joined.path == "/docs/work/report.pdf"
|
||||
|
||||
def test_join_preserves_metadata(self):
|
||||
"""测试 join 保留 namespace 和 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123")
|
||||
joined = uri.join("sub")
|
||||
assert joined.namespace == FileSystemNamespace.SHARE
|
||||
assert joined.fs_id == "code123"
|
||||
|
||||
|
||||
class TestDiskNextURIDirUri:
|
||||
"""测试 dir_uri() 方法"""
|
||||
|
||||
def test_dir_uri_file(self):
|
||||
"""测试获取文件的父目录 URI"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/docs/"
|
||||
|
||||
def test_dir_uri_root(self):
|
||||
"""测试根目录的 dir_uri 返回自身"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/"
|
||||
|
||||
|
||||
class TestDiskNextURIRoot:
|
||||
"""测试 root() 方法"""
|
||||
|
||||
def test_root_resets_path(self):
|
||||
"""测试 root 重置路径"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.MY,
|
||||
path="/docs/work/report.pdf",
|
||||
fs_id="uuid-123",
|
||||
)
|
||||
root = uri.root()
|
||||
assert root.path == "/"
|
||||
assert root.fs_id == "uuid-123"
|
||||
assert root.namespace == FileSystemNamespace.MY
|
||||
|
||||
|
||||
class TestDiskNextURIName:
|
||||
"""测试 name() 方法"""
|
||||
|
||||
def test_name_file(self):
|
||||
"""测试获取文件名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.name() == "readme.md"
|
||||
|
||||
def test_name_directory(self):
|
||||
"""测试获取目录名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work")
|
||||
assert uri.name() == "work"
|
||||
|
||||
def test_name_root(self):
|
||||
"""测试根目录的 name 返回空字符串"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.name() == ""
|
||||
|
||||
|
||||
class TestDiskNextURIProperties:
|
||||
"""测试属性方法"""
|
||||
|
||||
def test_path_parts(self):
|
||||
"""测试路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf")
|
||||
assert uri.path_parts == ["docs", "work", "report.pdf"]
|
||||
|
||||
def test_path_parts_root(self):
|
||||
"""测试根路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.path_parts == []
|
||||
|
||||
def test_is_root_true(self):
|
||||
"""测试 is_root 为真"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_is_root_false(self):
|
||||
"""测试 is_root 为假"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_str_representation(self):
|
||||
"""测试字符串表示"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert str(uri) == "disknext://my/docs"
|
||||
|
||||
def test_repr(self):
|
||||
"""测试 repr"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert "disknext://my/docs" in repr(uri)
|
||||
@@ -5,8 +5,8 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, ThemeType, UserPublic
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, ThemeType, UserPublic
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password="hashed_password",
|
||||
group_id=group.id
|
||||
@@ -26,7 +26,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.nickname == "测试用户"
|
||||
assert user.status is True
|
||||
assert user.storage == 0
|
||||
@@ -34,15 +34,15 @@ async def test_user_create(db_session: AsyncSession):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_unique_username(db_session: AsyncSession):
|
||||
"""测试用户名唯一约束"""
|
||||
async def test_user_unique_email(db_session: AsyncSession):
|
||||
"""测试邮箱唯一约束"""
|
||||
# 创建用户组
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 创建第一个用户
|
||||
user1 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password1",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession):
|
||||
|
||||
# 尝试创建同名用户
|
||||
user2 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password2",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="publicuser",
|
||||
email="publicuser@test.local",
|
||||
nickname="公开用户",
|
||||
password="secret_password",
|
||||
storage=1024,
|
||||
@@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
assert isinstance(public_user, UserPublic)
|
||||
assert public_user.id == user.id
|
||||
assert public_user.username == "publicuser"
|
||||
assert public_user.email == "publicuser@test.local"
|
||||
# 注意: UserPublic.nick 字段名与 User.nickname 不同,
|
||||
# model_validate 不会自动映射,所以 nick 为 None
|
||||
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
|
||||
@@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="vipuser",
|
||||
email="vipuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -125,7 +125,7 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="defaultuser",
|
||||
email="defaultuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="storageuser",
|
||||
email="storageuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试默认值
|
||||
user1 = User(
|
||||
username="user1",
|
||||
email="user1@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 LIGHT
|
||||
user2 = User(
|
||||
username="user2",
|
||||
email="user2@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.LIGHT,
|
||||
group_id=group.id
|
||||
@@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 DARK
|
||||
user3 = User(
|
||||
username="user3",
|
||||
email="user3@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.DARK,
|
||||
group_id=group.id
|
||||
|
||||
@@ -4,8 +4,8 @@ Login 服务的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, LoginRequest, TokenResponse
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -20,7 +20,7 @@ async def setup_user(db_session: AsyncSession):
|
||||
# 创建正常用户
|
||||
plain_password = "secure_password_123"
|
||||
user = User(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=True,
|
||||
group_id=group.id
|
||||
@@ -41,7 +41,7 @@ async def setup_banned_user(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=False, # 封禁状态
|
||||
group_id=group.id
|
||||
@@ -61,7 +61,7 @@ async def setup_2fa_user(db_session: AsyncSession):
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user = User(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=True,
|
||||
two_factor=secret,
|
||||
@@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_not_found(db_session: AsyncSession):
|
||||
"""测试用户不存在"""
|
||||
login_request = LoginRequest(
|
||||
username="nonexistent_user",
|
||||
email="nonexistent@test.local",
|
||||
password="any_password"
|
||||
)
|
||||
|
||||
@@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
|
||||
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
"""测试密码错误"""
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password="wrong_password"
|
||||
)
|
||||
|
||||
@@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
"""测试用户被封禁"""
|
||||
login_request = LoginRequest(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password="password"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"]
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
@@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
)
|
||||
@@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
valid_code = totp.now()
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code=valid_code
|
||||
)
|
||||
@@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user):
|
||||
"""测试用户名大小写敏感"""
|
||||
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱大小写敏感"""
|
||||
user_data = setup_user
|
||||
|
||||
# 使用大写用户名登录(如果数据库是 loginuser)
|
||||
# 使用大写邮箱登录
|
||||
login_request = LoginRequest(
|
||||
username="LOGINUSER",
|
||||
email="LOGINUSER@TEST.LOCAL",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
# 应该失败,因为用户名大小写不匹配
|
||||
# 应该失败,因为邮箱大小写不匹配
|
||||
assert result is None
|
||||
|
||||
@@ -72,9 +72,9 @@ def test_password_verify_expired():
|
||||
@pytest.mark.asyncio
|
||||
async def test_totp_generate():
|
||||
"""测试 TOTP 密钥生成"""
|
||||
username = "testuser"
|
||||
email = "testuser@test.local"
|
||||
|
||||
response = await Password.generate_totp(username)
|
||||
response = await Password.generate_totp(email)
|
||||
|
||||
assert response.setup_token is not None
|
||||
assert response.uri is not None
|
||||
@@ -82,7 +82,7 @@ async def test_totp_generate():
|
||||
assert isinstance(response.uri, str)
|
||||
# TOTP URI 格式: otpauth://totp/...
|
||||
assert response.uri.startswith("otpauth://totp/")
|
||||
assert username in response.uri
|
||||
assert email in response.uri
|
||||
|
||||
|
||||
def test_totp_verify_valid():
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import UUID, uuid4
|
||||
import jwt
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from models import AccessTokenBase, RefreshTokenBase
|
||||
from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
scheme_name='获取 JWT Bearer 令牌',
|
||||
@@ -21,8 +21,8 @@ async def load_secret_key() -> None:
|
||||
从数据库读取 JWT 的密钥。
|
||||
"""
|
||||
# 延迟导入以避免循环依赖
|
||||
from models.database import get_session
|
||||
from models.setting import Setting
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
global SECRET_KEY
|
||||
async for session in get_session():
|
||||
@@ -69,19 +69,29 @@ def build_token_payload(
|
||||
|
||||
# 访问令牌
|
||||
def create_access_token(
|
||||
data: dict,
|
||||
sub: UUID,
|
||||
jti: UUID,
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
algorithm: str = "HS256",
|
||||
**kwargs
|
||||
) -> AccessTokenBase:
|
||||
"""
|
||||
生成访问令牌,默认有效期 3 小时。
|
||||
|
||||
:param data: 需要放进 JWT Payload 的字段。
|
||||
:param sub: 令牌的主题,通常是用户 ID。
|
||||
:param jti: 令牌的唯一标识符,通常是一个 UUID。
|
||||
:param expires_delta: 过期时间, 缺省时为 3 小时。
|
||||
:param algorithm: JWT 密钥强度,缺省时为 HS256
|
||||
:param kwargs: 需要放进 JWT Payload 的字段。
|
||||
|
||||
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
||||
"""
|
||||
|
||||
data = {"sub": str(sub), "jti": str(jti)}
|
||||
|
||||
# 将额外的字段添加到 Payload 中
|
||||
for key, value in kwargs.items():
|
||||
data[key] = value
|
||||
|
||||
access_token, expire_at = build_token_payload(
|
||||
data,
|
||||
@@ -97,19 +107,29 @@ def create_access_token(
|
||||
|
||||
# 刷新令牌
|
||||
def create_refresh_token(
|
||||
data: dict,
|
||||
sub: UUID,
|
||||
jti: UUID,
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
algorithm: str = "HS256",
|
||||
**kwargs,
|
||||
) -> RefreshTokenBase:
|
||||
"""
|
||||
生成刷新令牌,默认有效期 30 天。
|
||||
|
||||
:param data: 需要放进 JWT Payload 的字段。
|
||||
:param sub: 令牌的主题,通常是用户 ID。
|
||||
:param jti: 令牌的唯一标识符,通常是一个 UUID。
|
||||
:param expires_delta: 过期时间, 缺省时为 30 天。
|
||||
:param algorithm: JWT 密钥强度,缺省时为 HS256
|
||||
:param kwargs: 需要放进 JWT Payload 的字段。
|
||||
|
||||
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
|
||||
"""
|
||||
|
||||
data = {"sub": str(sub), "jti": str(jti)}
|
||||
|
||||
# 将额外的字段添加到 Payload 中
|
||||
for key, value in kwargs.items():
|
||||
data[key] = value
|
||||
|
||||
refresh_token, expire_at = build_token_payload(
|
||||
data,
|
||||
|
||||
@@ -28,6 +28,10 @@ def raise_forbidden(detail: str | None = None, *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception."""
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
|
||||
|
||||
def raise_banned(detail: str = "此文件已被管理员封禁,仅允许删除操作", *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception for banned objects."""
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
|
||||
|
||||
def raise_not_found(detail: str | None = None, *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 404 Not Found exception."""
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail, *args, **kwargs)
|
||||
|
||||
@@ -73,6 +73,8 @@ class Password:
|
||||
|
||||
:param length: 密码长度
|
||||
:type length: int
|
||||
:param url_safe: 是否生成 URL 安全的密码
|
||||
:type url_safe: bool
|
||||
:return: 随机密码
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user