diff --git a/.claude/settings.local.json b/.claude/settings.local.json index b06f87e..5877287 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -3,7 +3,9 @@ "allow": [ "Bash(git rev-parse:*)", "Bash(findstr:*)", - "Bash(find:*)" + "Bash(find:*)", + "Bash(yarn tsc:*)", + "Bash(dir:*)" ] } } diff --git a/main.py b/main.py index 69edc60..8a72ca5 100644 --- a/main.py +++ b/main.py @@ -1,25 +1,29 @@ from typing import NoReturn from fastapi import FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware from utils.conf import appmeta from utils.http.http_exceptions import raise_internal_error from utils.lifespan import lifespan -from models.database import init_db -from models.migration import migration +from sqlmodels.database_connection import DatabaseManager +from sqlmodels.migration import migration from utils import JWT from routers import router from service.redis import RedisManager from loguru import logger as l +async def _init_db() -> None: + """初始化数据库连接引擎""" + await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug) + # 添加初始化数据库启动项 -lifespan.add_startup(init_db) +lifespan.add_startup(_init_db) lifespan.add_startup(migration) lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(RedisManager.connect) # 添加关闭项 +lifespan.add_shutdown(DatabaseManager.close) lifespan.add_shutdown(RedisManager.disconnect) # 创建应用实例并设置元数据 diff --git a/middleware/auth.py b/middleware/auth.py index 666fc9b..feceed1 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -4,7 +4,7 @@ from uuid import UUID from fastapi import Depends import jwt -from models.user import User +from sqlmodels.user import User from utils import JWT from .dependencies import SessionDep from utils import http_exceptions @@ -25,8 +25,8 @@ async def auth_required( user_id = UUID(user_id) - # 从数据库获取用户信息 - user = await User.get(session, User.id == user_id) + # 从数据库获取用户信息(预加载 group 关系) + user = await User.get(session, User.id == user_id, load=User.group) if not user: http_exceptions.raise_unauthorized("账号或密码错误") @@ -44,8 +44,7 @@ async def admin_required( 使用方法: >>> APIRouter(dependencies=[Depends(admin_required)]) """ - group = await user.awaitable_attrs.group - if group.admin: + if user.group.admin: return user raise http_exceptions.raise_forbidden("Admin Required") diff --git a/middleware/dependencies.py b/middleware/dependencies.py index bbd1f50..96cb19f 100644 --- a/middleware/dependencies.py +++ b/middleware/dependencies.py @@ -14,14 +14,14 @@ from uuid import UUID from fastapi import Depends, Query from sqlmodel.ext.asyncio.session import AsyncSession -from models.database import get_session -from models.mixin import TimeFilterRequest, TableViewRequest -from models.user import UserFilterParams, UserStatus +from sqlmodels.database_connection import DatabaseManager +from sqlmodels.mixin import TimeFilterRequest, TableViewRequest +from sqlmodels.user import UserFilterParams, UserStatus # --- 数据库会话依赖 --- -SessionDep: TypeAlias = Annotated[AsyncSession, Depends(get_session)] +SessionDep: TypeAlias = Annotated[AsyncSession, Depends(DatabaseManager.get_session)] """数据库会话依赖,用于路由函数中获取数据库会话""" @@ -79,14 +79,14 @@ TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends(_get_table_ async def _get_user_filter_params( group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None, - username: Annotated[str | None, Query(max_length=50, description="按用户名模糊搜索")] = None, + email: Annotated[str | None, Query(max_length=50, description="按邮箱模糊搜索")] = None, nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None, status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None, ) -> UserFilterParams: """解析用户过滤查询参数""" return UserFilterParams( group_id=group_id, - username_contains=username, + email_contains=email, nickname_contains=nickname, status=status, ) diff --git a/models/mixin/polymorphic.py b/models/mixin/polymorphic.py deleted file mode 100644 index 2f7f9c6..0000000 --- a/models/mixin/polymorphic.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -联表继承(Joined Table Inheritance)的通用工具 - -提供用于简化SQLModel多态表设计的辅助函数和Mixin。 - -Usage Example: - - from sqlmodels.base import SQLModelBase - from sqlmodels.mixin import UUIDTableBaseMixin - from sqlmodels.mixin.polymorphic import ( - PolymorphicBaseMixin, - create_subclass_id_mixin, - AutoPolymorphicIdentityMixin - ) - - # 1. 定义Base类(只有字段,无表) - class ASRBase(SQLModelBase): - name: str - \"\"\"配置名称\"\"\" - - base_url: str - \"\"\"服务地址\"\"\" - - # 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin - class ASR( - ASRBase, - UUIDTableBaseMixin, - PolymorphicBaseMixin, - ABC - ): - \"\"\"ASR配置的抽象基类\"\"\" - # PolymorphicBaseMixin 自动提供: - # - _polymorphic_name 字段 - # - polymorphic_on='_polymorphic_name' - # - polymorphic_abstract=True(当有抽象方法时) - - # 3. 为第二层子类创建ID Mixin - ASRSubclassIdMixin = create_subclass_id_mixin('asr') - - # 4. 创建第二层抽象类(如果需要) - class FunASR( - ASRSubclassIdMixin, - ASR, - AutoPolymorphicIdentityMixin, - polymorphic_abstract=True - ): - \"\"\"FunASR的抽象基类,可能有多个实现\"\"\" - pass - - # 5. 创建具体实现类 - class FunASRLocal(FunASR, table=True): - \"\"\"FunASR本地部署版本\"\"\" - # polymorphic_identity 会自动设置为 'asr.funasrlocal' - pass - - # 6. 获取所有具体子类(用于 selectin_polymorphic) - concrete_asrs = ASR.get_concrete_subclasses() - # 返回 [FunASRLocal, ...] -""" -import uuid -from abc import ABC -from uuid import UUID - -from pydantic.fields import FieldInfo -from pydantic_core import PydanticUndefined -from sqlalchemy import String, inspect -from sqlalchemy.orm import Mapped, mapped_column -from sqlalchemy.orm.attributes import InstrumentedAttribute -from sqlmodel import Field - -from models.base.sqlmodel_base import SQLModelBase - - -def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']: - """ - 动态创建SubclassIdMixin类 - - 在联表继承中,子类需要一个外键指向父表的主键。 - 此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。 - - Args: - parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function') - - Returns: - 一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4) - - Example: - >>> ASRSubclassIdMixin = create_subclass_id_mixin('asr') - >>> class FunASR(ASRSubclassIdMixin, ASR, table=True): - ... pass - - Note: - - 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id - - 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase) - - 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin) - """ - if not parent_table_name: - raise ValueError("parent_table_name 不能为空") - - # 转换为PascalCase作为类名 - class_name_parts = parent_table_name.split('_') - class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin' - - # 使用闭包捕获parent_table_name - _parent_table_name = parent_table_name - - # 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields - class SubclassIdMixin(SQLModelBase): - # 定义id字段 - id: UUID = Field( - default_factory=uuid.uuid4, - foreign_key=f'{_parent_table_name}.id', - primary_key=True, - ) - - @classmethod - def __pydantic_init_subclass__(cls, **kwargs): - """ - Pydantic v2 的子类初始化钩子,在模型完全构建后调用 - - 修复联表继承中子类字段的default_factory丢失问题。 - SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段, - 导致 INSERT 语句中出现 `table.column` 引用而非实际值。 - - 通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory, - 遵循单一真相原则(不硬编码 default_factory)。 - - 需要修复的字段: - - id: 主键(从父类获取 default_factory) - - created_at: 创建时间戳(从父类获取 default_factory) - - updated_at: 更新时间戳(从父类获取 default_factory) - """ - super().__pydantic_init_subclass__(**kwargs) - - if not hasattr(cls, 'model_fields'): - return - - def find_original_field_info(field_name: str) -> FieldInfo | None: - """从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)""" - for base in cls.__mro__[1:]: # 跳过自己 - if hasattr(base, 'model_fields') and field_name in base.model_fields: - field_info = base.model_fields[field_name] - # 跳过被 InstrumentedAttribute 污染的 - if not isinstance(field_info.default, InstrumentedAttribute): - return field_info - return None - - # 动态检测所有需要修复的字段 - # 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断: - # 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染) - # 2. 原始定义有 default_factory 或明确的 default 值 - # - # 覆盖场景: - # - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复 - # - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增) - # - created_at/updated_at:有 default_factory=now,需要修复 - # - 外键字段(created_by_id等):有 default=None,需要修复 - # - 普通字段(name, temperature等):无 default_factory,不需要修复 - # - # MRO 查找保证: - # - 在多重继承场景下,MRO 顺序是确定性的 - # - find_original_field_info 会找到第一个未被污染且有该字段的父类 - for field_name, current_field in cls.model_fields.items(): - # 检查是否被污染(default 是 InstrumentedAttribute) - if not isinstance(current_field.default, InstrumentedAttribute): - continue # 未被污染,跳过 - - # 从父类查找原始定义 - original = find_original_field_info(field_name) - if original is None: - continue # 找不到原始定义,跳过 - - # 根据原始定义的 default/default_factory 来修复 - if original.default_factory: - # 有 default_factory(如 uuid.uuid4, now) - new_field = FieldInfo( - default_factory=original.default_factory, - annotation=current_field.annotation, - json_schema_extra=current_field.json_schema_extra, - ) - elif original.default is not PydanticUndefined: - # 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined - # PydanticUndefined 表示字段没有默认值(必填) - new_field = FieldInfo( - default=original.default, - annotation=current_field.annotation, - json_schema_extra=current_field.json_schema_extra, - ) - else: - continue # 既没有 default_factory 也没有有效的 default,跳过 - - # 复制SQLModel特有的属性 - if hasattr(current_field, 'foreign_key'): - new_field.foreign_key = current_field.foreign_key - if hasattr(current_field, 'primary_key'): - new_field.primary_key = current_field.primary_key - - cls.model_fields[field_name] = new_field - - # 设置类名和文档 - SubclassIdMixin.__name__ = class_name - SubclassIdMixin.__qualname__ = class_name - SubclassIdMixin.__doc__ = f""" - {parent_table_name}子类的ID Mixin - - 用于{parent_table_name}的子类,提供外键指向父表。 - 通过MRO确保此id字段覆盖继承的id字段。 - """ - - return SubclassIdMixin - - -class AutoPolymorphicIdentityMixin: - """ - 自动生成polymorphic_identity的Mixin - - 使用此Mixin的类会自动根据类名生成polymorphic_identity。 - 格式:{parent_polymorphic_identity}.{classname_lowercase} - - 如果没有父类的polymorphic_identity,则直接使用类名小写。 - - Example: - >>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True): - ... __polymorphic_name: str - ... - >>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True): - ... pass - ... # polymorphic_identity 会自动设置为 'function' - ... - >>> class CodeInterpreterFunction(Function, table=True): - ... pass - ... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction' - - Note: - - 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留 - - 此Mixin应该在继承列表中靠后的位置(在表基类之前) - """ - - def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs): - """ - 子类化钩子,自动生成polymorphic_identity - - Args: - polymorphic_identity: 如果手动指定,则使用指定的值 - **kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True) - """ - super().__init_subclass__(**kwargs) - - # 如果手动指定了polymorphic_identity,使用指定的值 - if polymorphic_identity is not None: - identity = polymorphic_identity - else: - # 自动生成polymorphic_identity - class_name = cls.__name__.lower() - - # 尝试从父类获取polymorphic_identity作为前缀 - parent_identity = None - for base in cls.__mro__[1:]: # 跳过自己 - if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict): - parent_identity = base.__mapper_args__.get('polymorphic_identity') - if parent_identity: - break - - # 构建identity - if parent_identity: - identity = f'{parent_identity}.{class_name}' - else: - identity = class_name - - # 设置到__mapper_args__ - if '__mapper_args__' not in cls.__dict__: - cls.__mapper_args__ = {} - - # 只在尚未设置polymorphic_identity时设置 - if 'polymorphic_identity' not in cls.__mapper_args__: - cls.__mapper_args__['polymorphic_identity'] = identity - - -class PolymorphicBaseMixin: - """ - 为联表继承链中的基类自动配置 polymorphic 设置的 Mixin - - 此 Mixin 自动设置以下内容: - - `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器 - - `_polymorphic_name: str`: 定义多态鉴别器字段(带索引) - - `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类 - - 使用场景: - 适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。 - - 用法示例: - ```python - from abc import ABC - from sqlmodels.mixin import UUIDTableBaseMixin - from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin - - # 定义基类 - class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC): - __tablename__ = 'mytool' - - # 不需要手动定义 _polymorphic_name - # 不需要手动设置 polymorphic_on - # 不需要手动设置 polymorphic_abstract - - # 定义子类 - class SpecificTool(MyTool): - __tablename__ = 'specifictool' - - # 会自动继承 polymorphic 配置 - ``` - - 自动行为: - 1. 定义 `_polymorphic_name: str` 字段(带索引) - 2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'` - 3. 自动检测抽象类: - - 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True - - 否则设置为 False - - 手动覆盖: - 可以在类定义时手动指定参数来覆盖自动行为: - ```python - class MyTool( - UUIDTableBaseMixin, - PolymorphicBaseMixin, - ABC, - polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name - polymorphic_abstract=False # 强制不设为抽象类 - ): - pass - ``` - - 注意事项: - - 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用 - - 适用于联表继承(joined table inheritance)场景 - - 子类会自动继承 _polymorphic_name 字段定义 - - 使用单下划线前缀是因为: - * SQLAlchemy 会映射单下划线字段为数据库列 - * Pydantic 将其视为私有属性,不参与序列化 - * 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列 - """ - - # 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段 - # - # 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column - # - # 为什么这样做: - # 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改 - # 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀) - # 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用 - # 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema) - # 5. 内部代码仍然可以正常访问和修改此字段 - # - # 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md - _polymorphic_name: Mapped[str] = mapped_column(String, index=True) - """ - 多态鉴别器字段,用于标识具体的子类类型 - - 注意:此字段使用单下划线前缀,表示内部使用。 - - ✅ 存储到数据库 - - ✅ 不出现在 API 序列化中 - - ✅ 防止外部直接修改 - """ - - def __init_subclass__( - cls, - polymorphic_on: str | None = None, - polymorphic_abstract: bool | None = None, - **kwargs - ): - """ - 在子类定义时自动配置 polymorphic 设置 - - Args: - polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。 - 设置为其他值可以使用不同的字段作为多态鉴别器。 - polymorphic_abstract: 是否为抽象类。 - - None: 自动检测(默认) - - True: 强制设为抽象类 - - False: 强制设为非抽象类 - **kwargs: 传递给父类的其他参数 - """ - super().__init_subclass__(**kwargs) - - # 初始化 __mapper_args__(如果还没有) - if '__mapper_args__' not in cls.__dict__: - cls.__mapper_args__ = {} - - # 设置 polymorphic_on(默认为 _polymorphic_name) - if 'polymorphic_on' not in cls.__mapper_args__: - cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name' - - # 自动检测或设置 polymorphic_abstract - if 'polymorphic_abstract' not in cls.__mapper_args__: - if polymorphic_abstract is None: - # 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类 - has_abc = ABC in cls.__mro__ - has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set())) - polymorphic_abstract = has_abc and has_abstract_methods - - cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract - - @classmethod - def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]: - """ - 递归获取当前类的所有具体(非抽象)子类 - - 用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。 - 可在任意多态基类上调用,返回该类的所有非抽象子类。 - - :return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类) - """ - result: list[type[PolymorphicBaseMixin]] = [] - for subclass in cls.__subclasses__(): - # 使用 inspect() 获取 mapper 的公开属性 - # 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811) - mapper = inspect(subclass) - if not mapper.polymorphic_abstract: - result.append(subclass) - # 无论是否抽象,都需要递归(抽象类可能有具体子类) - if hasattr(subclass, 'get_concrete_subclasses'): - result.extend(subclass.get_concrete_subclasses()) - return result - - @classmethod - def get_polymorphic_discriminator(cls) -> str: - """ - 获取多态鉴别字段名 - - 使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。 - - :return: 多态鉴别字段名(如 '_polymorphic_name') - :raises ValueError: 如果类未配置 polymorphic_on - """ - polymorphic_on = inspect(cls).polymorphic_on - if polymorphic_on is None: - raise ValueError( - f"{cls.__name__} 未配置 polymorphic_on," - f"请确保正确继承 PolymorphicBaseMixin" - ) - return polymorphic_on.key - - @classmethod - def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]: - """ - 获取 polymorphic_identity 到具体子类的映射 - - 包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。 - - :return: identity 到子类的映射字典 - """ - result: dict[str, type[PolymorphicBaseMixin]] = {} - for subclass in cls.get_concrete_subclasses(): - identity = inspect(subclass).polymorphic_identity - if identity: - result[identity] = subclass - return result diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py index 8898593..49333a5 100644 --- a/routers/api/v1/admin/__init__.py +++ b/routers/api/v1/admin/__init__.py @@ -5,15 +5,15 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep -from models import ( +from sqlmodels import ( User, ResponseBase, Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo, ) -from models.base import SQLModelBase -from models.setting import ( +from sqlmodels.base import SQLModelBase +from sqlmodels.setting import ( SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, ) -from models.setting import SettingsType +from sqlmodels.setting import SettingsType from utils import http_exceptions from utils.conf import appmeta from .file import admin_file_router diff --git a/routers/api/v1/admin/file/__init__.py b/routers/api/v1/admin/file/__init__.py index df53574..a7e8fd1 100644 --- a/routers/api/v1/admin/file/__init__.py +++ b/routers/api/v1/admin/file/__init__.py @@ -5,14 +5,60 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import FileResponse from loguru import logger as l +from sqlmodel.ext.asyncio.session import AsyncSession from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep -from models import ( - Policy, PolicyType, User, ResponseBase, ListResponse, +from sqlmodels import ( + Policy, PolicyType, User, ListResponse, Object, ObjectType, AdminFileResponse, FileBanRequest, ) from service.storage import LocalStorageService +async def _set_ban_recursive( + session: AsyncSession, + obj: Object, + ban: bool, + admin_id: UUID, + reason: str | None, +) -> int: + """ + 递归设置封禁状态,返回受影响对象数量。 + + :param session: 数据库会话 + :param obj: 要封禁/解禁的对象 + :param ban: True=封禁, False=解禁 + :param admin_id: 管理员UUID + :param reason: 封禁原因 + :return: 受影响的对象数量 + """ + count = 0 + + # 如果是文件夹,先递归处理子对象 + if obj.is_folder: + children = await Object.get( + session, + Object.parent_id == obj.id, + fetch_mode="all", + ) + for child in children: + count += await _set_ban_recursive(session, child, ban, admin_id, reason) + + # 设置当前对象 + obj.is_banned = ban + if ban: + obj.banned_at = datetime.now() + obj.banned_by = admin_id + obj.ban_reason = reason + else: + obj.banned_at = None + obj.banned_by = None + obj.ban_reason = None + + await obj.save(session) + count += 1 + return count + + admin_file_router = APIRouter( prefix="/file", tags=["admin", "admin_file"], @@ -119,15 +165,17 @@ async def router_admin_preview_file( summary='封禁/解禁文件', description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.', dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_ban_file( session: SessionDep, file_id: UUID, request: FileBanRequest, admin: Annotated[User, Depends(admin_required)], -) -> ResponseBase: +) -> None: """ - 封禁或解禁文件。封禁后用户无法访问该文件。 + 封禁或解禁文件/文件夹。封禁后用户无法访问该文件。 + 封禁文件夹时会级联封禁所有子对象。 :param session: 数据库会话 :param file_id: 文件UUID @@ -139,24 +187,10 @@ async def router_admin_ban_file( if not file_obj: raise HTTPException(status_code=404, detail="文件不存在") - file_obj.is_banned = request.is_banned - if request.is_banned: - file_obj.banned_at = datetime.now() - file_obj.banned_by = admin.id - file_obj.ban_reason = request.reason - else: - file_obj.banned_at = None - file_obj.banned_by = None - file_obj.ban_reason = None + count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason) - file_obj = await file_obj.save(session) - - action = "封禁" if request.is_banned else "解禁" - l.info(f"管理员{action}了文件: {file_obj.name}") - return ResponseBase(data={ - "id": str(file_obj.id), - "is_banned": file_obj.is_banned, - }) + action = "封禁" if request.ban else "解禁" + l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象") @admin_file_router.delete( @@ -164,12 +198,13 @@ async def router_admin_ban_file( summary='删除文件', description='Delete file by ID', dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_delete_file( session: SessionDep, file_id: UUID, delete_physical: bool = True, -) -> ResponseBase: +) -> None: """ 删除文件。 @@ -211,5 +246,4 @@ async def router_admin_delete_file( # 使用条件删除 await Object.delete(session, condition=Object.id == file_obj.id) - l.info(f"管理员删除了文件: {file_name}") - return ResponseBase(data={"deleted": True}) \ No newline at end of file + l.info(f"管理员删除了文件: {file_name}") \ No newline at end of file diff --git a/routers/api/v1/admin/group/__init__.py b/routers/api/v1/admin/group/__init__.py index 7e58f4f..dbfba35 100644 --- a/routers/api/v1/admin/group/__init__.py +++ b/routers/api/v1/admin/group/__init__.py @@ -5,12 +5,12 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep -from models import ( +from sqlmodels import ( User, ResponseBase, UserPublic, ListResponse, Group, GroupOptions, ) -from models.group import ( +from sqlmodels.group import ( GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, ) -from models.policy import GroupPolicyLink +from sqlmodels.policy import GroupPolicyLink admin_group_router = APIRouter( prefix="/group", @@ -113,11 +113,12 @@ async def router_admin_get_group_members( summary='创建用户组', description='Create a new user group', dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_create_group( session: SessionDep, request: GroupCreateRequest, -) -> ResponseBase: +) -> None: """ 创建新的用户组。 @@ -164,7 +165,6 @@ async def router_admin_create_group( await session.commit() l.info(f"管理员创建了用户组: {group.name}") - return ResponseBase(data={"id": str(group.id), "name": group.name}) @admin_group_router.patch( @@ -172,12 +172,13 @@ async def router_admin_create_group( summary='更新用户组信息', description='Update user group information by ID', dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_update_group( session: SessionDep, group_id: UUID, request: GroupUpdateRequest, -) -> ResponseBase: +) -> None: """ 根据用户组ID更新用户组信息。 @@ -233,8 +234,7 @@ async def router_admin_update_group( session.add(link) await session.commit() - l.info(f"管理员更新了用户组: {group.name}") - return ResponseBase(data={"id": str(group.id)}) + l.info(f"管理员更新了用户组: {group_id}") @admin_group_router.delete( @@ -242,11 +242,12 @@ async def router_admin_update_group( summary='删除用户组', description='Delete user group by ID', dependencies=[Depends(admin_required)], + status_code=204, ) async def router_admin_delete_group( session: SessionDep, group_id: UUID, -) -> ResponseBase: +) -> None: """ 根据用户组ID删除用户组。 @@ -271,5 +272,4 @@ async def router_admin_delete_group( group_name = group.name await Group.delete(session, group) - l.info(f"管理员删除了用户组: {group_name}") - return ResponseBase(data={"deleted": True}) \ No newline at end of file + l.info(f"管理员删除了用户组: {group_id}") \ No newline at end of file diff --git a/routers/api/v1/admin/policy/__init__.py b/routers/api/v1/admin/policy/__init__.py index dc1f30a..f95678d 100644 --- a/routers/api/v1/admin/policy/__init__.py +++ b/routers/api/v1/admin/policy/__init__.py @@ -6,10 +6,10 @@ from sqlmodel import Field from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep -from models import ( +from sqlmodels import ( Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, ListResponse, Object, ) -from models.base import SQLModelBase +from sqlmodels.base import SQLModelBase from service.storage import DirectoryCreationError, LocalStorageService admin_policy_router = APIRouter( diff --git a/routers/api/v1/admin/share/__init__.py b/routers/api/v1/admin/share/__init__.py index 078f5c9..6d55143 100644 --- a/routers/api/v1/admin/share/__init__.py +++ b/routers/api/v1/admin/share/__init__.py @@ -5,7 +5,7 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep -from models import ( +from sqlmodels import ( ResponseBase, ListResponse, Share, AdminShareListItem, ) @@ -80,7 +80,7 @@ async def router_admin_get_share( "score": share.score, "has_password": bool(share.password), "user_id": str(share.user_id), - "username": user.username if user else None, + "username": user.email if user else None, "object": { "id": str(obj.id), "name": obj.name, diff --git a/routers/api/v1/admin/task/__init__.py b/routers/api/v1/admin/task/__init__.py index e4c8577..f32246f 100644 --- a/routers/api/v1/admin/task/__init__.py +++ b/routers/api/v1/admin/task/__init__.py @@ -5,7 +5,7 @@ from loguru import logger as l from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep -from models import ( +from sqlmodels import ( ResponseBase, ListResponse, Task, TaskSummary, ) @@ -89,7 +89,7 @@ async def router_admin_get_task( "progress": task.progress, "error": task.error, "user_id": str(task.user_id), - "username": user.username if user else None, + "username": user.email if user else None, "props": props.model_dump() if props else None, "created_at": task.created_at.isoformat(), "updated_at": task.updated_at.isoformat(), diff --git a/routers/api/v1/admin/user/__init__.py b/routers/api/v1/admin/user/__init__.py index b5e76a5..d95806e 100644 --- a/routers/api/v1/admin/user/__init__.py +++ b/routers/api/v1/admin/user/__init__.py @@ -6,11 +6,13 @@ from sqlalchemy import func from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep -from models import ( +from sqlmodels import ( User, ResponseBase, UserPublic, ListResponse, - Group, Object, ObjectType, ) -from models.user import ( - UserAdminUpdateRequest, UserCalibrateResponse, + Group, Object, ObjectType, Setting, SettingsType, + BatchDeleteRequest, +) +from sqlmodels.user import ( + UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, ) from utils import Password, http_exceptions @@ -26,19 +28,19 @@ admin_user_router = APIRouter( description='Get user information by ID', dependencies=[Depends(admin_required)], ) -async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase: +async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic: """ 根据用户ID获取用户信息,包括用户名、邮箱、注册时间等。 Args: session(SessionDep): 数据库会话依赖项。 - user_id (int): 用户ID。 + user_id (UUID): 用户ID。 Returns: ResponseBase: 包含用户信息的响应模型。 """ user = await User.get_exist_one(session, user_id) - return ResponseBase(data=user.to_public().model_dump()) + return user.to_public() @admin_user_router.get( @@ -60,7 +62,7 @@ async def router_admin_get_users( :param filter_params: 用户筛选参数(用户组、用户名、昵称、状态) :return: 分页用户列表 """ - result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view) + result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group) return ListResponse( items=[user.to_public() for user in result.items], count=result.count, @@ -75,22 +77,33 @@ async def router_admin_get_users( ) async def router_admin_create_user( session: SessionDep, - user: User, -) -> ResponseBase: + request: UserAdminCreateRequest, +) -> UserPublic: """ - 创建一个新的用户,设置用户名、密码等信息。 + 创建一个新的用户,设置邮箱、密码、用户组等信息。 - Returns: - ResponseBase: 包含创建结果的响应模型。 + :param session: 数据库会话 + :param request: 创建用户请求 DTO + :return: 创建结果 """ - existing_user = await User.get(session, User.username == user.username) + existing_user = await User.get(session, User.email == request.email) if existing_user: - return ResponseBase( - code=400, - msg="User with this username already exists." - ) + raise HTTPException(status_code=409, detail="该邮箱已被注册") + + # 验证用户组存在 + group = await Group.get(session, Group.id == request.group_id) + if not group: + raise HTTPException(status_code=400, detail="目标用户组不存在") + + user = User( + email=request.email, + password=Password.hash(request.password), + nickname=request.nickname, + group_id=request.group_id, + status=request.status, + ) user = await user.save(session) - return ResponseBase(data=user.to_public().model_dump()) + return user.to_public() @admin_user_router.patch( @@ -98,12 +111,13 @@ async def router_admin_create_user( summary='更新用户信息', description='Update user information by ID', dependencies=[Depends(admin_required)], + status_code=204 ) async def router_admin_update_user( session: SessionDep, user_id: UUID, request: UserAdminUpdateRequest, -) -> ResponseBase: +) -> None: """ 根据用户ID更新用户信息。 @@ -116,8 +130,15 @@ async def router_admin_update_user( if not user: raise HTTPException(status_code=404, detail="用户不存在") - # 默认管理员(用户名为 admin)不允许更改用户组 - if request.group_id and user.username == "admin" and request.group_id != user.group_id: + # 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别) + default_admin_setting = await Setting.get( + session, + (Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id") + ) + if (request.group_id + and default_admin_setting + and default_admin_setting.value == str(user_id) + and request.group_id != user.group_id): http_exceptions.raise_forbidden("默认管理员不允许更改用户组") # 如果更新用户组,验证新组存在 @@ -143,38 +164,35 @@ async def router_admin_update_user( setattr(user, key, value) user = await user.save(session) - l.info(f"管理员更新了用户: {user.username}") - return ResponseBase(data=user.to_public().model_dump()) + l.info(f"管理员更新了用户: {request.email}") @admin_user_router.delete( - path='/{user_id}', - summary='删除用户', - description='Delete user by ID', + path='/', + summary='删除用户(支持批量)', + description='Delete users by ID list', dependencies=[Depends(admin_required)], + status_code=204, ) -async def router_admin_delete_user( +async def router_admin_delete_users( session: SessionDep, - user_id: UUID, -) -> ResponseBase: + request: BatchDeleteRequest, +) -> None: """ - 根据用户ID删除用户及其所有数据。 + 批量删除用户及其所有数据。 注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。 :param session: 数据库会话 - :param user_id: 用户UUID - :return: 删除结果 + :param request: 批量删除请求,包含待删除用户的 UUID 列表 + :return: 删除结果(已删除数 / 总请求数) """ - user = await User.get(session, User.id == user_id) - if not user: - raise HTTPException(status_code=404, detail="用户不存在") - - username = user.username - await User.delete(session, user) - - l.info(f"管理员删除了用户: {username}") - return ResponseBase(data={"deleted": True}) + deleted = 0 + for uid in request.ids: + user = await User.get(session, User.id == uid) + if user: + await User.delete(session, user) + l.info(f"管理员删除了用户: {user.email}") @admin_user_router.post( @@ -186,7 +204,7 @@ async def router_admin_delete_user( async def router_admin_calibrate_storage( session: SessionDep, user_id: UUID, -) -> ResponseBase: +) -> UserCalibrateResponse: """ 重新计算用户的已用存储空间。 @@ -228,5 +246,5 @@ async def router_admin_calibrate_storage( file_count=file_count, ) - l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}") - return ResponseBase(data=response.model_dump()) \ No newline at end of file + l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}") + return response \ No newline at end of file diff --git a/routers/api/v1/admin/vas/__init__.py b/routers/api/v1/admin/vas/__init__.py index c112d31..8de1ba8 100644 --- a/routers/api/v1/admin/vas/__init__.py +++ b/routers/api/v1/admin/vas/__init__.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException from middleware.auth import admin_required from middleware.dependencies import SessionDep -from models import ( +from sqlmodels import ( ResponseBase, ) diff --git a/routers/api/v1/callback/__init__.py b/routers/api/v1/callback/__init__.py index 38aaac7..778bf5d 100644 --- a/routers/api/v1/callback/__init__.py +++ b/routers/api/v1/callback/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Query from fastapi.responses import PlainTextResponse -from models import ResponseBase +from sqlmodels import ResponseBase import service.oauth from utils import http_exceptions diff --git a/routers/api/v1/directory/__init__.py b/routers/api/v1/directory/__init__.py index 40f93cd..8599e9b 100644 --- a/routers/api/v1/directory/__init__.py +++ b/routers/api/v1/directory/__init__.py @@ -1,10 +1,12 @@ from typing import Annotated +from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from sqlmodel.ext.asyncio.session import AsyncSession from middleware.auth import auth_required from middleware.dependencies import SessionDep -from models import ( +from sqlmodels import ( DirectoryCreateRequest, DirectoryResponse, Object, @@ -14,50 +16,28 @@ from models import ( User, ResponseBase, ) +from utils import http_exceptions directory_router = APIRouter( prefix="/directory", tags=["directory"] ) -@directory_router.get( - path="/{path:path}", - summary="获取目录内容", -) -async def router_directory_get( - session: SessionDep, - user: Annotated[User, Depends(auth_required)], - path: str + +async def _get_directory_response( + session: AsyncSession, + user_id: UUID, + folder: Object, ) -> DirectoryResponse: """ - 获取目录内容 - - 路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs - `.crash` 代表回收站,也就意味着用户名禁止为 `.crash` + 构建目录响应 DTO :param session: 数据库会话 - :param user: 当前登录用户 - :param path: 目录路径(必须以用户名开头) - :return: 目录内容 + :param user_id: 用户UUID + :param folder: 目录对象 + :return: DirectoryResponse """ - # 路径必须以用户名开头 - path = path.strip("/") - if not path: - raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式") - - path_parts = path.split("/") - if path_parts[0] != user.username: - raise HTTPException(status_code=403, detail="无权访问其他用户的目录") - - folder = await Object.get_by_path(session, user.id, "/" + path, user.username) - - if not folder: - raise HTTPException(status_code=404, detail="目录不存在") - - if not folder.is_folder: - raise HTTPException(status_code=400, detail="指定路径不是目录") - - children = await Object.get_children(session, user.id, folder.id) + children = await Object.get_children(session, user_id, folder.id) policy = await folder.awaitable_attrs.policy objects = [ @@ -67,8 +47,8 @@ async def router_directory_get( thumb=False, size=child.size, type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE, - date=child.updated_at, - create_date=child.created_at, + created_at=child.created_at, + updated_at=child.updated_at, source_enabled=False, ) for child in children @@ -89,7 +69,74 @@ async def router_directory_get( ) -@directory_router.put( +@directory_router.get( + path="/", + summary="获取根目录内容", +) +async def router_directory_root( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], +) -> DirectoryResponse: + """ + 获取当前用户的根目录内容 + + :param session: 数据库会话 + :param user: 当前登录用户 + :return: 根目录内容 + """ + root = await Object.get_root(session, user.id) + if not root: + raise HTTPException(status_code=404, detail="根目录不存在") + + if root.is_banned: + http_exceptions.raise_banned() + + return await _get_directory_response(session, user.id, root) + + +@directory_router.get( + path="/{path:path}", + summary="获取目录内容", +) +async def router_directory_get( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + path: str +) -> DirectoryResponse: + """ + 获取目录内容 + + 路径从用户根目录开始,不包含用户名前缀。 + 如 /api/v1/directory/docs 表示根目录下的 docs 目录。 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param path: 目录路径(从根目录开始的相对路径) + :return: 目录内容 + """ + path = path.strip("/") + if not path: + # 空路径交给根目录端点处理(理论上不会到达这里) + root = await Object.get_root(session, user.id) + if not root: + raise HTTPException(status_code=404, detail="根目录不存在") + return await _get_directory_response(session, user.id, root) + + folder = await Object.get_by_path(session, user.id, "/" + path) + + if not folder: + raise HTTPException(status_code=404, detail="目录不存在") + + if not folder.is_folder: + raise HTTPException(status_code=400, detail="指定路径不是目录") + + if folder.is_banned: + http_exceptions.raise_banned() + + return await _get_directory_response(session, user.id, folder) + + +@directory_router.post( path="/", summary="创建目录", ) @@ -123,6 +170,9 @@ async def router_directory_create( if not parent.is_folder: raise HTTPException(status_code=400, detail="父路径不是目录") + if parent.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + # 检查是否已存在同名对象 existing = await Object.get( session, diff --git a/routers/api/v1/download/__init__.py b/routers/api/v1/download/__init__.py index f6bb6db..ac7e1db 100644 --- a/routers/api/v1/download/__init__.py +++ b/routers/api/v1/download/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from middleware.auth import auth_required -from models import ResponseBase +from sqlmodels import ResponseBase from utils import http_exceptions download_router = APIRouter( diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index 81fee7b..5fa6fef 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -18,7 +18,7 @@ from loguru import logger as l from middleware.auth import auth_required, verify_download_token from middleware.dependencies import SessionDep -from models import ( +from sqlmodels import ( CreateFileRequest, CreateUploadSessionRequest, Object, @@ -91,6 +91,9 @@ async def create_upload_session( if not parent.is_folder: raise HTTPException(status_code=400, detail="父对象不是目录") + if parent.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + # 确定存储策略 policy_id = request.policy_id or parent.policy_id policy = await Policy.get(session, Policy.id == policy_id) @@ -100,7 +103,7 @@ async def create_upload_session( # 验证文件大小限制 if policy.max_size > 0 and request.file_size > policy.max_size: raise HTTPException( - status_code=400, + status_code=413, detail=f"文件大小超过限制 ({policy.max_size} bytes)" ) @@ -221,30 +224,40 @@ async def upload_chunk( upload_session.uploaded_size += len(content) upload_session = await upload_session.save(session) - # 检查是否完成 + # 在后续可能的 commit 前保存需要的属性 is_complete = upload_session.is_complete + uploaded_chunks = upload_session.uploaded_chunks + total_chunks = upload_session.total_chunks file_object_id: UUID | None = None if is_complete: + # 保存 upload_session 属性(commit 后会过期) + file_name = upload_session.file_name + uploaded_size = upload_session.uploaded_size + storage_path = upload_session.storage_path + upload_session_id = upload_session.id + parent_id = upload_session.parent_id + policy_id = upload_session.policy_id + # 创建 PhysicalFile 记录 physical_file = PhysicalFile( - storage_path=upload_session.storage_path, - size=upload_session.uploaded_size, - policy_id=upload_session.policy_id, + storage_path=storage_path, + size=uploaded_size, + policy_id=policy_id, reference_count=1, ) physical_file = await physical_file.save(session, commit=False) # 创建 Object 记录 file_object = Object( - name=upload_session.file_name, + name=file_name, type=ObjectType.FILE, - size=upload_session.uploaded_size, + size=uploaded_size, physical_file_id=physical_file.id, - upload_session_id=str(upload_session.id), - parent_id=upload_session.parent_id, + upload_session_id=str(upload_session_id), + parent_id=parent_id, owner_id=user_id, - policy_id=upload_session.policy_id, + policy_id=policy_id, ) file_object = await file_object.save(session, commit=False) file_object_id = file_object.id @@ -252,18 +265,18 @@ async def upload_chunk( # 删除上传会话(使用条件删除) await UploadSession.delete( session, - condition=UploadSession.id == upload_session.id, + condition=UploadSession.id == upload_session_id, commit=False ) # 统一提交所有更改 await session.commit() - l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}") + l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}") return UploadChunkResponse( - uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks, - total_chunks=upload_session.total_chunks, + uploaded_chunks=uploaded_chunks if not is_complete else total_chunks, + total_chunks=total_chunks, is_complete=is_complete, object_id=file_object_id, ) @@ -368,6 +381,9 @@ async def create_download_token_endpoint( if not file_obj.is_file: raise HTTPException(status_code=400, detail="对象不是文件") + if file_obj.is_banned: + http_exceptions.raise_banned() + token = create_download_token(file_id, user.id) l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}") @@ -410,6 +426,9 @@ async def download_file( if not file_obj.is_file: raise HTTPException(status_code=400, detail="对象不是文件") + if file_obj.is_banned: + http_exceptions.raise_banned() + # 预加载 physical_file 关系以获取存储路径 physical_file = await file_obj.awaitable_attrs.physical_file if not physical_file or not physical_file.storage_path: @@ -470,6 +489,9 @@ async def create_empty_file( if not parent.is_folder: raise HTTPException(status_code=400, detail="父对象不是目录") + if parent.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + # 检查是否已存在同名文件 existing = await Object.get( session, diff --git a/routers/api/v1/mcp/__init__.py b/routers/api/v1/mcp/__init__.py deleted file mode 100644 index e2fe322..0000000 --- a/routers/api/v1/mcp/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter - -from models import MCPRequestBase, MCPResponseBase, MCPMethod - -# MCP 路由 -MCP_router = APIRouter( - prefix='/mcp', - tags=["mcp"], -) - -@MCP_router.get( - "/", -) -async def mcp_root( - param: MCPRequestBase -): - match param.method: - case MCPMethod.PING: - return MCPResponseBase(result="pong", **param.model_dump()) \ No newline at end of file diff --git a/routers/api/v1/object/__init__.py b/routers/api/v1/object/__init__.py index 93a10c4..c097e3b 100644 --- a/routers/api/v1/object/__init__.py +++ b/routers/api/v1/object/__init__.py @@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession from middleware.auth import auth_required from middleware.dependencies import SessionDep -from models import ( +from sqlmodels import ( + CreateFileRequest, Object, ObjectCopyRequest, ObjectDeleteRequest, @@ -26,10 +27,11 @@ from models import ( PhysicalFile, Policy, PolicyType, + ResponseBase, User, ) -from models import ResponseBase from service.storage import LocalStorageService +from utils import http_exceptions object_router = APIRouter( prefix="/object", @@ -59,15 +61,22 @@ async def _delete_object_recursive( """ deleted_count = 0 - if obj.is_folder: + # 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败 + obj_id = obj.id + obj_name = obj.name + obj_is_folder = obj.is_folder + obj_is_file = obj.is_file + obj_physical_file_id = obj.physical_file_id + + if obj_is_folder: # 递归删除子对象 - children = await Object.get_children(session, user_id, obj.id) + children = await Object.get_children(session, user_id, obj_id) for child in children: deleted_count += await _delete_object_recursive(session, child, user_id) # 如果是文件,处理物理文件引用 - if obj.is_file and obj.physical_file_id: - physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) + if obj_is_file and obj_physical_file_id: + physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id) if physical_file: # 减少引用计数 new_count = physical_file.decrement_reference() @@ -81,11 +90,11 @@ async def _delete_object_recursive( await storage_service.move_to_trash( source_path=physical_file.storage_path, user_id=user_id, - object_id=obj.id, + object_id=obj_id, ) - l.debug(f"物理文件已移动到回收站: {obj.name}") + l.debug(f"物理文件已移动到回收站: {obj_name}") except Exception as e: - l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}") + l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}") # 删除 PhysicalFile 记录 await PhysicalFile.delete(session, physical_file) @@ -95,8 +104,8 @@ async def _delete_object_recursive( await physical_file.save(session) l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}") - # 删除数据库记录 - await Object.delete(session, obj) + # 使用条件删除,避免访问过期的 obj 实例 + await Object.delete(session, condition=Object.id == obj_id) deleted_count += 1 return deleted_count @@ -168,6 +177,97 @@ async def _copy_object_recursive( return copied_count, new_ids +@object_router.post( + path='/', + summary='创建空白文件', + description='在指定目录下创建空白文件。', +) +async def router_object_create( + session: SessionDep, + user: Annotated[User, Depends(auth_required)], + request: CreateFileRequest, +) -> ResponseBase: + """ + 创建空白文件端点 + + :param session: 数据库会话 + :param user: 当前登录用户 + :param request: 创建文件请求(parent_id, name) + :return: 创建结果 + """ + user_id = user.id + + # 验证文件名 + if not request.name or '/' in request.name or '\\' in request.name: + raise HTTPException(status_code=400, detail="无效的文件名") + + # 验证父目录 + parent = await Object.get(session, Object.id == request.parent_id) + if not parent or parent.owner_id != user_id: + raise HTTPException(status_code=404, detail="父目录不存在") + + if not parent.is_folder: + raise HTTPException(status_code=400, detail="父对象不是目录") + + if parent.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + + # 检查是否已存在同名文件 + existing = await Object.get( + session, + (Object.owner_id == user_id) & + (Object.parent_id == parent.id) & + (Object.name == request.name) + ) + if existing: + raise HTTPException(status_code=409, detail="同名文件已存在") + + # 确定存储策略 + policy_id = request.policy_id or parent.policy_id + policy = await Policy.get(session, Policy.id == policy_id) + if not policy: + raise HTTPException(status_code=404, detail="存储策略不存在") + + parent_id = parent.id + + # 生成存储路径并创建空文件 + if policy.type == PolicyType.LOCAL: + storage_service = LocalStorageService(policy) + dir_path, storage_name, full_path = await storage_service.generate_file_path( + user_id=user_id, + original_filename=request.name, + ) + await storage_service.create_empty_file(full_path) + storage_path = full_path + else: + raise HTTPException(status_code=501, detail="S3 存储暂未实现") + + # 创建 PhysicalFile 记录 + physical_file = PhysicalFile( + storage_path=storage_path, + size=0, + policy_id=policy_id, + reference_count=1, + ) + physical_file = await physical_file.save(session) + + # 创建 Object 记录 + file_object = Object( + name=request.name, + type=ObjectType.FILE, + size=0, + physical_file_id=physical_file.id, + parent_id=parent_id, + owner_id=user_id, + policy_id=policy_id, + ) + await file_object.save(session) + + l.info(f"创建空白文件: {request.name}") + + return ResponseBase() + + @object_router.delete( path='/', summary='删除对象', @@ -197,10 +297,7 @@ async def router_object_delete( user_id = user.id deleted_count = 0 - # 处理单个 UUID 或 UUID 列表 - ids = request.ids if isinstance(request.ids, list) else [request.ids] - - for obj_id in ids: + for obj_id in request.ids: obj = await Object.get(session, Object.id == obj_id) if not obj or obj.owner_id != user_id: continue @@ -219,7 +316,7 @@ async def router_object_delete( return ResponseBase( data={ "deleted": deleted_count, - "total": len(ids), + "total": len(request.ids), } ) @@ -253,6 +350,9 @@ async def router_object_move( if not dst.is_folder: raise HTTPException(status_code=400, detail="目标不是有效文件夹") + if dst.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + # 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问 dst_id = dst.id dst_parent_id = dst.parent_id @@ -264,6 +364,9 @@ async def router_object_move( if not src or src.owner_id != user_id: continue + if src.is_banned: + continue + # 不能移动根目录 if src.parent_id is None: continue @@ -348,6 +451,9 @@ async def router_object_copy( if not dst.is_folder: raise HTTPException(status_code=400, detail="目标不是有效文件夹") + if dst.is_banned: + http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作") + copied_count = 0 new_ids: list[UUID] = [] @@ -356,6 +462,9 @@ async def router_object_copy( if not src or src.owner_id != user_id: continue + if src.is_banned: + continue + # 不能复制根目录 if src.parent_id is None: continue @@ -438,6 +547,9 @@ async def router_object_rename( if obj.owner_id != user_id: raise HTTPException(status_code=403, detail="无权操作此对象") + if obj.is_banned: + http_exceptions.raise_banned() + # 不能重命名根目录 if obj.parent_id is None: raise HTTPException(status_code=400, detail="无法重命名根目录") @@ -543,7 +655,7 @@ async def router_object_property_detail( policy_name = policy.name if policy else None # 获取分享统计 - from models import Share + from sqlmodels import Share shares = await Share.get( session, Share.object_id == obj.id, diff --git a/routers/api/v1/share/__init__.py b/routers/api/v1/share/__init__.py index ba0d142..d0522f5 100644 --- a/routers/api/v1/share/__init__.py +++ b/routers/api/v1/share/__init__.py @@ -7,11 +7,11 @@ from loguru import logger as l from middleware.auth import auth_required from middleware.dependencies import SessionDep -from models import ResponseBase -from models.user import User -from models.share import Share, ShareCreateRequest, ShareResponse -from models.object import Object -from models.mixin import ListResponse, TableViewRequest +from sqlmodels import ResponseBase +from sqlmodels.user import User +from sqlmodels.share import Share, ShareCreateRequest, ShareResponse +from sqlmodels.object import Object +from sqlmodels.mixin import ListResponse, TableViewRequest from utils import http_exceptions from utils.password.pwd import Password @@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase: """ http_exceptions.raise_not_implemented() -@share_router.get( - path='/doc/{id}', - summary='取得Office文档预览地址', - description='Get Office document preview URL by ID.', -) -def router_share_doc(id: str) -> ResponseBase: - """ - Get Office document preview URL by ID. - - Args: - id (str): The ID of the Office document. - - Returns: - dict: A dictionary containing the document preview URL. - """ - http_exceptions.raise_not_implemented() - @share_router.get( path='/content/{id}', summary='获取文本文件内容', @@ -261,6 +244,9 @@ async def router_share_create( if not obj or obj.owner_id != user.id: raise HTTPException(status_code=404, detail="对象不存在或无权限") + if obj.is_banned: + http_exceptions.raise_banned() + # 生成分享码 code = str(uuid4()) diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index 69b0db6..068cc6d 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -1,7 +1,8 @@ from fastapi import APIRouter from middleware.dependencies import SessionDep -from models import ResponseBase, Setting, SettingsType, SiteConfigResponse +from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse +from sqlmodels.setting import CaptchaType from utils import http_exceptions site_router = APIRouter( @@ -43,16 +44,43 @@ def router_site_captcha(): @site_router.get( path='/config', summary='站点全局配置', - description='Get the configuration file.', - response_model=ResponseBase, + description='获取站点全局配置,包括验证码设置、注册开关等。', ) async def router_site_config(session: SessionDep) -> SiteConfigResponse: """ - Get the configuration file. + 获取站点全局配置 - Returns: - dict: The site configuration. + 无需认证。前端在初始化时调用此端点获取验证码类型、 + 登录/注册/找回密码是否需要验证码等配置。 """ + # 批量查询所需设置 + settings: list[Setting] = await Setting.get( + session, + (Setting.type == SettingsType.BASIC) | + (Setting.type == SettingsType.LOGIN) | + (Setting.type == SettingsType.REGISTER) | + (Setting.type == SettingsType.CAPTCHA), + fetch_mode="all", + ) + + # 构建 name→value 映射 + s: dict[str, str | None] = {item.name: item.value for item in settings} + + # 根据 captcha_type 选择对应的 public key + captcha_type_str = s.get("captcha_type", "default") + captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT + captcha_key: str | None = None + if captcha_type == CaptchaType.GCAPTCHA: + captcha_key = s.get("captcha_ReCaptchaKey") or None + elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE: + captcha_key = s.get("captcha_CloudflareKey") or None + return SiteConfigResponse( - title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")), + title=s.get("siteName") or "DiskNext", + register_enabled=s.get("register_enabled") == "1", + login_captcha=s.get("login_captcha") == "1", + reg_captcha=s.get("reg_captcha") == "1", + forget_captcha=s.get("forget_captcha") == "1", + captcha_type=captcha_type, + captcha_key=captcha_key, ) \ No newline at end of file diff --git a/routers/api/v1/slave/__init__.py b/routers/api/v1/slave/__init__.py index ad751cb..37f5e0a 100644 --- a/routers/api/v1/slave/__init__.py +++ b/routers/api/v1/slave/__init__.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends from fastapi.responses import FileResponse from middleware.auth import auth_required -from models import ResponseBase +from sqlmodels import ResponseBase from utils import http_exceptions slave_router = APIRouter( diff --git a/routers/api/v1/tag/__init__.py b/routers/api/v1/tag/__init__.py index edc1c7a..dd751a7 100644 --- a/routers/api/v1/tag/__init__.py +++ b/routers/api/v1/tag/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from middleware.auth import auth_required -from models import ResponseBase +from sqlmodels import ResponseBase from utils import http_exceptions tag_router = APIRouter( diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 132b6ce..b237159 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -1,30 +1,26 @@ from typing import Annotated, Literal -from uuid import UUID +from uuid import UUID, uuid4 +import jwt from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm +from loguru import logger from webauthn import generate_registration_options from webauthn.helpers import options_to_json_dict -from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired -from loguru import logger -import models import service +import sqlmodels from middleware.auth import auth_required from middleware.dependencies import SessionDep -from utils.JWT import SECRET_KEY -from utils import Password, http_exceptions +from utils import JWT, Password, http_exceptions +from .settings import user_settings_router user_router = APIRouter( prefix="/user", tags=["user"], ) -user_settings_router = APIRouter( - prefix='/user/settings', - tags=["user", "user_settings"], - dependencies=[Depends(auth_required)], -) +user_router.include_router(user_settings_router) @user_router.post( path='/session', @@ -34,7 +30,7 @@ user_settings_router = APIRouter( async def router_user_session( session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], -) -> models.TokenResponse: +) -> sqlmodels.TokenResponse: """ 用户登录端点。 @@ -43,7 +39,7 @@ async def router_user_session( OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码 """ - username = form_data.username + email = form_data.username # OAuth2 表单字段名为 username,实际传入的是 email password = form_data.password # 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式) @@ -59,8 +55,8 @@ async def router_user_session( result = await service.user.login( session, - models.LoginRequest( - username=username, + sqlmodels.LoginRequest( + email=email, password=password, two_fa_code=otp_code, ), @@ -75,19 +71,70 @@ async def router_user_session( ) async def router_user_session_refresh( session: SessionDep, - request, # RefreshTokenRequest -) -> models.TokenResponse: - http_exceptions.raise_not_implemented() + request: sqlmodels.RefreshTokenRequest, +) -> sqlmodels.TokenResponse: + """ + 使用 refresh_token 签发新的 access_token 和 refresh_token。 + + 流程: + 1. 解码 refresh_token JWT + 2. 验证 token_type 为 refresh + 3. 验证用户存在且状态正常 + 4. 签发新的 access_token + refresh_token + + :param session: 数据库会话 + :param request: 刷新令牌请求 + :return: 新的 TokenResponse + """ + + try: + payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"]) + except jwt.InvalidTokenError: + http_exceptions.raise_unauthorized("刷新令牌无效或已过期") + + # 验证是 refresh token + if payload.get("token_type") != "refresh": + http_exceptions.raise_unauthorized("非刷新令牌") + + user_id_str = payload.get("sub") + if not user_id_str: + http_exceptions.raise_unauthorized("令牌缺少用户标识") + + user_id = UUID(user_id_str) + user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id) + if not user: + http_exceptions.raise_unauthorized("用户不存在") + + if not user.status: + http_exceptions.raise_forbidden("账户已被禁用") + + # 签发新令牌 + access_token = JWT.create_access_token( + sub=user.id, + jti=uuid4(), + ) + refresh_token = JWT.create_refresh_token( + sub=user.id, + jti=uuid4(), + ) + + return sqlmodels.TokenResponse( + access_token=access_token.access_token, + access_expires=access_token.access_expires, + refresh_token=refresh_token.refresh_token, + refresh_expires=refresh_token.refresh_expires, + ) @user_router.post( path='/', summary='用户注册', description='User registration endpoint.', + status_code=204, ) async def router_user_register( session: SessionDep, - request: models.RegisterRequest, -) -> models.ResponseBase: + request: sqlmodels.RegisterRequest, +) -> None: """ 用户注册端点 @@ -95,7 +142,7 @@ async def router_user_register( 1. 验证用户名唯一性 2. 获取默认用户组 3. 创建用户记录 - 4. 创建以用户名命名的根目录 + 4. 创建用户根目录(name="/") :param session: 数据库会话 :param request: 注册请求 @@ -103,62 +150,53 @@ async def router_user_register( :raises HTTPException 400: 用户名已存在 :raises HTTPException 500: 默认用户组或存储策略不存在 """ - # 1. 验证用户名唯一性 - existing_user = await models.User.get( + # 1. 验证邮箱唯一性 + existing_user = await sqlmodels.User.get( session, - models.User.username == request.username + sqlmodels.User.email == request.email ) if existing_user: - raise HTTPException(status_code=400, detail="用户名已存在") + raise HTTPException(status_code=400, detail="邮箱已存在") # 2. 获取默认用户组(从设置中读取 UUID) - default_group_setting: models.Setting | None = await models.Setting.get( + default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get( session, - (models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group") + (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group") ) if default_group_setting is None or not default_group_setting.value: logger.error("默认用户组不存在") http_exceptions.raise_internal_error() default_group_id = UUID(default_group_setting.value) - default_group = await models.Group.get(session, models.Group.id == default_group_id) + default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id) if not default_group: logger.error("默认用户组不存在") http_exceptions.raise_internal_error() # 3. 创建用户 hashed_password = Password.hash(request.password) - new_user = models.User( - username=request.username, + new_user = sqlmodels.User( + email=request.email, password=hashed_password, group_id=default_group.id, ) - new_user_id = new_user.id # 在 save 前保存 UUID - new_user_username = new_user.username + new_user_id = new_user.id await new_user.save(session) - # 4. 创建以用户名命名的根目录 - default_policy = await models.Policy.get(session, models.Policy.name == "本地存储") + # 4. 创建用户根目录 + default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储") if not default_policy: logger.error("默认存储策略不存在") http_exceptions.raise_internal_error() - await models.Object( - name=new_user_username, - type=models.ObjectType.FOLDER, + await sqlmodels.Object( + name="/", + type=sqlmodels.ObjectType.FOLDER, owner_id=new_user_id, parent_id=None, policy_id=default_policy.id, ).save(session) - return models.ResponseBase( - data={ - "user_id": new_user_id, - "username": new_user_username, - }, - msg="注册成功", - ) - @user_router.post( path='/code', summary='发送验证码邮件', @@ -166,7 +204,7 @@ async def router_user_register( ) def router_user_email_code( reason: Literal['register', 'reset'] = 'register', -) -> models.ResponseBase: +) -> sqlmodels.ResponseBase: """ Send a verification code email. @@ -180,7 +218,7 @@ def router_user_email_code( summary='初始化QQ登录', description='Initialize QQ login for a user.', ) -def router_user_qq() -> models.ResponseBase: +def router_user_qq() -> sqlmodels.ResponseBase: """ Initialize QQ login for a user. @@ -194,7 +232,7 @@ def router_user_qq() -> models.ResponseBase: summary='WebAuthn登录初始化', description='Initialize WebAuthn login for a user.', ) -async def router_user_authn(username: str) -> models.ResponseBase: +async def router_user_authn(username: str) -> sqlmodels.ResponseBase: http_exceptions.raise_not_implemented() @@ -203,7 +241,7 @@ async def router_user_authn(username: str) -> models.ResponseBase: summary='WebAuthn登录', description='Finish WebAuthn login for a user.', ) -def router_user_authn_finish(username: str) -> models.ResponseBase: +def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase: """ Finish WebAuthn login for a user. @@ -220,7 +258,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase: summary='获取用户主页展示用分享', description='Get user profile for display.', ) -def router_user_profile(id: str) -> models.ResponseBase: +def router_user_profile(id: str) -> sqlmodels.ResponseBase: """ Get user profile for display. @@ -237,7 +275,7 @@ def router_user_profile(id: str) -> models.ResponseBase: summary='获取用户头像', description='Get user avatar by ID and size.', ) -def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: +def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase: """ Get user avatar by ID and size. @@ -259,12 +297,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: summary='获取用户信息', description='Get user information.', dependencies=[Depends(dependency=auth_required)], - response_model=models.UserResponse, + response_model=sqlmodels.UserResponse, ) async def router_user_me( session: SessionDep, - user: Annotated[models.User, Depends(auth_required)], -) -> models.ResponseBase: + user: Annotated[sqlmodels.User, Depends(auth_required)], +) -> sqlmodels.UserResponse: """ 获取用户信息. @@ -272,10 +310,10 @@ async def router_user_me( :rtype: ResponseBase """ # 加载 group 及其 options 关系 - group = await models.Group.get( + group = await sqlmodels.Group.get( session, - models.Group.id == user.group_id, - load=models.Group.options + sqlmodels.Group.id == user.group_id, + load=sqlmodels.Group.options ) # 构建 GroupResponse @@ -284,9 +322,9 @@ async def router_user_me( # 异步加载 tags 关系 user_tags = await user.awaitable_attrs.tags - return models.UserResponse( + return sqlmodels.UserResponse( id=user.id, - username=user.username, + email=user.email, status=user.status, score=user.score, nickname=user.nickname, @@ -304,30 +342,26 @@ async def router_user_me( ) async def router_user_storage( session: SessionDep, - user: Annotated[models.user.User, Depends(auth_required)], -) -> models.ResponseBase: + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> sqlmodels.UserStorageResponse: """ 获取用户存储空间信息。 - - 返回值: - - used: 已使用空间(字节) - - free: 剩余空间(字节) - - total: 总容量(字节)= 用户组容量 """ # 获取用户组的基础存储容量 - group = await models.Group.get(session, models.Group.id == user.group_id) + group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id) if not group: - raise HTTPException(status_code=500, detail="用户组不存在") + raise HTTPException(status_code=404, detail="用户组不存在") + + # [TODO] 总空间加上用户购买的额外空间 + total: int = group.max_storage used: int = user.storage free: int = max(0, total - used) - return models.ResponseBase( - data={ - "used": used, - "free": free, - "total": total, - } + return sqlmodels.UserStorageResponse( + used=used, + free=free, + total=total, ) @user_router.put( @@ -338,8 +372,8 @@ async def router_user_storage( ) async def router_user_authn_start( session: SessionDep, - user: Annotated[models.user.User, Depends(auth_required)], -) -> models.ResponseBase: + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> sqlmodels.ResponseBase: """ Initialize WebAuthn login for a user. @@ -347,30 +381,30 @@ async def router_user_authn_start( dict: A dictionary containing WebAuthn initialization information. """ # TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等 - authn_setting = await models.Setting.get( + authn_setting = await sqlmodels.Setting.get( session, - (models.Setting.type == "authn") & (models.Setting.name == "authn_enabled") + (sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled") ) if not authn_setting or authn_setting.value != "1": raise HTTPException(status_code=400, detail="WebAuthn is not enabled") - site_url_setting = await models.Setting.get( + site_url_setting = await sqlmodels.Setting.get( session, - (models.Setting.type == "basic") & (models.Setting.name == "siteURL") + (sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL") ) - site_title_setting = await models.Setting.get( + site_title_setting = await sqlmodels.Setting.get( session, - (models.Setting.type == "basic") & (models.Setting.name == "siteTitle") + (sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle") ) options = generate_registration_options( rp_id=site_url_setting.value if site_url_setting else "", rp_name=site_title_setting.value if site_title_setting else "", - user_name=user.username, - user_display_name=user.nick or user.username, + user_name=user.email, + user_display_name=user.nickname or user.email, ) - return models.ResponseBase(data=options_to_json_dict(options)) + return sqlmodels.ResponseBase(data=options_to_json_dict(options)) @user_router.put( path='/authn/finish', @@ -378,179 +412,11 @@ async def router_user_authn_start( description='Finish WebAuthn login for a user.', dependencies=[Depends(auth_required)], ) -def router_user_authn_finish() -> models.ResponseBase: +def router_user_authn_finish() -> sqlmodels.ResponseBase: """ Finish WebAuthn login for a user. Returns: dict: A dictionary containing WebAuthn login information. """ - http_exceptions.raise_not_implemented() - -@user_settings_router.get( - path='/policies', - summary='获取用户可选存储策略', - description='Get user selectable storage policies.', -) -def router_user_settings_policies() -> models.ResponseBase: - """ - Get user selectable storage policies. - - Returns: - dict: A dictionary containing available storage policies for the user. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.get( - path='/nodes', - summary='获取用户可选节点', - description='Get user selectable nodes.', - dependencies=[Depends(auth_required)], -) -def router_user_settings_nodes() -> models.ResponseBase: - """ - Get user selectable nodes. - - Returns: - dict: A dictionary containing available nodes for the user. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.get( - path='/tasks', - summary='任务队列', - description='Get user task queue.', - dependencies=[Depends(auth_required)], -) -def router_user_settings_tasks() -> models.ResponseBase: - """ - Get user task queue. - - Returns: - dict: A dictionary containing the user's task queue information. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.get( - path='/', - summary='获取当前用户设定', - description='Get current user settings.', - dependencies=[Depends(auth_required)], -) -def router_user_settings() -> models.ResponseBase: - """ - Get current user settings. - - Returns: - dict: A dictionary containing the current user settings. - """ - return models.ResponseBase(data=models.UserSettingResponse().model_dump()) - -@user_settings_router.post( - path='/avatar', - summary='从文件上传头像', - description='Upload user avatar from file.', - dependencies=[Depends(auth_required)], -) -def router_user_settings_avatar() -> models.ResponseBase: - """ - Upload user avatar from file. - - Returns: - dict: A dictionary containing the result of the avatar upload. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.put( - path='/avatar', - summary='设定为Gravatar头像', - description='Set user avatar to Gravatar.', - dependencies=[Depends(auth_required)], -) -def router_user_settings_avatar_gravatar() -> models.ResponseBase: - """ - Set user avatar to Gravatar. - - Returns: - dict: A dictionary containing the result of setting the Gravatar avatar. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.patch( - path='/{option}', - summary='更新用户设定', - description='Update user settings.', - dependencies=[Depends(auth_required)], -) -def router_user_settings_patch(option: str) -> models.ResponseBase: - """ - Update user settings. - - Args: - option (str): The setting option to update. - - Returns: - dict: A dictionary containing the result of the settings update. - """ - http_exceptions.raise_not_implemented() - -@user_settings_router.get( - path='/2fa', - summary='获取两步验证初始化信息', - description='Get two-factor authentication initialization information.', - dependencies=[Depends(auth_required)], -) -async def router_user_settings_2fa( - user: Annotated[models.user.User, Depends(auth_required)], -) -> models.ResponseBase: - """ - Get two-factor authentication initialization information. - - Returns: - dict: A dictionary containing two-factor authentication setup information. - """ - - return models.ResponseBase( - data=await Password.generate_totp(user.username) - ) - -@user_settings_router.post( - path='/2fa', - summary='启用两步验证', - description='Enable two-factor authentication.', - dependencies=[Depends(auth_required)], -) -async def router_user_settings_2fa_enable( - session: SessionDep, - user: Annotated[models.user.User, Depends(auth_required)], - setup_token: str, - code: str, -) -> models.ResponseBase: - """ - Enable two-factor authentication for the user. - - Returns: - dict: A dictionary containing the result of enabling two-factor authentication. - """ - - serializer = URLSafeTimedSerializer(SECRET_KEY) - - try: - # 1. 解包 Token,设置有效期(例如 600秒) - secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600) - except SignatureExpired: - raise HTTPException(status_code=400, detail="Setup session expired") - except BadSignature: - raise HTTPException(status_code=400, detail="Invalid token") - - # 2. 验证用户输入的 6 位验证码 - if not Password.verify_totp(secret, code): - raise HTTPException(status_code=400, detail="Invalid OTP code") - - # 3. 将 secret 存储到用户的数据库记录中,启用 2FA - user.two_factor = secret - user = await user.save(session) - - return models.ResponseBase( - data={"message": "Two-factor authentication enabled successfully"} - ) \ No newline at end of file + http_exceptions.raise_not_implemented() \ No newline at end of file diff --git a/routers/api/v1/user/settings/__init__.py b/routers/api/v1/user/settings/__init__.py new file mode 100644 index 0000000..7511af4 --- /dev/null +++ b/routers/api/v1/user/settings/__init__.py @@ -0,0 +1,203 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired + +import sqlmodels +from middleware.auth import auth_required +from middleware.dependencies import SessionDep +from utils import JWT, Password, http_exceptions + +user_settings_router = APIRouter( + prefix='/settings', + tags=["user", "user_settings"], + dependencies=[Depends(auth_required)], +) + + +@user_settings_router.get( + path='/policies', + summary='获取用户可选存储策略', + description='Get user selectable storage policies.', +) +def router_user_settings_policies() -> sqlmodels.ResponseBase: + """ + Get user selectable storage policies. + + Returns: + dict: A dictionary containing available storage policies for the user. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.get( + path='/nodes', + summary='获取用户可选节点', + description='Get user selectable nodes.', + dependencies=[Depends(auth_required)], +) +def router_user_settings_nodes() -> sqlmodels.ResponseBase: + """ + Get user selectable nodes. + + Returns: + dict: A dictionary containing available nodes for the user. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.get( + path='/tasks', + summary='任务队列', + description='Get user task queue.', + dependencies=[Depends(auth_required)], +) +def router_user_settings_tasks() -> sqlmodels.ResponseBase: + """ + Get user task queue. + + Returns: + dict: A dictionary containing the user's task queue information. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.get( + path='/', + summary='获取当前用户设定', + description='Get current user settings.', +) +def router_user_settings( + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> sqlmodels.UserSettingResponse: + """ + Get current user settings. + + Returns: + dict: A dictionary containing the current user settings. + """ + return sqlmodels.UserSettingResponse( + id=user.id, + email=user.email, + nickname=user.nickname, + created_at=user.created_at, + group_name=user.group.name, + language=user.language, + timezone=user.timezone, + group_expires=user.group_expires, + two_factor=user.two_factor is not None, + ) + + +@user_settings_router.post( + path='/avatar', + summary='从文件上传头像', + description='Upload user avatar from file.', + dependencies=[Depends(auth_required)], +) +def router_user_settings_avatar() -> sqlmodels.ResponseBase: + """ + Upload user avatar from file. + + Returns: + dict: A dictionary containing the result of the avatar upload. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.put( + path='/avatar', + summary='设定为Gravatar头像', + description='Set user avatar to Gravatar.', + dependencies=[Depends(auth_required)], +) +def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase: + """ + Set user avatar to Gravatar. + + Returns: + dict: A dictionary containing the result of setting the Gravatar avatar. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.patch( + path='/{option}', + summary='更新用户设定', + description='Update user settings.', + dependencies=[Depends(auth_required)], +) +def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase: + """ + Update user settings. + + Args: + option (str): The setting option to update. + + Returns: + dict: A dictionary containing the result of the settings update. + """ + http_exceptions.raise_not_implemented() + + +@user_settings_router.get( + path='/2fa', + summary='获取两步验证初始化信息', + description='Get two-factor authentication initialization information.', + dependencies=[Depends(auth_required)], +) +async def router_user_settings_2fa( + user: Annotated[sqlmodels.user.User, Depends(auth_required)], +) -> sqlmodels.ResponseBase: + """ + Get two-factor authentication initialization information. + + Returns: + dict: A dictionary containing two-factor authentication setup information. + """ + + return sqlmodels.ResponseBase( + data=await Password.generate_totp(user.email) + ) + + +@user_settings_router.post( + path='/2fa', + summary='启用两步验证', + description='Enable two-factor authentication.', + dependencies=[Depends(auth_required)], +) +async def router_user_settings_2fa_enable( + session: SessionDep, + user: Annotated[sqlmodels.user.User, Depends(auth_required)], + setup_token: str, + code: str, +) -> sqlmodels.ResponseBase: + """ + Enable two-factor authentication for the user. + + Returns: + dict: A dictionary containing the result of enabling two-factor authentication. + """ + + serializer = URLSafeTimedSerializer(JWT.SECRET_KEY) + + try: + # 1. 解包 Token,设置有效期(例如 600秒) + secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600) + except SignatureExpired: + raise HTTPException(status_code=400, detail="Setup session expired") + except BadSignature: + raise HTTPException(status_code=400, detail="Invalid token") + + # 2. 验证用户输入的 6 位验证码 + if not Password.verify_totp(secret, code): + raise HTTPException(status_code=400, detail="Invalid OTP code") + + # 3. 将 secret 存储到用户的数据库记录中,启用 2FA + user.two_factor = secret + user = await user.save(session) + + return sqlmodels.ResponseBase( + data={"message": "Two-factor authentication enabled successfully"} + ) \ No newline at end of file diff --git a/routers/api/v1/vas/__init__.py b/routers/api/v1/vas/__init__.py index c2c26ef..6b43b6f 100644 --- a/routers/api/v1/vas/__init__.py +++ b/routers/api/v1/vas/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from middleware.auth import auth_required -from models import ResponseBase +from sqlmodels import ResponseBase from utils import http_exceptions vas_router = APIRouter( diff --git a/routers/api/v1/webdav/__init__.py b/routers/api/v1/webdav/__init__.py index d9047a9..270e2ce 100644 --- a/routers/api/v1/webdav/__init__.py +++ b/routers/api/v1/webdav/__init__.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends from middleware.auth import auth_required -from models import ResponseBase +from sqlmodels import ResponseBase from utils import http_exceptions # WebDAV 管理路由 diff --git a/service/storage/local_storage.py b/service/storage/local_storage.py index 7eb16ca..f40b771 100644 --- a/service/storage/local_storage.py +++ b/service/storage/local_storage.py @@ -15,7 +15,7 @@ import aiofiles import aiofiles.os from loguru import logger as l -from models.policy import Policy +from sqlmodels.policy import Policy from .exceptions import ( DirectoryCreationError, FileReadError, diff --git a/service/storage/naming_rule.py b/service/storage/naming_rule.py index beb823a..dd7d873 100644 --- a/service/storage/naming_rule.py +++ b/service/storage/naming_rule.py @@ -23,7 +23,7 @@ import string from datetime import datetime from uuid import UUID, uuid4 -from models.base import SQLModelBase +from sqlmodels.base import SQLModelBase class NamingContext(SQLModelBase): diff --git a/service/user/login.py b/service/user/login.py index 5ee1dec..a7521c7 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -3,7 +3,7 @@ from uuid import uuid4 from loguru import logger from middleware.dependencies import SessionDep -from models import LoginRequest, TokenResponse, User +from sqlmodels import LoginRequest, TokenResponse, User from utils import http_exceptions from utils.JWT import create_access_token, create_refresh_token from utils.password.pwd import Password, PasswordStatus @@ -30,17 +30,17 @@ async def login( # is_captcha_required = captcha_setting and captcha_setting.value == "1" # 获取用户信息 - current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore + current_user: User = await User.get(session, User.email == login_request.email, fetch_mode="first") #type: ignore # 验证用户是否存在 if not current_user: - logger.debug(f"Cannot find user with username: {login_request.username}") - http_exceptions.raise_unauthorized("Invalid username or password") + logger.debug(f"Cannot find user with email: {login_request.email}") + http_exceptions.raise_unauthorized("Invalid email or password") # 验证密码是否正确 if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID: - logger.debug(f"Password verification failed for user: {login_request.username}") - http_exceptions.raise_unauthorized("Invalid username or password") + logger.debug(f"Password verification failed for user: {login_request.email}") + http_exceptions.raise_unauthorized("Invalid email or password") # 验证用户是否可登录 if not current_user.status: @@ -50,23 +50,23 @@ async def login( if current_user.two_factor: # 用户已启用两步验证 if not login_request.two_fa_code: - logger.debug(f"2FA required for user: {login_request.username}") + logger.debug(f"2FA required for user: {login_request.email}") http_exceptions.raise_precondition_required("2FA required") # 验证 OTP 码 if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: - logger.debug(f"Invalid 2FA code for user: {login_request.username}") + logger.debug(f"Invalid 2FA code for user: {login_request.email}") http_exceptions.raise_unauthorized("Invalid 2FA code") # 创建令牌 - access_token = create_access_token(data={ - 'sub': str(current_user.id), - 'jti': str(uuid4()) - }) - refresh_token = create_refresh_token(data={ - 'sub': str(current_user.id), - 'jti': str(uuid4()) - }) + access_token = create_access_token( + sub=current_user.id, + jti=uuid4() + ) + refresh_token = create_refresh_token( + sub=current_user.id, + jti=uuid4() + ) return TokenResponse( access_token=access_token.access_token, diff --git a/models/README.md b/sqlmodels/README.md similarity index 100% rename from models/README.md rename to sqlmodels/README.md diff --git a/models/__init__.py b/sqlmodels/__init__.py similarity index 93% rename from models/__init__.py rename to sqlmodels/__init__.py index f8431d5..2faa072 100644 --- a/models/__init__.py +++ b/sqlmodels/__init__.py @@ -1,11 +1,14 @@ from .user import ( + BatchDeleteRequest, LoginRequest, + RefreshTokenRequest, RegisterRequest, AccessTokenBase, RefreshTokenBase, TokenResponse, User, UserBase, + UserStorageResponse, UserPublic, UserResponse, UserSettingResponse, @@ -66,6 +69,7 @@ from .object import ( FileBanRequest, ) from .physical_file import PhysicalFile, PhysicalFileBase +from .uri import DiskNextURI, FileSystemNamespace from .order import Order, OrderStatus, OrderType from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary from .redeem import Redeem, RedeemType @@ -82,7 +86,7 @@ from .tag import Tag, TagType from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary from .webdav import WebDAV -from .database import engine, get_session +from .database_connection import DatabaseManager from .model_base import ( MCPBase, diff --git a/models/base/README.md b/sqlmodels/base/README.md similarity index 99% rename from models/base/README.md rename to sqlmodels/base/README.md index 346771f..9ff36ce 100644 --- a/models/base/README.md +++ b/sqlmodels/base/README.md @@ -630,7 +630,7 @@ For developers modifying this module: - Handles Python 3.14 annotations via `get_type_hints()` **Metaclass processing order**: -1. Check if class should be a table (`_is_table_mixin`) +1. Check if class should be a table (`_has_table_mixin`) 2. Collect `__mapper_args__` from kwargs and explicit dict 3. Process `table_args`, `table_name`, `abstract` parameters 4. Resolve annotations using `get_type_hints()` diff --git a/models/base/__init__.py b/sqlmodels/base/__init__.py similarity index 81% rename from models/base/__init__.py rename to sqlmodels/base/__init__.py index 4744778..91e3cb6 100644 --- a/models/base/__init__.py +++ b/sqlmodels/base/__init__.py @@ -5,8 +5,8 @@ SQLModel 基础模块 - SQLModelBase: 所有 SQLModel 类的基类(真正的基类) 注意: - TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin + TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin 为了避免循环导入,此处不再重新导出它们 - 请直接从 models.mixin 导入这些类 + 请直接从 sqlmodels.mixin 导入这些类 """ from .sqlmodel_base import SQLModelBase diff --git a/models/base/sqlmodel_base.py b/sqlmodels/base/sqlmodel_base.py similarity index 99% rename from models/base/sqlmodel_base.py rename to sqlmodels/base/sqlmodel_base.py index 22bee25..e07b90c 100644 --- a/models/base/sqlmodel_base.py +++ b/sqlmodels/base/sqlmodel_base.py @@ -414,7 +414,7 @@ class __DeclarativeMeta(SQLModelMetaclass): def __new__(cls, name, bases, attrs, **kwargs): # 1. 约定优于配置:自动设置 table=True - is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases) + is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases) if is_intended_as_table and 'table' not in kwargs: kwargs['table'] = True diff --git a/models/color.py b/sqlmodels/color.py similarity index 100% rename from models/color.py rename to sqlmodels/color.py diff --git a/models/database.py b/sqlmodels/database.py similarity index 100% rename from models/database.py rename to sqlmodels/database.py diff --git a/sqlmodels/database_connection.py b/sqlmodels/database_connection.py new file mode 100644 index 0000000..a4efa2e --- /dev/null +++ b/sqlmodels/database_connection.py @@ -0,0 +1,78 @@ +from typing import AsyncGenerator, ClassVar + +from loguru import logger +from sqlalchemy import NullPool, AsyncAdaptedQueuePool +from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine +from sqlalchemy.orm import sessionmaker +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + + +class DatabaseManager: + engine: ClassVar[AsyncEngine | None] = None + _async_session_factory: ClassVar[sessionmaker | None] = None + + @classmethod + async def get_session(cls) -> AsyncGenerator[AsyncSession]: + assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()" + async with cls._async_session_factory() as session: + yield session + + @classmethod + async def init( + cls, + database_url: str, + debug: bool = False, + ): + """ + 初始化数据库连接引擎。 + + :param database_url: 数据库连接URL + :param debug: 是否开启调试模式 + """ + # 构建引擎参数 + engine_kwargs: dict = { + 'echo': debug, + 'future': True, + } + + if debug: + # Debug 模式使用 NullPool(无连接池,每次创建新连接) + engine_kwargs['poolclass'] = NullPool + else: + # 生产模式使用 AsyncAdaptedQueuePool 连接池 + engine_kwargs.update({ + 'poolclass': AsyncAdaptedQueuePool, + 'pool_size': 40, + 'max_overflow': 80, + 'pool_timeout': 30, + 'pool_recycle': 1800, + 'pool_pre_ping': True, + }) + + # 只在需要时添加 connect_args + if database_url.startswith("sqlite"): + engine_kwargs['connect_args'] = {'check_same_thread': False} + + cls.engine = create_async_engine(database_url, **engine_kwargs) + + cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession) + + # 开发阶段直接 create_all 创建表结构 + async with cls.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + logger.info("数据库引擎初始化完成") + + @classmethod + async def close(cls): + """ + 优雅地关闭数据库连接引擎。 + 仅应在应用结束时调用。 + """ + if cls.engine: + logger.info("正在关闭数据库连接引擎...") + await cls.engine.dispose() + logger.info("数据库连接引擎已成功关闭。") + else: + logger.info("数据库连接引擎未初始化,无需关闭。") diff --git a/models/download.py b/sqlmodels/download.py similarity index 100% rename from models/download.py rename to sqlmodels/download.py diff --git a/models/group.py b/sqlmodels/group.py similarity index 100% rename from models/group.py rename to sqlmodels/group.py diff --git a/models/migration.py b/sqlmodels/migration.py similarity index 94% rename from models/migration.py rename to sqlmodels/migration.py index eb81ec3..63ab785 100644 --- a/models/migration.py +++ b/sqlmodels/migration.py @@ -104,9 +104,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA), Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA), Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA), - Setting(name="captcha_IsUseReCaptcha", value="0", type=SettingsType.CAPTCHA), - Setting(name="captcha_ReCaptchaKey", value="defaultKey", type=SettingsType.CAPTCHA), - Setting(name="captcha_ReCaptchaSecret", value="defaultSecret", type=SettingsType.CAPTCHA), + Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA), + Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA), + Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA), + Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA), + Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA), Setting(name="thumb_width", value="400", type=SettingsType.THUMB), Setting(name="thumb_height", value="300", type=SettingsType.THUMB), Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA), @@ -119,11 +121,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti async def init_default_settings() -> None: from .setting import Setting - from .database import get_session + from .database_connection import DatabaseManager log.info('初始化设置...') - async for session in get_session(): + async for session in DatabaseManager.get_session(): # 检查是否已经存在版本设置 ver = await Setting.get( session, @@ -139,11 +141,11 @@ async def init_default_group() -> None: from .group import Group, GroupOptions from .policy import Policy, GroupPolicyLink from .setting import Setting - from .database import get_session + from .database_connection import DatabaseManager log.info('初始化用户组...') - async for session in get_session(): + async for session in DatabaseManager.get_session(): # 获取默认存储策略 default_policy = await Policy.get(session, Policy.name == "本地存储") default_policy_id = default_policy.id if default_policy else None @@ -231,13 +233,20 @@ async def init_default_user() -> None: from .group import Group from .object import Object, ObjectType from .policy import Policy - from .database import get_session + from .database_connection import DatabaseManager log.info('初始化管理员用户...') - async for session in get_session(): - # 检查管理员用户是否存在 - admin_user = await User.get(session, User.username == "admin") + async for session in DatabaseManager.get_session(): + # 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断) + admin_id_setting = await Setting.get( + session, + (Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id") + ) + admin_user = None + if admin_id_setting and admin_id_setting.value: + from uuid import UUID + admin_user = await User.get(session, User.id == UUID(admin_id_setting.value)) if not admin_user: # 获取管理员组 @@ -256,18 +265,24 @@ async def init_default_user() -> None: hashed_admin_password = Password.hash(admin_password) admin_user = User( - username="admin", + email="admin@disknext.local", nickname="admin", group_id=admin_group.id, password=hashed_admin_password, ) admin_user_id = admin_user.id # 在 save 前保存 UUID - admin_username = admin_user.username await admin_user.save(session) - # 为管理员创建根目录(使用用户名作为目录名) + # 记录默认管理员 ID 到 Setting + await Setting( + name="default_admin_id", + value=str(admin_user_id), + type=SettingsType.AUTH, + ).save(session) + + # 为管理员创建根目录 await Object( - name=admin_username, + name="/", type=ObjectType.FOLDER, owner_id=admin_user_id, parent_id=None, @@ -275,18 +290,18 @@ async def init_default_user() -> None: ).save(session) log.warning('请注意,账号密码仅显示一次,请妥善保管') - log.info(f'初始管理员账号: admin') + log.info(f'初始管理员邮箱: admin@disknext.local') log.info(f'初始管理员密码: {admin_password}') async def init_default_policy() -> None: from .policy import Policy, PolicyType - from .database import get_session + from .database_connection import DatabaseManager from service.storage import LocalStorageService log.info('初始化默认存储策略...') - async for session in get_session(): + async for session in DatabaseManager.get_session(): # 检查默认存储策略是否存在 default_policy = await Policy.get(session, Policy.name == "本地存储") diff --git a/models/mixin/README.md b/sqlmodels/mixin/README.md similarity index 100% rename from models/mixin/README.md rename to sqlmodels/mixin/README.md diff --git a/models/mixin/__init__.py b/sqlmodels/mixin/__init__.py similarity index 65% rename from models/mixin/__init__.py rename to sqlmodels/mixin/__init__.py index 1ad01e7..832828a 100644 --- a/models/mixin/__init__.py +++ b/sqlmodels/mixin/__init__.py @@ -5,42 +5,58 @@ SQLModel Mixin模块 包含: - polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin) +- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError) - table: 表基类(TableBaseMixin, UUIDTableBaseMixin) - table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest) +- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations) - jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入 - info_response: InfoResponse DTO的id/时间戳Mixin 导入顺序很重要,避免循环导入: 1. polymorphic(只依赖 SQLModelBase) -2. table(依赖 polymorphic) +2. optimistic_lock(只依赖 SQLAlchemy) +3. table(依赖 polymorphic 和 optimistic_lock) +4. relation_preload(只依赖 SQLModelBase) 注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig, 而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。 """ # polymorphic 必须先导入 from .polymorphic import ( - create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin, + create_subclass_id_mixin, + register_sti_column_properties_for_all_subclasses, + register_sti_columns_for_all_subclasses, ) -# table 依赖 polymorphic +# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前 +from .optimistic_lock import ( + OptimisticLockError, + OptimisticLockMixin, +) +# table 依赖 polymorphic 和 optimistic_lock from .table import ( - TableBaseMixin, - UUIDTableBaseMixin, - TimeFilterRequest, - PaginationRequest, - TableViewRequest, ListResponse, + PaginationRequest, T, + TableBaseMixin, + TableViewRequest, + TimeFilterRequest, + UUIDTableBaseMixin, now, now_date, ) +# relation_preload 只依赖 SQLModelBase +from .relation_preload import ( + RelationPreloadMixin, + requires_relations, +) # jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt) # 需要时直接从 sqlmodels.mixin.jwt 导入 from .info_response import ( - IntIdInfoMixin, - UUIDIdInfoMixin, DatetimeInfoMixin, IntIdDatetimeInfoMixin, + IntIdInfoMixin, UUIDIdDatetimeInfoMixin, + UUIDIdInfoMixin, ) diff --git a/models/mixin/info_response.py b/sqlmodels/mixin/info_response.py similarity index 96% rename from models/mixin/info_response.py rename to sqlmodels/mixin/info_response.py index 647b9a3..f1e053e 100644 --- a/models/mixin/info_response.py +++ b/sqlmodels/mixin/info_response.py @@ -12,7 +12,7 @@ InfoResponse DTO Mixin模块 from datetime import datetime from uuid import UUID -from models.base import SQLModelBase +from sqlmodels.base import SQLModelBase class IntIdInfoMixin(SQLModelBase): diff --git a/sqlmodels/mixin/optimistic_lock.py b/sqlmodels/mixin/optimistic_lock.py new file mode 100644 index 0000000..c9b7da5 --- /dev/null +++ b/sqlmodels/mixin/optimistic_lock.py @@ -0,0 +1,90 @@ +""" +乐观锁 Mixin + +提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。 + +乐观锁适用场景: +- 涉及"状态转换"的表(如:待支付 -> 已支付) +- 涉及"数值变动"的表(如:余额、库存) + +不适用场景: +- 日志表、纯插入表、低价值统计表 +- 能用 UPDATE table SET col = col + 1 解决的简单计数问题 + +使用示例: + class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True): + status: OrderStatusEnum + amount: Decimal + + # save/update 时自动检查版本号 + # 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError + try: + order = await order.save(session) + except OptimisticLockError as e: + # 处理冲突:重新查询并重试,或报错给用户 + l.warning(f"乐观锁冲突: {e}") +""" +from typing import ClassVar + +from sqlalchemy.orm.exc import StaleDataError + + +class OptimisticLockError(Exception): + """ + 乐观锁冲突异常 + + 当 save/update 操作检测到版本号不匹配时抛出。 + 这意味着在读取和写入之间,其他事务已经修改了该记录。 + + Attributes: + model_class: 发生冲突的模型类名 + record_id: 记录 ID(如果可用) + expected_version: 期望的版本号(如果可用) + original_error: 原始的 StaleDataError + """ + + def __init__( + self, + message: str, + model_class: str | None = None, + record_id: str | None = None, + expected_version: int | None = None, + original_error: StaleDataError | None = None, + ): + super().__init__(message) + self.model_class = model_class + self.record_id = record_id + self.expected_version = expected_version + self.original_error = original_error + + +class OptimisticLockMixin: + """ + 乐观锁 Mixin + + 使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。 + 每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改), + session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。 + + 原理: + 1. 每条记录有一个 version 字段,初始值为 0 + 2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似: + UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ? + 3. 如果 WHERE 条件不匹配(version 已被其他事务修改), + UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError + + 继承顺序: + OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前: + class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True): + ... + + 配套重试: + 如果加了乐观锁,业务层需要处理 OptimisticLockError: + - 报错给用户:"数据已被修改,请刷新后重试" + - 自动重试:重新查询最新数据并再次尝试 + """ + _has_optimistic_lock: ClassVar[bool] = True + """标记此类启用了乐观锁""" + + version: int = 0 + """乐观锁版本号,每次更新自动递增""" diff --git a/sqlmodels/mixin/polymorphic.py b/sqlmodels/mixin/polymorphic.py new file mode 100644 index 0000000..ba67275 --- /dev/null +++ b/sqlmodels/mixin/polymorphic.py @@ -0,0 +1,710 @@ +""" +联表继承(Joined Table Inheritance)的通用工具 + +提供用于简化SQLModel多态表设计的辅助函数和Mixin。 + +Usage Example: + + from sqlmodels.base import SQLModelBase + from sqlmodels.mixin import UUIDTableBaseMixin + from sqlmodels.mixin.polymorphic import ( + PolymorphicBaseMixin, + create_subclass_id_mixin, + AutoPolymorphicIdentityMixin + ) + + # 1. 定义Base类(只有字段,无表) + class ASRBase(SQLModelBase): + name: str + \"\"\"配置名称\"\"\" + + base_url: str + \"\"\"服务地址\"\"\" + + # 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin + class ASR( + ASRBase, + UUIDTableBaseMixin, + PolymorphicBaseMixin, + ABC + ): + \"\"\"ASR配置的抽象基类\"\"\" + # PolymorphicBaseMixin 自动提供: + # - _polymorphic_name 字段 + # - polymorphic_on='_polymorphic_name' + # - polymorphic_abstract=True(当有抽象方法时) + + # 3. 为第二层子类创建ID Mixin + ASRSubclassIdMixin = create_subclass_id_mixin('asr') + + # 4. 创建第二层抽象类(如果需要) + class FunASR( + ASRSubclassIdMixin, + ASR, + AutoPolymorphicIdentityMixin, + polymorphic_abstract=True + ): + \"\"\"FunASR的抽象基类,可能有多个实现\"\"\" + pass + + # 5. 创建具体实现类 + class FunASRLocal(FunASR, table=True): + \"\"\"FunASR本地部署版本\"\"\" + # polymorphic_identity 会自动设置为 'asr.funasrlocal' + pass + + # 6. 获取所有具体子类(用于 selectin_polymorphic) + concrete_asrs = ASR.get_concrete_subclasses() + # 返回 [FunASRLocal, ...] +""" +import uuid +from abc import ABC +from uuid import UUID + +from loguru import logger as l +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from sqlalchemy import Column, String, inspect +from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column +from sqlalchemy.orm.attributes import InstrumentedAttribute +from sqlmodel import Field +from sqlmodel.main import get_column_from_field + +from sqlmodels.base.sqlmodel_base import SQLModelBase + +# 用于延迟注册 STI 子类列的队列 +# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理 +_sti_subclasses_to_register: list[type] = [] + + +def register_sti_columns_for_all_subclasses() -> None: + """ + 为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表) + + 此函数应在 configure_mappers() 之前调用。 + 将 STI 子类的字段添加到父表的 metadata 中。 + 同时修复被 Column 对象污染的 model_fields。 + """ + for cls in _sti_subclasses_to_register: + try: + cls._register_sti_columns() + except Exception as e: + l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}") + + # 修复被 Column 对象污染的 model_fields + # 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生 + try: + _fix_polluted_model_fields(cls) + except Exception as e: + l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}") + + +def register_sti_column_properties_for_all_subclasses() -> None: + """ + 为所有已注册的 STI 子类添加列属性到 mapper(第二阶段) + + 此函数应在 configure_mappers() 之后调用。 + 将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。 + """ + for cls in _sti_subclasses_to_register: + try: + cls._register_sti_column_properties() + except Exception as e: + l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}") + + # 清空队列 + _sti_subclasses_to_register.clear() + + +def _fix_polluted_model_fields(cls: type) -> None: + """ + 修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields + + 当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute + 或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields + 时错误地使用这些 SQLAlchemy 对象作为默认值。 + + 此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。 + + :param cls: 要修复的类 + """ + if not hasattr(cls, 'model_fields'): + return + + def find_original_field_info(field_name: str) -> FieldInfo | None: + """从 MRO 中查找字段的原始定义(未被污染的)""" + for base in cls.__mro__[1:]: # 跳过自己 + if hasattr(base, 'model_fields') and field_name in base.model_fields: + field_info = base.model_fields[field_name] + # 跳过被 InstrumentedAttribute 或 Column 污染的 + if not isinstance(field_info.default, (InstrumentedAttribute, Column)): + return field_info + return None + + for field_name, current_field in cls.model_fields.items(): + # 检查是否被污染(default 是 InstrumentedAttribute 或 Column) + # Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True + # 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值 + if not isinstance(current_field.default, (InstrumentedAttribute, Column)): + continue # 未被污染,跳过 + + # 从父类查找原始定义 + original = find_original_field_info(field_name) + if original is None: + continue # 找不到原始定义,跳过 + + # 根据原始定义的 default/default_factory 来修复 + if original.default_factory: + # 有 default_factory(如 uuid.uuid4, now) + new_field = FieldInfo( + default_factory=original.default_factory, + annotation=current_field.annotation, + json_schema_extra=current_field.json_schema_extra, + ) + elif original.default is not PydanticUndefined: + # 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined + # PydanticUndefined 表示字段没有默认值(必填) + new_field = FieldInfo( + default=original.default, + annotation=current_field.annotation, + json_schema_extra=current_field.json_schema_extra, + ) + else: + continue # 既没有 default_factory 也没有有效的 default,跳过 + + # 复制 SQLModel 特有的属性 + if hasattr(current_field, 'foreign_key'): + new_field.foreign_key = current_field.foreign_key + if hasattr(current_field, 'primary_key'): + new_field.primary_key = current_field.primary_key + + cls.model_fields[field_name] = new_field + + +def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']: + """ + 动态创建SubclassIdMixin类 + + 在联表继承中,子类需要一个外键指向父表的主键。 + 此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。 + + Args: + parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function') + + Returns: + 一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4) + + Example: + >>> ASRSubclassIdMixin = create_subclass_id_mixin('asr') + >>> class FunASR(ASRSubclassIdMixin, ASR, table=True): + ... pass + + Note: + - 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id + - 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase) + - 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin) + """ + if not parent_table_name: + raise ValueError("parent_table_name 不能为空") + + # 转换为PascalCase作为类名 + class_name_parts = parent_table_name.split('_') + class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin' + + # 使用闭包捕获parent_table_name + _parent_table_name = parent_table_name + + # 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields + class SubclassIdMixin(SQLModelBase): + # 定义id字段 + id: UUID = Field( + default_factory=uuid.uuid4, + foreign_key=f'{_parent_table_name}.id', + primary_key=True, + ) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + """ + Pydantic v2 的子类初始化钩子,在模型完全构建后调用 + + 修复联表继承中子类字段的 default_factory 丢失问题。 + SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段, + 导致 INSERT 语句中出现 `table.column` 引用而非实际值。 + """ + super().__pydantic_init_subclass__(**kwargs) + _fix_polluted_model_fields(cls) + + # 设置类名和文档 + SubclassIdMixin.__name__ = class_name + SubclassIdMixin.__qualname__ = class_name + SubclassIdMixin.__doc__ = f""" + {parent_table_name}子类的ID Mixin + + 用于{parent_table_name}的子类,提供外键指向父表。 + 通过MRO确保此id字段覆盖继承的id字段。 + """ + + return SubclassIdMixin + + +class AutoPolymorphicIdentityMixin: + """ + 自动生成polymorphic_identity的Mixin,并支持STI子类列注册 + + 使用此Mixin的类会自动根据类名生成polymorphic_identity。 + 格式:{parent_polymorphic_identity}.{classname_lowercase} + + 如果没有父类的polymorphic_identity,则直接使用类名小写。 + + **重要:数据库迁移注意事项** + + 编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀! + + 例如,对于以下继承链:: + + LLM (polymorphic_on='_polymorphic_name') + └── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm') + └── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm') + + 迁移脚本中设置 _polymorphic_name 时:: + + # ❌ 错误:缺少父类前缀 + UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id + + # ✅ 正确:包含完整的继承链前缀 + UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id + + **STI(单表继承)支持**: + 当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段 + 添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被 + 注册到父表的问题。 + + Example (JTI): + >>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True): + ... __polymorphic_name: str + ... + >>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True): + ... pass + ... # polymorphic_identity 会自动设置为 'function' + ... + >>> class CodeInterpreterFunction(Function, table=True): + ... pass + ... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction' + + Example (STI): + >>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True): + ... user_id: UUID + ... + >>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True): + ... upload_deadline: datetime | None = None # 自动添加到 userfile 表 + ... # polymorphic_identity 会自动设置为 'pendingfile' + + Note: + - 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留 + - 此Mixin应该在继承列表中靠后的位置(在表基类之前) + - STI模式下,新字段会在类定义时自动添加到父表的metadata中 + """ + + def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs): + """ + 子类化钩子,自动生成polymorphic_identity并处理STI列注册 + + Args: + polymorphic_identity: 如果手动指定,则使用指定的值 + **kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True) + """ + super().__init_subclass__(**kwargs) + + # 如果手动指定了polymorphic_identity,使用指定的值 + if polymorphic_identity is not None: + identity = polymorphic_identity + else: + # 自动生成polymorphic_identity + class_name = cls.__name__.lower() + + # 尝试从父类获取polymorphic_identity作为前缀 + parent_identity = None + for base in cls.__mro__[1:]: # 跳过自己 + if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict): + parent_identity = base.__mapper_args__.get('polymorphic_identity') + if parent_identity: + break + + # 构建identity + if parent_identity: + identity = f'{parent_identity}.{class_name}' + else: + identity = class_name + + # 设置到__mapper_args__ + if '__mapper_args__' not in cls.__dict__: + cls.__mapper_args__ = {} + + # 只在尚未设置polymorphic_identity时设置 + if 'polymorphic_identity' not in cls.__mapper_args__: + cls.__mapper_args__['polymorphic_identity'] = identity + + # 注册 STI 子类列的延迟执行 + # 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整 + # 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses() + _sti_subclasses_to_register.append(cls) + + @classmethod + def __pydantic_init_subclass__(cls, **kwargs): + """ + Pydantic v2 的子类初始化钩子,在模型完全构建后调用 + + 修复 STI 继承中子类字段被 Column 对象污染的问题。 + 当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时, + SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。 + """ + super().__pydantic_init_subclass__(**kwargs) + _fix_polluted_model_fields(cls) + + @classmethod + def _register_sti_columns(cls) -> None: + """ + 将STI子类的新字段注册到父表的列定义中 + + 检测当前类是否是STI子类(与父类共用同一张表), + 如果是,则将子类定义的新字段添加到父表的metadata中。 + + JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。 + """ + # 查找父表(在 MRO 中找到第一个有 __table__ 的父类) + parent_table = None + parent_fields: set[str] = set() + + for base in cls.__mro__[1:]: + if hasattr(base, '__table__') and base.__table__ is not None: + parent_table = base.__table__ + # 收集父类的所有字段名 + if hasattr(base, 'model_fields'): + parent_fields.update(base.model_fields.keys()) + break + + if parent_table is None: + return # 没有找到父表,可能是根类 + + # JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI + # JTI 类有自己独立的表,不需要将列注册到父表 + if hasattr(cls, '__table__') and cls.__table__ is not None: + if cls.__table__.name != parent_table.name: + return # JTI,跳过 STI 列注册 + + # 获取当前类的新字段(不在父类中的字段) + if not hasattr(cls, 'model_fields'): + return + + existing_columns = {col.name for col in parent_table.columns} + + for field_name, field_info in cls.model_fields.items(): + # 跳过从父类继承的字段 + if field_name in parent_fields: + continue + + # 跳过私有字段和ClassVar + if field_name.startswith('_'): + continue + + # 跳过已存在的列 + if field_name in existing_columns: + continue + + # 使用 SQLModel 的内置 API 创建列 + try: + column = get_column_from_field(field_info) + column.name = field_name + column.key = field_name + # STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值 + # Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段) + column.nullable = True + + # 将列添加到父表 + parent_table.append_column(column) + except Exception as e: + l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}") + + @classmethod + def _register_sti_column_properties(cls) -> None: + """ + 将 STI 子类的列作为 ColumnProperty 添加到 mapper + + 此方法在 configure_mappers() 之后调用,将已添加到表中的列 + 注册为 mapper 的属性,使 ORM 查询能正确识别这些列。 + + **重要**:子类的列属性会同时注册到子类和父类的 mapper 上。 + 这确保了查询父类时,SELECT 语句包含所有 STI 子类的列, + 避免在响应序列化时触发懒加载(MissingGreenlet 错误)。 + + JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。 + """ + # 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类) + parent_table = None + parent_class = None + for base in cls.__mro__[1:]: + if hasattr(base, '__table__') and base.__table__ is not None: + parent_table = base.__table__ + parent_class = base + break + + if parent_table is None: + return # 没有找到父表,可能是根类 + + # JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI + # JTI 类有自己独立的表,不需要将列属性注册到 mapper + if hasattr(cls, '__table__') and cls.__table__ is not None: + if cls.__table__.name != parent_table.name: + return # JTI,跳过 STI 列属性注册 + + # 获取子类和父类的 mapper + child_mapper = inspect(cls).mapper + parent_mapper = inspect(parent_class).mapper + local_table = child_mapper.local_table + + # 查找父类的所有字段名 + parent_fields: set[str] = set() + if hasattr(parent_class, 'model_fields'): + parent_fields.update(parent_class.model_fields.keys()) + + if not hasattr(cls, 'model_fields'): + return + + # 获取两个 mapper 已有的列属性 + child_existing_props = {p.key for p in child_mapper.column_attrs} + parent_existing_props = {p.key for p in parent_mapper.column_attrs} + + for field_name in cls.model_fields: + # 跳过从父类继承的字段 + if field_name in parent_fields: + continue + + # 跳过私有字段 + if field_name.startswith('_'): + continue + + # 检查表中是否有这个列 + if field_name not in local_table.columns: + continue + + column = local_table.columns[field_name] + + # 添加到子类的 mapper(如果尚不存在) + if field_name not in child_existing_props: + try: + prop = ColumnProperty(column) + child_mapper.add_property(field_name, prop) + except Exception as e: + l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}") + + # 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列) + if field_name not in parent_existing_props: + try: + prop = ColumnProperty(column) + parent_mapper.add_property(field_name, prop) + except Exception as e: + l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}") + + +class PolymorphicBaseMixin: + """ + 为联表继承链中的基类自动配置 polymorphic 设置的 Mixin + + 此 Mixin 自动设置以下内容: + - `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器 + - `_polymorphic_name: str`: 定义多态鉴别器字段(带索引) + - `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类 + + 使用场景: + 适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。 + + 用法示例: + ```python + from abc import ABC + from sqlmodels.mixin import UUIDTableBaseMixin + from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin + + # 定义基类 + class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC): + __tablename__ = 'mytool' + + # 不需要手动定义 _polymorphic_name + # 不需要手动设置 polymorphic_on + # 不需要手动设置 polymorphic_abstract + + # 定义子类 + class SpecificTool(MyTool): + __tablename__ = 'specifictool' + + # 会自动继承 polymorphic 配置 + ``` + + 自动行为: + 1. 定义 `_polymorphic_name: str` 字段(带索引) + 2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'` + 3. 自动检测抽象类: + - 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True + - 否则设置为 False + + 手动覆盖: + 可以在类定义时手动指定参数来覆盖自动行为: + ```python + class MyTool( + UUIDTableBaseMixin, + PolymorphicBaseMixin, + ABC, + polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name + polymorphic_abstract=False # 强制不设为抽象类 + ): + pass + ``` + + 注意事项: + - 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用 + - 适用于联表继承(joined table inheritance)场景 + - 子类会自动继承 _polymorphic_name 字段定义 + - 使用单下划线前缀是因为: + * SQLAlchemy 会映射单下划线字段为数据库列 + * Pydantic 将其视为私有属性,不参与序列化 + * 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列 + """ + + # 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段 + # + # 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column + # + # 为什么这样做: + # 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改 + # 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀) + # 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用 + # 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema) + # 5. 内部代码仍然可以正常访问和修改此字段 + # + # 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md + _polymorphic_name: Mapped[str] = mapped_column(String, index=True) + """ + 多态鉴别器字段,用于标识具体的子类类型 + + 注意:此字段使用单下划线前缀,表示内部使用。 + - ✅ 存储到数据库 + - ✅ 不出现在 API 序列化中 + - ✅ 防止外部直接修改 + """ + + def __init_subclass__( + cls, + polymorphic_on: str | None = None, + polymorphic_abstract: bool | None = None, + **kwargs + ): + """ + 在子类定义时自动配置 polymorphic 设置 + + Args: + polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。 + 设置为其他值可以使用不同的字段作为多态鉴别器。 + polymorphic_abstract: 是否为抽象类。 + - None: 自动检测(默认) + - True: 强制设为抽象类 + - False: 强制设为非抽象类 + **kwargs: 传递给父类的其他参数 + """ + super().__init_subclass__(**kwargs) + + # 初始化 __mapper_args__(如果还没有) + if '__mapper_args__' not in cls.__dict__: + cls.__mapper_args__ = {} + + # 设置 polymorphic_on(默认为 _polymorphic_name) + if 'polymorphic_on' not in cls.__mapper_args__: + cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name' + + # 自动检测或设置 polymorphic_abstract + if 'polymorphic_abstract' not in cls.__mapper_args__: + if polymorphic_abstract is None: + # 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类 + has_abc = ABC in cls.__mro__ + has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set())) + polymorphic_abstract = has_abc and has_abstract_methods + + cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract + + @classmethod + def _is_joined_table_inheritance(cls) -> bool: + """ + 检测当前类是否使用联表继承(Joined Table Inheritance) + + 通过检查子类是否有独立的表来判断: + - JTI: 子类有独立的 local_table(与父类不同) + - STI: 子类与父类共用同一个 local_table + + :return: True 表示 JTI,False 表示 STI 或无子类 + """ + mapper = inspect(cls) + base_table_name = mapper.local_table.name + + # 检查所有直接子类 + for subclass in cls.__subclasses__(): + sub_mapper = inspect(subclass) + # 如果任何子类有不同的表名,说明是 JTI + if sub_mapper.local_table.name != base_table_name: + return True + + return False + + @classmethod + def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]: + """ + 递归获取当前类的所有具体(非抽象)子类 + + 用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。 + 可在任意多态基类上调用,返回该类的所有非抽象子类。 + + :return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类) + """ + result: list[type[PolymorphicBaseMixin]] = [] + for subclass in cls.__subclasses__(): + # 使用 inspect() 获取 mapper 的公开属性 + # 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811) + mapper = inspect(subclass) + if not mapper.polymorphic_abstract: + result.append(subclass) + # 无论是否抽象,都需要递归(抽象类可能有具体子类) + if hasattr(subclass, 'get_concrete_subclasses'): + result.extend(subclass.get_concrete_subclasses()) + return result + + @classmethod + def get_polymorphic_discriminator(cls) -> str: + """ + 获取多态鉴别字段名 + + 使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。 + + :return: 多态鉴别字段名(如 '_polymorphic_name') + :raises ValueError: 如果类未配置 polymorphic_on + """ + polymorphic_on = inspect(cls).polymorphic_on + if polymorphic_on is None: + raise ValueError( + f"{cls.__name__} 未配置 polymorphic_on," + f"请确保正确继承 PolymorphicBaseMixin" + ) + return polymorphic_on.key + + @classmethod + def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]: + """ + 获取 polymorphic_identity 到具体子类的映射 + + 包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。 + + :return: identity 到子类的映射字典 + """ + result: dict[str, type[PolymorphicBaseMixin]] = {} + for subclass in cls.get_concrete_subclasses(): + identity = inspect(subclass).polymorphic_identity + if identity: + result[identity] = subclass + return result diff --git a/sqlmodels/mixin/relation_preload.py b/sqlmodels/mixin/relation_preload.py new file mode 100644 index 0000000..624018f --- /dev/null +++ b/sqlmodels/mixin/relation_preload.py @@ -0,0 +1,470 @@ +""" +关系预加载 Mixin + +提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。 + +设计原则: +- 按需加载:只加载被调用方法需要的关系 +- 增量加载:已加载的关系不重复加载 +- 查询最优:相同关系只查询一次,不同关系增量查询 +- 零侵入:调用方无需任何改动 +- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire + +使用方式: + from sqlmodels.mixin import RelationPreloadMixin, requires_relations + + class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True): + kling_video_generator: KlingO1Generator = Relationship(...) + + @requires_relations('kling_video_generator', KlingO1Generator.kling_o1) + async def cost(self, params, context, session) -> ToolCost: + # 自动加载,可以安全访问 + price = self.kling_video_generator.kling_o1.pro_price_per_second + ... + + # 调用方 - 无需任何改动 + await tool.cost(params, context, session) # 自动加载 cost 需要的关系 + await tool._call(...) # 关系相同则跳过,否则增量加载 + +支持 AsyncGenerator: + @requires_relations('twitter_api') + async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]: + yield ToolResponse(...) # 装饰器正确处理 async generator +""" +import inspect as python_inspect +from functools import wraps +from typing import Callable, TypeVar, ParamSpec, Any + +from loguru import logger as l +from sqlalchemy import inspect as sa_inspect +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlmodel.main import RelationshipInfo + +P = ParamSpec('P') +R = TypeVar('R') + + +def _extract_session( + func: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> AsyncSession | None: + """ + 从方法参数中提取 AsyncSession + + 按以下顺序查找: + 1. kwargs 中名为 'session' 的参数 + 2. 根据函数签名定位 'session' 参数的位置,从 args 提取 + 3. kwargs 中类型为 AsyncSession 的参数 + """ + # 1. 优先从 kwargs 查找 + if 'session' in kwargs: + return kwargs['session'] + + # 2. 从函数签名定位位置参数 + try: + sig = python_inspect.signature(func) + param_names = list(sig.parameters.keys()) + + if 'session' in param_names: + # 计算位置(减去 self) + idx = param_names.index('session') - 1 + if 0 <= idx < len(args): + return args[idx] + except (ValueError, TypeError): + pass + + # 3. 遍历 kwargs 找 AsyncSession 类型 + for value in kwargs.values(): + if isinstance(value, AsyncSession): + return value + + return None + + +def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool: + """ + 检查对象的关系是否已加载(独立函数版本) + + Args: + obj: 要检查的对象 + rel_name: 关系属性名 + + Returns: + True 如果关系已加载,False 如果未加载或已过期 + """ + try: + state = sa_inspect(obj) + return rel_name not in state.unloaded + except Exception: + return False + + +def _find_relation_to_class(from_class: type, to_class: type) -> str | None: + """ + 在类中查找指向目标类的关系属性名 + + Args: + from_class: 源类 + to_class: 目标类 + + Returns: + 关系属性名,如果找不到则返回 None + + Example: + _find_relation_to_class(KlingO1VideoFunction, KlingO1Generator) + # 返回 'kling_video_generator' + """ + for attr_name in dir(from_class): + try: + attr = getattr(from_class, attr_name, None) + if attr is None: + continue + # 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性) + # parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类 + if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'): + target_class = attr.property.mapper.class_ + if target_class == to_class: + return attr_name + except AttributeError: + continue + return None + + +def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]: + """ + 装饰器:声明方法需要的关系,自动按需增量加载 + + 参数格式: + - 字符串:本类属性名,如 'kling_video_generator' + - RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1 + + 行为: + - 方法调用时自动检查关系是否已加载 + - 未加载的关系会被增量加载(单次查询) + - 已加载的关系直接跳过 + + 支持: + - 普通 async 方法:`async def cost(...) -> ToolCost` + - AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]` + + Example: + @requires_relations('kling_video_generator', KlingO1Generator.kling_o1) + async def cost(self, params, context, session) -> ToolCost: + # self.kling_video_generator.kling_o1 已自动加载 + ... + + @requires_relations('twitter_api') + async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]: + yield ToolResponse(...) # AsyncGenerator 正确处理 + + 验证: + - 字符串格式的关系名在类创建时(__init_subclass__)验证 + - 拼写错误会在导入时抛出 AttributeError + """ + def decorator(func: Callable[P, R]) -> Callable[P, R]: + # 检测是否是 async generator 函数 + is_async_gen = python_inspect.isasyncgenfunction(func) + + if is_async_gen: + # AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator + @wraps(func) + async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R: + session = _extract_session(func, args, kwargs) + if session is not None: + await self._ensure_relations_loaded(session, relations) + # 委托给原始 async generator,逐个 yield 值 + async for item in func(self, *args, **kwargs): + yield item # type: ignore + else: + # 普通 async 函数:await 并返回结果 + @wraps(func) + async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R: + session = _extract_session(func, args, kwargs) + if session is not None: + await self._ensure_relations_loaded(session, relations) + return await func(self, *args, **kwargs) + + # 保存关系声明供验证和内省使用 + wrapper._required_relations = relations # type: ignore + return wrapper + + return decorator + + +class RelationPreloadMixin: + """ + 关系预加载 Mixin + + 提供按需增量加载能力,确保 SQL 查询数理论最优。 + + 特性: + - 按需加载:只加载被调用方法需要的关系 + - 增量加载:已加载的关系不重复加载 + - 原地更新:直接修改 self,无需替换实例 + - 导入时验证:字符串关系名在类创建时验证 + - Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire + """ + + def __init_subclass__(cls, **kwargs) -> None: + """类创建时验证所有 @requires_relations 声明""" + super().__init_subclass__(**kwargs) + + # 收集类及其父类的所有注解(包含普通字段) + all_annotations: set[str] = set() + for klass in cls.__mro__: + if hasattr(klass, '__annotations__'): + all_annotations.update(klass.__annotations__.keys()) + + # 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__) + sqlmodel_relationships: set[str] = set() + for klass in cls.__mro__: + if hasattr(klass, '__sqlmodel_relationships__'): + sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys()) + + # 合并所有可用的属性名 + all_available_names = all_annotations | sqlmodel_relationships + + for method_name in dir(cls): + if method_name.startswith('__'): + continue + + try: + method = getattr(cls, method_name, None) + except AttributeError: + continue + + if method is None or not hasattr(method, '_required_relations'): + continue + + # 验证字符串格式的关系名 + for spec in method._required_relations: + if isinstance(spec, str): + # 检查注解、Relationship 或已有属性 + if spec not in all_available_names and not hasattr(cls, spec): + raise AttributeError( + f"{cls.__name__}.{method_name} 声明了关系 '{spec}'," + f"但 {cls.__name__} 没有此属性" + ) + + def _is_relation_loaded(self, rel_name: str) -> bool: + """ + 检查关系是否真正已加载(基于 SQLAlchemy inspect) + + 使用 SQLAlchemy 的 inspect 检测真实加载状态, + 自动处理 commit 导致的 expire 问题。 + + Args: + rel_name: 关系属性名 + + Returns: + True 如果关系已加载,False 如果未加载或已过期 + """ + try: + state = sa_inspect(self) + # unloaded 包含未加载的关系属性名 + return rel_name not in state.unloaded + except Exception: + # 对象可能未被 SQLAlchemy 管理 + return False + + async def _ensure_relations_loaded( + self, + session: AsyncSession, + relations: tuple[str | RelationshipInfo, ...], + ) -> None: + """ + 确保指定关系已加载,只加载未加载的部分 + + 基于 SQLAlchemy inspect 检测真实状态,自动处理: + - 首次访问的关系 + - commit 后 expire 的关系 + - 嵌套关系(如 KlingO1Generator.kling_o1) + + Args: + session: 数据库会话 + relations: 需要的关系规格 + """ + # 找出真正未加载的关系(基于 SQLAlchemy inspect) + to_load: list[str | RelationshipInfo] = [] + # 区分直接关系和嵌套关系的 key + direct_keys: set[str] = set() # 本类的直接关系属性名 + nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名 + + for rel in relations: + if isinstance(rel, str): + # 直接关系:检查本类的关系是否已加载 + if not self._is_relation_loaded(rel): + to_load.append(rel) + direct_keys.add(rel) + else: + # 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1 + # 1. 查找指向父类的关系属性 + parent_class = rel.parent.class_ + parent_attr = _find_relation_to_class(self.__class__, parent_class) + + if parent_attr is None: + # 找不到路径,可能是配置错误,但仍尝试加载 + l.warning( + f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径," + f"无法检查 {rel.key} 是否已加载" + ) + to_load.append(rel) + continue + + # 2. 检查父对象是否已加载 + if not self._is_relation_loaded(parent_attr): + # 父对象未加载,需要同时加载父对象和嵌套关系 + if parent_attr not in direct_keys and parent_attr not in nested_parent_keys: + to_load.append(parent_attr) + nested_parent_keys.add(parent_attr) + to_load.append(rel) + else: + # 3. 父对象已加载,检查嵌套关系是否已加载 + parent_obj = getattr(self, parent_attr) + if not _is_obj_relation_loaded(parent_obj, rel.key): + # 嵌套关系未加载:需要同时传递父关系和嵌套关系 + # 因为 _build_load_chains 需要完整的链来构建 selectinload + if parent_attr not in direct_keys and parent_attr not in nested_parent_keys: + to_load.append(parent_attr) + nested_parent_keys.add(parent_attr) + to_load.append(rel) + + if not to_load: + return # 全部已加载,跳过 + + # 构建 load 参数 + load_options = self._specs_to_load_options(to_load) + if not load_options: + return + + # 安全地获取主键值(避免触发懒加载) + state = sa_inspect(self) + pk_tuple = state.key[1] if state.key else None + if pk_tuple is None: + l.warning(f"无法获取 {self.__class__.__name__} 的主键值") + return + # 主键是元组,取第一个值(假设单列主键) + pk_value = pk_tuple[0] + + # 单次查询加载缺失的关系 + fresh = await self.__class__.get( + session, + self.__class__.id == pk_value, + load=load_options, + ) + + if fresh is None: + l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在") + return + + # 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问) + all_direct_keys = direct_keys | nested_parent_keys + for key in all_direct_keys: + value = getattr(fresh, key, None) + object.__setattr__(self, key, value) + + def _specs_to_load_options( + self, + specs: list[str | RelationshipInfo], + ) -> list[RelationshipInfo]: + """ + 将关系规格转换为 load 参数 + + - 字符串 → cls.{name} + - RelationshipInfo → 直接使用 + """ + result: list[RelationshipInfo] = [] + + for spec in specs: + if isinstance(spec, str): + rel = getattr(self.__class__, spec, None) + if rel is not None: + result.append(rel) + else: + l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在") + else: + result.append(spec) + + return result + + # ==================== 可选的手动预加载 API ==================== + + @classmethod + def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]: + """ + 获取指定方法声明的关系(用于外部预加载场景) + + Args: + method_name: 方法名 + + Returns: + RelationshipInfo 列表 + """ + method = getattr(cls, method_name, None) + if method is None or not hasattr(method, '_required_relations'): + return [] + + result: list[RelationshipInfo] = [] + for spec in method._required_relations: + if isinstance(spec, str): + rel = getattr(cls, spec, None) + if rel: + result.append(rel) + else: + result.append(spec) + + return result + + @classmethod + def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]: + """ + 获取多个方法的关系并去重(用于批量预加载场景) + + Args: + method_names: 方法名列表 + + Returns: + 去重后的 RelationshipInfo 列表 + """ + seen: set[str] = set() + result: list[RelationshipInfo] = [] + + for method_name in method_names: + for rel in cls.get_relations_for_method(method_name): + key = rel.key + if key not in seen: + seen.add(key) + result.append(rel) + + return result + + async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin': + """ + 手动预加载指定方法的关系(可选优化 API) + + 当需要确保在调用方法前完成所有加载时使用。 + 通常情况下不需要调用此方法,装饰器会自动处理。 + + Args: + session: 数据库会话 + method_names: 方法名列表 + + Returns: + self(支持链式调用) + + Example: + # 可选:显式预加载(通常不需要) + tool = await tool.preload_for(session, 'cost', '_call') + """ + all_relations: list[str | RelationshipInfo] = [] + + for method_name in method_names: + method = getattr(self.__class__, method_name, None) + if method and hasattr(method, '_required_relations'): + all_relations.extend(method._required_relations) + + if all_relations: + await self._ensure_relations_loaded(session, tuple(all_relations)) + + return self diff --git a/models/mixin/table.py b/sqlmodels/mixin/table.py similarity index 64% rename from models/mixin/table.py rename to sqlmodels/mixin/table.py index 419f298..cd6c830 100644 --- a/models/mixin/table.py +++ b/sqlmodels/mixin/table.py @@ -12,7 +12,14 @@ mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin ↓ base/__init__.py ← 从 mixin 重新导出(保持向后兼容) + +维护须知: + 增删功能时必须更新 __version__ 字段(遵循语义化版本) + +版本历史: + 0.1.0 - delete() 方法支持条件删除(condition 参数) """ +__version__ = "0.1.0" import uuid from datetime import datetime from typing import TypeVar, Literal, override, Any, ClassVar, Generic @@ -26,16 +33,19 @@ from typing import TypeVar, Literal, override, Any, ClassVar, Generic # 未来: PR #1275合并后可改回继承SQLModelBase from pydantic import BaseModel, ConfigDict from fastapi import HTTPException -from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct +from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect from sqlalchemy.orm import selectinload, Relationship, with_polymorphic +from sqlalchemy.orm.exc import StaleDataError from sqlmodel import Field, select + +from .optimistic_lock import OptimisticLockError from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.sql._typing import _OnClauseArgument from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel.main import RelationshipInfo from .polymorphic import PolymorphicBaseMixin -from models.base.sqlmodel_base import SQLModelBase +from sqlmodels.base.sqlmodel_base import SQLModelBase # Type variables for generic type hints, improving code completion and analysis. T = TypeVar("T", bound="TableBaseMixin") @@ -196,8 +206,8 @@ class TableBaseMixin(AsyncAttrs): created_at (datetime): 记录创建时的时间戳, 自动设置. updated_at (datetime): 记录每次更新时的时间戳, 自动更新. """ - _is_table_mixin: ClassVar[bool] = True - """标记此类为表混入类的内部属性""" + _has_table_mixin: ClassVar[bool] = True + """标记此类继承了表混入类的内部属性""" def __init_subclass__(cls, **kwargs): """ @@ -218,7 +228,7 @@ class TableBaseMixin(AsyncAttrs): ) @classmethod - async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]: + async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]: """ 向数据库中添加一个新的或多个新的记录. @@ -230,8 +240,6 @@ class TableBaseMixin(AsyncAttrs): session (AsyncSession): 用于数据库操作的异步会话对象. instances (T | list[T]): 要添加的单个模型实例或模型实例列表. refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True. - commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, - 之后需要手动调用 `session.commit()`。默认为 True. Returns: T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例. @@ -246,11 +254,6 @@ class TableBaseMixin(AsyncAttrs): # 添加单个实例 item3 = Item(name="Cherry") added_item = await Item.add(session, item3) - - # 批量操作,减少提交次数 - await Item.add(session, [item1, item2], commit=False) - await Item.add(session, [item3, item4], commit=False) - await session.commit() """ is_list = False if isinstance(instances, list): @@ -259,10 +262,7 @@ class TableBaseMixin(AsyncAttrs): else: session.add(instances) - if commit: - await session.commit() - else: - await session.flush() + await session.commit() if refresh: if is_list: @@ -278,14 +278,16 @@ class TableBaseMixin(AsyncAttrs): session: AsyncSession, load: RelationshipInfo | list[RelationshipInfo] | None = None, refresh: bool = True, - commit: bool = True + commit: bool = True, + jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + optimistic_retry_count: int = 0, ) -> T: """ 保存(插入或更新)当前模型实例到数据库. 这是一个实例方法,它将当前对象添加到会话中并提交更改。 可以用于创建新记录或更新现有记录。还可以选择在保存后 - 预加载(eager load)一个或多个关联关系. + 预加载(eager load)一个关联关系. **重要**:调用此方法后,session中的所有对象都会过期(expired)。 如果需要继续使用该对象,必须使用返回值: @@ -298,13 +300,17 @@ class TableBaseMixin(AsyncAttrs): # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 await client.save(session, refresh=False) - # ✅ 正确:批量操作,减少提交次数 - await item1.save(session, commit=False) - await item2.save(session, commit=False) + # ✅ 正确:批量操作时延迟提交 + for item in items: + item = await item.save(session, commit=False) await session.commit() - # ✅ 正确:批量操作并预加载多个关联关系 - user = await user.save(session, load=[User.group, User.tags]) + # ✅ 正确:保存后需要访问多态关系时 + tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all') + return tool_set # tools 关系已正确加载子类数据 + + # ✅ 正确:启用乐观锁自动重试 + order = await order.save(session, optimistic_retry_count=3) # ❌ 错误:需要返回值但未使用 await client.save(session) @@ -313,34 +319,77 @@ class TableBaseMixin(AsyncAttrs): Args: session (AsyncSession): 用于数据库操作的异步会话对象. - load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性. - 可以是单个关系或关系列表. - 例如 `User.posts` 或 `[User.group, User.tags]`. + load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性. + 例如 `User.posts`. refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值, 设为 False 可节省一次数据库查询。默认为 True. - commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, - 之后需要手动调用 `session.commit()`。默认为 True. + commit (bool): 是否在保存后提交事务。如果为 False,只会 flush 获取 ID + 但不提交,适用于批量操作场景。默认为 True. + jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。 + - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表 + - 'all': 两阶段查询,只加载实际关联的子类 + - None(默认): 不使用多态加载 + optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。 + 重试时会重新查询最新数据,将当前修改合并后再次保存。 Returns: T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. + + Raises: + OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽 """ - session.add(self) - if commit: - await session.commit() - else: - await session.flush() + cls = type(self) + instance = self + retries_remaining = optimistic_retry_count + current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时 + + while True: + session.add(instance) + try: + if commit: + await session.commit() + else: + await session.flush() + break # 成功,退出循环 + except StaleDataError as e: + await session.rollback() + if retries_remaining <= 0: + raise OptimisticLockError( + message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改", + model_class=cls.__name__, + record_id=str(getattr(instance, 'id', None)), + expected_version=getattr(instance, 'version', None), + original_error=e, + ) from e + + # 失败后重试:重新查询最新数据并合并修改 + retries_remaining -= 1 + if current_data is None: + current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'}) + + fresh = await cls.get(session, cls.id == self.id) + if fresh is None: + raise OptimisticLockError( + message=f"{cls.__name__} 重试失败:记录已被删除", + model_class=cls.__name__, + record_id=str(getattr(self, 'id', None)), + original_error=e, + ) from e + + for key, value in current_data.items(): + if hasattr(fresh, key): + setattr(fresh, key, value) + instance = fresh if not refresh: - return self + return instance if load is not None: - cls = type(self) - await session.refresh(self) - # 如果指定了 load, 重新获取实例并加载关联关系 - return await cls.get(session, cls.id == self.id, load=load) + await session.refresh(instance) + return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses) else: - await session.refresh(self) - return self + await session.refresh(instance) + return instance async def update( self: T, @@ -351,7 +400,9 @@ class TableBaseMixin(AsyncAttrs): exclude: set[str] | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None, refresh: bool = True, - commit: bool = True + commit: bool = True, + jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + optimistic_retry_count: int = 0, ) -> T: """ 使用另一个模型实例或字典中的数据来更新当前实例. @@ -371,16 +422,20 @@ class TableBaseMixin(AsyncAttrs): user = await user.update(session, update_data, load=User.permission) return user + # ✅ 正确:更新后需要访问多态关系时 + tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all') + return tool_set # tools 关系已正确加载子类数据 + # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 await client.update(session, update_data, refresh=False) - # ✅ 正确:批量操作,减少提交次数 - await user1.update(session, data1, commit=False) - await user2.update(session, data2, commit=False) + # ✅ 正确:批量操作时延迟提交 + for item in items: + item = await item.update(session, data, commit=False) await session.commit() - # ✅ 正确:批量操作并预加载多个关联关系 - user = await user.update(session, data, load=[User.group, User.tags]) + # ✅ 正确:启用乐观锁自动重试 + order = await order.update(session, update_data, optimistic_retry_count=3) # ❌ 错误:需要返回值但未使用 await client.update(session, update_data) @@ -394,111 +449,134 @@ class TableBaseMixin(AsyncAttrs): exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供) 的字段将被忽略. 默认为 True. exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}. - load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性. - 可以是单个关系或关系列表. - 例如 `User.permission` 或 `[User.group, User.tags]`. + load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性. + 例如 `User.permission`. refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值, 设为 False 可节省一次数据库查询。默认为 True. - commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, - 之后需要手动调用 `session.commit()`。默认为 True. + commit (bool): 是否在更新后提交事务。如果为 False,只会 flush + 但不提交,适用于批量操作场景。默认为 True. + jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。 + - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表 + - 'all': 两阶段查询,只加载实际关联的子类 + - None(默认): 不使用多态加载 + optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。 + 重试时会重新查询最新数据,将 other 的更新重新应用后再次保存。 Returns: T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. - """ - self.sqlmodel_update( - other.model_dump(exclude_unset=exclude_unset, exclude=exclude), - update=extra_data - ) - session.add(self) - if commit: - await session.commit() - else: - await session.flush() + Raises: + OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽 + """ + cls = type(self) + update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude) + instance = self + retries_remaining = optimistic_retry_count + + while True: + instance.sqlmodel_update(update_data, update=extra_data) + session.add(instance) + + try: + if commit: + await session.commit() + else: + await session.flush() + break # 成功,退出循环 + except StaleDataError as e: + await session.rollback() + if retries_remaining <= 0: + raise OptimisticLockError( + message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改", + model_class=cls.__name__, + record_id=str(getattr(instance, 'id', None)), + expected_version=getattr(instance, 'version', None), + original_error=e, + ) from e + + # 失败后重试:重新查询最新数据并重新应用更新 + retries_remaining -= 1 + fresh = await cls.get(session, cls.id == self.id) + if fresh is None: + raise OptimisticLockError( + message=f"{cls.__name__} 重试失败:记录已被删除", + model_class=cls.__name__, + record_id=str(getattr(self, 'id', None)), + original_error=e, + ) from e + instance = fresh if not refresh: - return self + return instance if load is not None: - cls = type(self) - await session.refresh(self) - return await cls.get(session, cls.id == self.id, load=load) + await session.refresh(instance) + return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses) else: - await session.refresh(self) - return self + await session.refresh(instance) + return instance @classmethod async def delete( - cls: type[T], - session: AsyncSession, - instances: T | list[T] | None = None, - *, - condition: BinaryExpression | ClauseElement | None = None, - commit: bool = True + cls: type[T], + session: AsyncSession, + instances: T | list[T] | None = None, + *, + condition: BinaryExpression | ClauseElement | None = None, + commit: bool = True, ) -> int: """ - 从数据库中删除记录. - - 支持两种删除方式: - 1. 实例删除:传入 instances 参数,先加载再删除 - 2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效) + 从数据库中删除记录,支持实例删除和条件删除两种模式。 Args: - session (AsyncSession): 用于数据库操作的异步会话对象. - instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选). - condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一). - commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, - 之后需要手动调用 `session.commit()`。默认为 True. + session: 用于数据库操作的异步会话对象 + instances: 要删除的单个模型实例或模型实例列表(实例删除模式) + condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE) + commit: 是否在删除后提交事务。默认为 True Returns: - int: 删除的记录数量 + 删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数) + + Raises: + ValueError: 同时提供 instances 和 condition,或两者都未提供 Usage: - # 实例删除 - item_to_delete = await Item.get(session, Item.id == 1) - if item_to_delete: - deleted_count = await Item.delete(session, item_to_delete) + # 实例删除模式 + item = await Item.get(session, Item.id == 1) + if item: + await Item.delete(session, item) - # 条件删除(更高效,无需加载实例) + items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all") + if items: + await Item.delete(session, items) + + # 条件删除模式(高效批量删除,不加载实例到内存) deleted_count = await Item.delete( session, - condition=(Item.status == "inactive") & (Item.created_at < cutoff_date) + condition=(Item.user_id == user_id) & (Item.status == "expired"), ) - - # 批量删除后手动提交 - await Item.delete(session, item1, commit=False) - await Item.delete(session, item2, commit=False) - await session.commit() """ - # 条件删除模式 - if condition is not None: - from sqlmodel import delete as sql_delete - - if instances is not None: - raise ValueError("不能同时指定 instances 和 condition") - - # 执行条件删除 - stmt = sql_delete(cls).where(condition) - result = await session.exec(stmt) - deleted_count = result.rowcount - - if commit: - await session.commit() - - return deleted_count - - # 实例删除模式(原有逻辑) - if instances is None: - raise ValueError("必须指定 instances 或 condition") + if instances is not None and condition is not None: + raise ValueError("不能同时提供 instances 和 condition 参数") + if instances is None and condition is None: + raise ValueError("必须提供 instances 或 condition 参数之一") deleted_count = 0 - if isinstance(instances, list): - for instance in instances: - await session.delete(instance) - deleted_count += 1 + + if condition is not None: + # 条件删除模式:直接执行 SQL DELETE + stmt = sql_delete(cls).where(condition) + result = await session.execute(stmt) + deleted_count = result.rowcount else: - await session.delete(instances) - deleted_count = 1 + # 实例删除模式 + if isinstance(instances, list): + for instance in instances: + await session.delete(instance) + deleted_count = len(instances) + else: + await session.delete(instances) + deleted_count = 1 if commit: await session.commit() @@ -552,7 +630,8 @@ class TableBaseMixin(AsyncAttrs): filter: BinaryExpression | ClauseElement | None = None, with_for_update: bool = False, table_view: TableViewRequest | None = None, - load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + populate_existing: bool = False, created_before_datetime: datetime | None = None, created_after_datetime: datetime | None = None, updated_before_datetime: datetime | None = None, @@ -581,8 +660,10 @@ class TableBaseMixin(AsyncAttrs): options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据, 例如 `[selectinload(User.posts)]`. load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系. - 可以是单个关系或关系列表. - 例如 `User.profile` 或 `[User.group, User.tags]`. + 可以是单个关系或关系列表。支持嵌套关系预加载: + 当传入多个关系时,会自动检测依赖关系并构建链式 selectinload。 + 例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]` + 会自动构建 `selectinload(element_links).selectinload(node)`。 order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表. 例如 `[User.name.asc(), User.created_at.desc()]`. filter (BinaryExpression | ClauseElement | None): 附加的过滤条件. @@ -593,11 +674,16 @@ class TableBaseMixin(AsyncAttrs): 会覆盖offset、limit、order_by及时间筛选参数。 这是推荐的分页排序方式,统一了所有LIST端点的参数格式。 - load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。 + jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。 - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表 - 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益) - None(默认): 不使用多态加载 + populate_existing (bool): 如果为 True,强制用数据库数据覆盖 session 中已存在的对象(identity map)。 + 用于批量刷新对象,避免循环调用 session.refresh() 导致的 N 次查询。 + 注意:只刷新标量字段,不影响运行时属性(_开头的属性)。 + 对于 STI(单表继承)对象,推荐按子类分组查询以包含子类字段。默认为 False。 + created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录 created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录 updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录 @@ -607,7 +693,7 @@ class TableBaseMixin(AsyncAttrs): T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`. Raises: - ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用. + ValueError: 如果提供了无效的 `fetch_mode` 值,或 jti_subclasses 未与 load 配合使用. Examples: # 使用table_view参数(推荐) @@ -621,13 +707,13 @@ class TableBaseMixin(AsyncAttrs): session, ToolSet.id == tool_set_id, load=ToolSet.tools, - load_polymorphic='all' # 只加载实际关联的子类 + jti_subclasses='all' # 只加载实际关联的子类 ) """ - # 参数验证:load_polymorphic 需要与 load 配合使用 - if load_polymorphic is not None and load is None: + # 参数验证:jti_subclasses 需要与 load 配合使用 + if jti_subclasses is not None and load is None: raise ValueError( - "load_polymorphic 参数需要与 load 参数配合使用," + "jti_subclasses 参数需要与 load 参数配合使用," "请同时指定要加载的关系" ) @@ -656,13 +742,34 @@ class TableBaseMixin(AsyncAttrs): # 对于多态基类,使用 with_polymorphic 预加载所有子类的列 # 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误) - if issubclass(cls, PolymorphicBaseMixin): + polymorphic_cls = None # 保存多态实体,用于子类关系预加载 + is_polymorphic = issubclass(cls, PolymorphicBaseMixin) + is_jti = is_polymorphic and cls._is_joined_table_inheritance() + is_sti = is_polymorphic and not cls._is_joined_table_inheritance() + + # JTI 模式:总是使用 with_polymorphic(避免 N+1 查询) + # STI 模式:不使用 with_polymorphic(批量刷新时请按子类分组查询) + if is_jti: # '*' 表示加载所有子类 polymorphic_cls = with_polymorphic(cls, '*') statement = select(polymorphic_cls) else: statement = select(cls) + # 对于 STI(单表继承)子类,自动添加多态过滤条件 + # SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤 + # 这是已知行为,参考: + # - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤) + # - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整) + # 社区最佳实践是显式添加多态过滤条件 + if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance(): + mapper = inspect(cls) + # 检查是否有 polymorphic_identity 且不是抽象类 + if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract: + poly_on = mapper.polymorphic_on + if poly_on is not None: + statement = statement.where(poly_on == mapper.polymorphic_identity) + if condition is not None: statement = statement.where(condition) @@ -688,12 +795,19 @@ class TableBaseMixin(AsyncAttrs): # 标准化为列表 load_list = load if isinstance(load, list) else [load] - # 处理多态加载 - if load_polymorphic is not None: - # 多态加载只支持单个关系 - if len(load_list) > 1: - raise ValueError("load_polymorphic 仅支持单个关系") - target_class = load_list[0].property.mapper.class_ + # 构建链式 selectinload(支持嵌套关系预加载) + # 例如:load=[NodeGroupNode.element_links, NodeGroupElementLink.node] + # 会构建:selectinload(element_links).selectinload(node) + load_chains = cls._build_load_chains(load_list) + + # 处理多态加载(仅支持单链且只有一个关系) + if jti_subclasses is not None: + if len(load_chains) > 1 or len(load_chains[0]) > 1: + raise ValueError( + "jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系" + ) + single_load = load_chains[0][0] + target_class = single_load.property.mapper.class_ # 检查目标类是否继承自 PolymorphicBaseMixin if not issubclass(target_class, PolymorphicBaseMixin): @@ -702,26 +816,48 @@ class TableBaseMixin(AsyncAttrs): f"请确保其继承自 PolymorphicBaseMixin" ) - if load_polymorphic == 'all': + if jti_subclasses == 'all': # 两阶段查询:获取实际关联的多态类型 subclasses_to_load = await cls._resolve_polymorphic_subclasses( - session, condition, load_list[0], target_class + session, condition, single_load, target_class ) else: - subclasses_to_load = load_polymorphic + subclasses_to_load = jti_subclasses if subclasses_to_load: # 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项 # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading statement = statement.options( - selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load) + selectinload(single_load).selectin_polymorphic(subclasses_to_load) ) else: - statement = statement.options(selectinload(load_list[0])) + statement = statement.options(selectinload(single_load)) else: - # 为每个关系添加 selectinload - for rel in load_list: - statement = statement.options(selectinload(rel)) + # 为每条链构建链式 selectinload + for chain in load_chains: + # 获取第一个关系并检查是否需要通过多态实体访问 + first_rel = chain[0] + first_rel_parent = first_rel.property.parent.class_ + + # 如果关系的 parent_class 是当前类的子类(不是 cls 本身), + # 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问 + if ( + polymorphic_cls is not None + and first_rel_parent is not cls + and issubclass(first_rel_parent, cls) + ): + # 通过多态实体访问子类的关系属性 + # 例如:polymorphic_cls.NodeGroupNode.element_links + subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__) + rel_name = first_rel.key + first_rel_via_poly = getattr(subclass_alias, rel_name) + loader = selectinload(first_rel_via_poly) + else: + loader = selectinload(first_rel) + + for rel in chain[1:]: + loader = loader.selectinload(rel) + statement = statement.options(loader) if order_by is not None: statement = statement.order_by(*order_by) @@ -736,7 +872,17 @@ class TableBaseMixin(AsyncAttrs): statement = statement.filter(filter) if with_for_update: - statement = statement.with_for_update() + # 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制 + # PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE + if issubclass(cls, PolymorphicBaseMixin): + statement = statement.with_for_update(of=cls) + else: + statement = statement.with_for_update() + + if populate_existing: + # 强制用数据库数据覆盖 identity map 中的对象 + # 用于批量刷新,避免循环 refresh() 的 N 次查询 + statement = statement.execution_options(populate_existing=True) result = await session.exec(statement) @@ -749,6 +895,79 @@ class TableBaseMixin(AsyncAttrs): else: raise ValueError(f"无效的 fetch_mode: {fetch_mode}") + @staticmethod + def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]: + """ + 将关系列表构建为链式加载结构 + + 自动检测关系之间的依赖关系,构建嵌套预加载链。 + 例如:[NodeGroupNode.element_links, NodeGroupElementLink.node] + 会构建:[[element_links, node]](一条链) + + 算法: + 1. 获取每个关系的 parent class 和 target class + 2. 如果关系 B 的 parent class 等于关系 A 的 target class,则 B 链在 A 后面 + 3. 独立的关系各自成为一条链 + + Args: + load_list: 关系属性列表 + + Returns: + 链式关系列表,每条链是一个关系列表 + """ + if not load_list: + return [] + + # 构建关系信息:{关系: (parent_class, target_class)} + rel_info: dict[RelationshipInfo, tuple[type, type]] = {} + for rel in load_list: + parent_class = rel.property.parent.class_ + target_class = rel.property.mapper.class_ + rel_info[rel] = (parent_class, target_class) + + # 构建依赖图:{关系: 其前置关系} + predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list} + for rel_b in load_list: + parent_b, _ = rel_info[rel_b] + for rel_a in load_list: + if rel_a is rel_b: + continue + _, target_a = rel_info[rel_a] + # 如果 B 的 parent 精确等于 A 的 target,则 B 链在 A 后面 + # 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类) + if parent_b is target_a: + predecessors[rel_b] = rel_a + break + + # 找出所有链的起点(没有前置关系的) + roots = [rel for rel, pred in predecessors.items() if pred is None] + + # 构建链 + chains: list[list[RelationshipInfo]] = [] + used: set[RelationshipInfo] = set() + + for root in roots: + chain = [root] + used.add(root) + # 找后续节点 + current = root + while True: + # 找以 current 的 target 为 parent 的关系 + _, current_target = rel_info[current] + next_rel = None + for rel, (parent, _) in rel_info.items(): + if rel not in used and parent is current_target: + next_rel = rel + break + if next_rel is None: + break + chain.append(next_rel) + used.add(next_rel) + current = next_rel + chains.append(chain) + + return chains + @classmethod async def _resolve_polymorphic_subclasses( cls: type[T], @@ -791,12 +1010,15 @@ class TableBaseMixin(AsyncAttrs): )) ) else: - # 一对多关系:通过外键查询 - foreign_key_col = relationship_property.local_remote_pairs[0][1] + # 多对一/一对多关系:通过外键查询 + # local_remote_pairs[0] = (local_fk_col, remote_pk_col) + # 对于多对一:local 是当前类的外键,remote 是目标类的主键 + local_fk_col = relationship_property.local_remote_pairs[0][0] + remote_pk_col = relationship_property.local_remote_pairs[0][1] type_query = ( select(distinct(poly_name_col)) - .where(foreign_key_col.in_( - select(cls.id).where(condition) if condition is not None else select(cls.id) + .where(remote_pk_col.in_( + select(local_fk_col).where(condition) if condition is not None else select(local_fk_col) )) ) @@ -898,7 +1120,7 @@ class TableBaseMixin(AsyncAttrs): order_by: list[ClauseElement] | None = None, filter: BinaryExpression | ClauseElement | None = None, table_view: TableViewRequest | None = None, - load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, + jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, ) -> 'ListResponse[T]': """ 获取分页列表及总数,直接返回 ListResponse @@ -918,7 +1140,7 @@ class TableBaseMixin(AsyncAttrs): order_by: 排序子句 filter: 附加过滤条件 table_view: 分页排序参数(推荐使用) - load_polymorphic: 多态子类加载选项 + jti_subclasses: 多态子类加载选项 Returns: ListResponse[T]: 包含 count 和 items 的分页响应 @@ -957,7 +1179,7 @@ class TableBaseMixin(AsyncAttrs): order_by=order_by, filter=filter, table_view=table_view, - load_polymorphic=load_polymorphic, + jti_subclasses=jti_subclasses, ) return ListResponse(count=total_count, items=items) @@ -973,8 +1195,7 @@ class TableBaseMixin(AsyncAttrs): Args: session (AsyncSession): 用于数据库操作的异步会话对象. id (int): 要查找的记录的主键 ID. - load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性. - 可以是单个关系或关系列表. + load (Relationship | None): 可选的,用于预加载的关联属性. Returns: T: 找到的模型实例. @@ -1002,7 +1223,7 @@ class UUIDTableBaseMixin(TableBaseMixin): @override @classmethod - async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T: + async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T: """ 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常. @@ -1012,8 +1233,7 @@ class UUIDTableBaseMixin(TableBaseMixin): Args: session (AsyncSession): 用于数据库操作的异步会话对象. id (uuid.UUID): 要查找的记录的 UUID 主键. - load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性. - 可以是单个关系或关系列表. + load (Relationship | None): 可选的,用于预加载的关联属性. Returns: T: 找到的模型实例. diff --git a/models/model_base.py b/sqlmodels/model_base.py similarity index 98% rename from models/model_base.py rename to sqlmodels/model_base.py index 913a4e5..0851f8e 100644 --- a/models/model_base.py +++ b/sqlmodels/model_base.py @@ -119,4 +119,5 @@ class MCPResponseBase(MCPBase): """MCP 响应模型基础类""" result: str - """方法返回结果""" \ No newline at end of file + """方法返回结果""" + \ No newline at end of file diff --git a/models/node.py b/sqlmodels/node.py similarity index 100% rename from models/node.py rename to sqlmodels/node.py diff --git a/models/object.py b/sqlmodels/object.py similarity index 81% rename from models/object.py rename to sqlmodels/object.py index 7951702..114a800 100644 --- a/models/object.py +++ b/sqlmodels/object.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal from uuid import UUID from enum import StrEnum +from sqlalchemy import BigInteger from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text from .base import SQLModelBase @@ -15,6 +16,7 @@ if TYPE_CHECKING: from .source_link import SourceLink from .share import Share from .physical_file import PhysicalFile + from .uri import DiskNextURI class ObjectType(StrEnum): @@ -103,7 +105,7 @@ class ObjectMoveRequest(SQLModelBase): class ObjectDeleteRequest(SQLModelBase): """删除对象请求 DTO""" - ids: UUID | list[UUID] + ids: list[UUID] """待删除对象UUID列表""" @@ -116,12 +118,12 @@ class ObjectResponse(ObjectBase): thumb: bool = False """是否有缩略图""" - date: datetime - """对象修改时间""" - - create_date: datetime + created_at: datetime """对象创建时间""" + updated_at: datetime + """对象修改时间""" + source_enabled: bool = False """是否启用离线下载源""" @@ -138,7 +140,7 @@ class PolicyResponse(SQLModelBase): type: StorageType """存储类型""" - max_size: int = Field(ge=0, default=0) + max_size: int = Field(ge=0, default=0, sa_type=BigInteger) """单文件最大限制,单位字节,0表示不限制""" file_type: list[str] | None = None @@ -186,18 +188,18 @@ class Object(ObjectBase, UUIDTableBaseMixin): 合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。 根目录规则: - - 每个用户有一个显式根目录对象(name=用户的username, parent_id=NULL) + - 每个用户有一个显式根目录对象(name="/", parent_id=NULL) - 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id - 根目录的 policy_id 指定用户默认存储策略 - - 路径格式:/username/path/to/file(如 /admin/docs/readme.md) + - 路径格式:/path/to/file(如 /docs/readme.md),不包含用户名前缀 """ __table_args__ = ( # 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况) UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"), - # 名称不能包含斜杠 ([TODO] 还有特殊字符) + # 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/") CheckConstraint( - "name NOT LIKE '%/%' AND name NOT LIKE '%\\%'", + "parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')", name="ck_object_name_no_slash", ), # 性能索引 @@ -220,7 +222,7 @@ class Object(ObjectBase, UUIDTableBaseMixin): # ==================== 文件专属字段 ==================== - size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) + size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"}) """文件大小(字节),目录为 0""" upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) @@ -374,15 +376,16 @@ class Object(ObjectBase, UUIDTableBaseMixin): session, user_id: UUID, path: str, - username: str, ) -> "Object | None": """ 根据路径获取对象 + 路径从用户根目录开始,不包含用户名前缀。 + 如 "/" 表示根目录,"/docs/images" 表示根目录下的 docs/images。 + :param session: 数据库会话 :param user_id: 用户UUID - :param path: 路径,如 "/username" 或 "/username/docs/images" - :param username: 用户名,用于识别根目录 + :param path: 路径,如 "/" 或 "/docs/images" :return: Object 或 None """ path = path.strip() @@ -403,16 +406,7 @@ class Object(ObjectBase, UUIDTableBaseMixin): if not parts: return root - # 检查第一部分是否是用户名(根目录名) - if parts[0] == username: - # 路径以用户名开头,如 /admin/docs - if len(parts) == 1: - # 只有用户名,返回根目录 - return root - # 去掉用户名部分,从第二个部分开始遍历 - parts = parts[1:] - - # 从根目录开始遍历剩余路径 + # 从根目录开始遍历路径 current = root for part in parts: if not current: @@ -443,6 +437,77 @@ class Object(ObjectBase, UUIDTableBaseMixin): fetch_mode="all" ) + @classmethod + async def resolve_uri( + cls, + session, + uri: "DiskNextURI", + requesting_user_id: UUID | None = None, + ) -> "Object": + """ + 将 URI 解析为 Object 实例 + + 分派逻辑(类似 Cloudreve 的 getNavigator): + - MY → user_id = uri.id(str(requesting_user_id)) + 验证权限(自己的或管理员),然后 get_by_path + - SHARE → 通过 uri.fs_id 查 Share 表,验证密码和有效期 + 获取 share.object,然后沿 uri.path 遍历子对象 + - TRASH → 延后实现 + + :param session: 数据库会话 + :param uri: DiskNextURI 实例 + :param requesting_user_id: 请求用户UUID + :return: Object 实例 + :raises ValueError: URI 无法解析 + :raises PermissionError: 权限不足 + :raises NotImplementedError: 不支持的命名空间 + """ + from .uri import FileSystemNamespace + + if uri.namespace == FileSystemNamespace.MY: + # 确定目标用户 + target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None) + if not target_user_id_str: + raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id") + + target_user_id = UUID(target_user_id_str) + + # 权限检查:只能访问自己的空间(管理员权限由路由层判断) + if requesting_user_id and target_user_id != requesting_user_id: + raise PermissionError("无权访问其他用户的文件空间") + + obj = await cls.get_by_path(session, target_user_id, uri.path) + if not obj: + raise ValueError(f"路径不存在: {uri.path}") + return obj + + elif uri.namespace == FileSystemNamespace.SHARE: + raise NotImplementedError("分享空间解析尚未实现") + + elif uri.namespace == FileSystemNamespace.TRASH: + raise NotImplementedError("回收站解析尚未实现") + + else: + raise ValueError(f"未知的命名空间: {uri.namespace}") + + async def get_full_path(self, session) -> str: + """ + 从当前对象沿 parent_id 向上遍历到根目录,返回完整路径 + + :param session: 数据库会话 + :return: 完整路径,如 "/docs/images/photo.jpg" + """ + parts: list[str] = [] + current: Object | None = self + + while current and current.parent_id is not None: + parts.append(current.name) + current = await Object.get(session, Object.id == current.parent_id) + + # 反转顺序(从根到当前) + parts.reverse() + return "/" + "/".join(parts) + # ==================== 上传会话模型 ==================== @@ -452,10 +517,10 @@ class UploadSessionBase(SQLModelBase): file_name: str = Field(max_length=255) """原始文件名""" - file_size: int = Field(ge=0) + file_size: int = Field(ge=0, sa_type=BigInteger) """文件总大小(字节)""" - chunk_size: int = Field(ge=1) + chunk_size: int = Field(ge=1, sa_type=BigInteger) """分片大小(字节)""" total_chunks: int = Field(ge=1) @@ -474,7 +539,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin): uploaded_chunks: int = 0 """已上传分片数""" - uploaded_size: int = 0 + uploaded_size: int = Field(default=0, sa_type=BigInteger) """已上传大小(字节)""" storage_path: str | None = Field(default=None, max_length=512) @@ -680,8 +745,8 @@ class AdminFileResponse(ObjectResponse): owner_id: UUID """所有者UUID""" - owner_username: str - """所有者用户名""" + owner_email: str + """所有者邮箱""" policy_name: str """存储策略名称""" @@ -709,12 +774,12 @@ class AdminFileResponse(ObjectResponse): # ObjectResponse 字段 id=obj.id, thumb=False, - date=obj.updated_at, - create_date=obj.created_at, + created_at=obj.created_at, + updated_at=obj.updated_at, source_enabled=False, # AdminFileResponse 字段 owner_id=obj.owner_id, - owner_username=owner.username if owner else "unknown", + owner_email=owner.email if owner else "unknown", policy_name=policy.name if policy else "unknown", is_banned=obj.is_banned, banned_at=obj.banned_at, @@ -725,7 +790,7 @@ class AdminFileResponse(ObjectResponse): class FileBanRequest(SQLModelBase): """文件封禁请求 DTO""" - is_banned: bool = True + ban: bool = True """是否封禁""" reason: str | None = Field(default=None, max_length=500) diff --git a/models/order.py b/sqlmodels/order.py similarity index 100% rename from models/order.py rename to sqlmodels/order.py diff --git a/models/physical_file.py b/sqlmodels/physical_file.py similarity index 96% rename from models/physical_file.py rename to sqlmodels/physical_file.py index 49fd5e5..187039b 100644 --- a/models/physical_file.py +++ b/sqlmodels/physical_file.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING from uuid import UUID +from sqlalchemy import BigInteger from sqlmodel import Field, Relationship, Index from .base import SQLModelBase @@ -28,7 +29,7 @@ class PhysicalFileBase(SQLModelBase): storage_path: str = Field(max_length=512) """物理存储路径(相对于存储策略根目录)""" - size: int = 0 + size: int = Field(default=0, sa_type=BigInteger) """文件大小(字节)""" checksum_md5: str | None = Field(default=None, max_length=32) diff --git a/models/policy.py b/sqlmodels/policy.py similarity index 100% rename from models/policy.py rename to sqlmodels/policy.py diff --git a/models/redeem.py b/sqlmodels/redeem.py similarity index 100% rename from models/redeem.py rename to sqlmodels/redeem.py diff --git a/models/report.py b/sqlmodels/report.py similarity index 100% rename from models/report.py rename to sqlmodels/report.py diff --git a/models/setting.py b/sqlmodels/setting.py similarity index 85% rename from models/setting.py rename to sqlmodels/setting.py index 5e6b5cf..a0d0a1c 100644 --- a/models/setting.py +++ b/sqlmodels/setting.py @@ -20,16 +20,10 @@ class SiteConfigResponse(SQLModelBase): title: str = "DiskNext" """网站标题""" - # themes: dict[str, str] = {} - # """网站主题配置""" - - # default_theme: dict[str, str] = {} - # """默认主题RGB色号""" - site_notice: str | None = None """网站公告""" - user: UserResponse + user: UserResponse | None = None """用户信息""" logo_light: str | None = None @@ -38,11 +32,23 @@ class SiteConfigResponse(SQLModelBase): logo_dark: str | None = None """网站Logo URL(深色模式)""" - captcha_type: CaptchaType | None = None + register_enabled: bool = True + """是否允许注册""" + + login_captcha: bool = False + """登录是否需要验证码""" + + reg_captcha: bool = False + """注册是否需要验证码""" + + forget_captcha: bool = False + """找回密码是否需要验证码""" + + captcha_type: CaptchaType = CaptchaType.DEFAULT """验证码类型""" captcha_key: str | None = None - """验证码密钥""" + """验证码 public key(DEFAULT 类型时为 None)""" # ==================== 管理员设置 DTO ==================== diff --git a/models/share.py b/sqlmodels/share.py similarity index 98% rename from models/share.py rename to sqlmodels/share.py index b50fba0..a3cd946 100644 --- a/models/share.py +++ b/sqlmodels/share.py @@ -215,6 +215,6 @@ class AdminShareListItem(ShareListItemBase): """从 Share ORM 对象构建""" return cls( **ShareListItemBase.model_validate(share, from_attributes=True).model_dump(), - username=user.username if user else None, + username=user.email if user else None, object_name=obj.name if obj else None, ) diff --git a/models/source_link.py b/sqlmodels/source_link.py similarity index 100% rename from models/source_link.py rename to sqlmodels/source_link.py diff --git a/models/storage_pack.py b/sqlmodels/storage_pack.py similarity index 100% rename from models/storage_pack.py rename to sqlmodels/storage_pack.py diff --git a/models/tag.py b/sqlmodels/tag.py similarity index 100% rename from models/tag.py rename to sqlmodels/tag.py diff --git a/models/task.py b/sqlmodels/task.py similarity index 98% rename from models/task.py rename to sqlmodels/task.py index b541235..c1cf261 100644 --- a/models/task.py +++ b/sqlmodels/task.py @@ -73,7 +73,7 @@ class TaskSummary(TaskSummaryBase): """从 Task ORM 对象构建""" return cls( **TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(), - username=user.username if user else None, + username=user.email if user else None, ) diff --git a/sqlmodels/uri.py b/sqlmodels/uri.py new file mode 100644 index 0000000..3d4075c --- /dev/null +++ b/sqlmodels/uri.py @@ -0,0 +1,258 @@ + +from enum import StrEnum +from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote + +from .base import SQLModelBase + + +class FileSystemNamespace(StrEnum): + """文件系统命名空间""" + MY = "my" + """用户个人空间""" + + SHARE = "share" + """分享空间""" + + TRASH = "trash" + """回收站""" + + +class DiskNextURI(SQLModelBase): + """ + DiskNext 文件 URI + + URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query] + + fs_id 可省略: + - my/trash 命名空间省略时默认当前用户 + - share 命名空间必须提供 fs_id(Share.code) + """ + + fs_id: str | None = None + """文件系统标识符,可省略""" + + namespace: FileSystemNamespace + """命名空间""" + + path: str = "/" + """路径""" + + password: str | None = None + """访问密码(用于有密码的分享)""" + + query: dict[str, str] | None = None + """查询参数""" + + # === 属性 === + + @property + def path_parts(self) -> list[str]: + """路径分割为列表(过滤空串)""" + return [p for p in self.path.split("/") if p] + + @property + def is_root(self) -> bool: + """是否指向根目录""" + return self.path.strip("/") == "" + + # === 核心方法 === + + def id(self, default_id: str | None = None) -> str | None: + """ + 获取 fs_id,省略时返回 default_id + + 参考 Cloudreve URI.ID(defaultUid) 方法 + + :param default_id: 默认值(通常为当前用户 ID) + :return: fs_id 或 default_id + """ + return self.fs_id if self.fs_id else default_id + + # === 类方法 === + + @classmethod + def parse(cls, uri: str) -> "DiskNextURI": + """ + 解析 URI 字符串 + + 实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析 + - hostname → namespace + - username → fs_id + - password → password + - path → path + - query → query dict + + :param uri: URI 字符串,如 "disknext://my/docs/readme.md" + :return: DiskNextURI 实例 + :raises ValueError: URI 格式无效 + """ + if not uri.startswith("disknext://"): + raise ValueError(f"URI 必须以 disknext:// 开头: {uri}") + + # 替换协议为 http:// 以利用 urllib.parse 解析 + http_uri = "http://" + uri[len("disknext://"):] + parsed = urlparse(http_uri) + + # 解析 namespace + hostname = parsed.hostname + if not hostname: + raise ValueError(f"URI 缺少命名空间: {uri}") + + try: + namespace = FileSystemNamespace(hostname) + except ValueError: + raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}") + + # 解析 fs_id 和 password + fs_id = unquote(parsed.username) if parsed.username else None + password = unquote(parsed.password) if parsed.password else None + + # 解析 path + path = unquote(parsed.path) if parsed.path else "/" + if not path: + path = "/" + + # 解析 query + query: dict[str, str] | None = None + if parsed.query: + raw_query = parse_qs(parsed.query, keep_blank_values=True) + query = {k: v[0] for k, v in raw_query.items()} + + return cls( + fs_id=fs_id, + namespace=namespace, + path=path, + password=password, + query=query, + ) + + @classmethod + def build( + cls, + namespace: FileSystemNamespace, + path: str = "/", + fs_id: str | None = None, + password: str | None = None, + ) -> "DiskNextURI": + """ + 构建 URI 实例 + + :param namespace: 命名空间 + :param path: 路径 + :param fs_id: 文件系统标识符 + :param password: 访问密码 + :return: DiskNextURI 实例 + """ + # 确保 path 以 / 开头 + if not path.startswith("/"): + path = "/" + path + + return cls( + fs_id=fs_id, + namespace=namespace, + path=path, + password=password, + ) + + # === 实例方法 === + + def to_string(self) -> str: + """ + 序列化为 URI 字符串 + + :return: URI 字符串,如 "disknext://my/docs/readme.md" + """ + result = "disknext://" + + # fs_id 和 password + if self.fs_id: + result += quote(self.fs_id, safe="") + if self.password: + result += ":" + quote(self.password, safe="") + result += "@" + + # namespace + result += self.namespace.value + + # path + result += self.path + + # query + if self.query: + result += "?" + urlencode(self.query) + + return result + + def join(self, *elements: str) -> "DiskNextURI": + """ + 拼接路径元素,返回新 URI + + :param elements: 路径元素 + :return: 新的 DiskNextURI 实例 + """ + base = self.path.rstrip("/") + for element in elements: + element = element.strip("/") + if element: + base += "/" + element + + if not base: + base = "/" + + return DiskNextURI( + fs_id=self.fs_id, + namespace=self.namespace, + path=base, + password=self.password, + query=self.query, + ) + + def dir_uri(self) -> "DiskNextURI": + """ + 返回父目录的 URI + + :return: 父目录的 DiskNextURI 实例 + """ + parts = self.path_parts + if not parts: + # 已经是根目录 + return self.root() + + parent_path = "/" + "/".join(parts[:-1]) + if not parent_path.endswith("/"): + parent_path += "/" + + return DiskNextURI( + fs_id=self.fs_id, + namespace=self.namespace, + path=parent_path, + password=self.password, + ) + + def root(self) -> "DiskNextURI": + """ + 返回根目录的 URI(保留 namespace 和 fs_id) + + :return: 根目录的 DiskNextURI 实例 + """ + return DiskNextURI( + fs_id=self.fs_id, + namespace=self.namespace, + path="/", + password=self.password, + ) + + def name(self) -> str: + """ + 返回路径的最后一段(文件名或目录名) + + :return: 文件名或目录名,根目录返回空字符串 + """ + parts = self.path_parts + return parts[-1] if parts else "" + + def __str__(self) -> str: + return self.to_string() + + def __repr__(self) -> str: + return f"DiskNextURI({self.to_string()!r})" diff --git a/models/user.py b/sqlmodels/user.py similarity index 84% rename from models/user.py rename to sqlmodels/user.py index 287b502..afdeecb 100644 --- a/models/user.py +++ b/sqlmodels/user.py @@ -60,8 +60,8 @@ class UserFilterParams(SQLModelBase): group_id: UUID | None = None """按用户组UUID筛选""" - username_contains: str | None = Field(default=None, max_length=50) - """用户名包含(不区分大小写的模糊搜索)""" + email_contains: str | None = Field(default=None, max_length=50) + """邮箱包含(不区分大小写的模糊搜索)""" nickname_contains: str | None = Field(default=None, max_length=50) """昵称包含(不区分大小写的模糊搜索)""" @@ -75,8 +75,8 @@ class UserFilterParams(SQLModelBase): class UserBase(SQLModelBase): """用户基础字段,供数据库模型和 DTO 共享""" - username: str - """用户名""" + email: str + """用户邮箱""" status: UserStatus = UserStatus.ACTIVE """用户状态""" @@ -90,8 +90,8 @@ class UserBase(SQLModelBase): class LoginRequest(SQLModelBase): """登录请求 DTO""" - username: str - """用户名或邮箱""" + email: str + """用户邮箱""" password: str """用户密码""" @@ -106,8 +106,8 @@ class LoginRequest(SQLModelBase): class RegisterRequest(SQLModelBase): """注册请求 DTO""" - username: str - """用户名,唯一,一经注册不可更改""" + email: str + """用户邮箱,唯一""" password: str """用户密码""" @@ -116,6 +116,20 @@ class RegisterRequest(SQLModelBase): """验证码""" +class BatchDeleteRequest(SQLModelBase): + """批量删除请求 DTO""" + + ids: list[UUID] + """待删除 UUID 列表""" + + +class RefreshTokenRequest(SQLModelBase): + """刷新令牌请求 DTO""" + + refresh_token: str + """刷新令牌""" + + class WebAuthnInfo(SQLModelBase): """WebAuthn 信息 DTO""" @@ -166,6 +180,9 @@ class UserResponse(ResponseBase): id: UUID """用户UUID""" + email: str + """用户邮箱""" + nickname: str | None = None """用户昵称""" @@ -184,11 +201,23 @@ class UserResponse(ResponseBase): tags: list[str] = [] """用户标签列表""" +class UserStorageResponse(SQLModelBase): + """用户存储信息 DTO""" + + used: int + """已用存储空间(字节)""" + + free: int + """剩余存储空间(字节)""" + + total: int + """总存储空间(字节)""" + class UserPublic(UserBase): """用户公开信息 DTO,用于 API 响应""" - id: UUID | None = None + id: UUID """用户UUID""" nickname: str | None = None @@ -206,6 +235,9 @@ class UserPublic(UserBase): group_id: UUID | None = None """所属用户组UUID""" + group_name: str | None = None + """用户组名称""" + two_factor: str | None = None """两步验证密钥(32位字符串,null 表示未启用)""" @@ -219,29 +251,63 @@ class UserPublic(UserBase): class UserSettingResponse(SQLModelBase): """用户设置响应 DTO""" - authn: "AuthnResponse | None" = None + id: UUID + """用户UUID""" + + email: str + """用户邮箱""" + + nickname: str | None = None + """昵称""" + + created_at: datetime + """用户注册时间""" + + group_name: str + """用户所属用户组名称""" + + language: str + """语言偏好""" + + timezone: int + """时区""" + + authn: "list[AuthnResponse] | None" = None """认证信息""" group_expires: datetime | None = None """用户组过期时间""" - prefer_theme: str = "#5898d4" - """用户首选主题""" - - themes: dict[str, str] = {} - """用户主题配置""" - two_factor: bool = False """是否启用两步验证""" - uid: UUID | None = None - """用户UUID""" - # ==================== 管理员用户管理 DTO ==================== +class UserAdminCreateRequest(SQLModelBase): + """管理员创建用户请求 DTO""" + + email: str = Field(max_length=50) + """用户邮箱""" + + password: str + """用户密码(明文,由服务端加密)""" + + nickname: str | None = Field(default=None, max_length=50) + """昵称""" + + group_id: UUID + """所属用户组UUID""" + + status: UserStatus = UserStatus.ACTIVE + """用户状态""" + + class UserAdminUpdateRequest(SQLModelBase): """管理员更新用户请求 DTO""" + + email: str = Field(max_length=50) + """邮箱""" nickname: str | None = Field(default=None, max_length=50) """昵称""" @@ -317,8 +383,8 @@ UserSettingResponse.model_rebuild() class User(UserBase, UUIDTableBaseMixin): """用户模型""" - username: str = Field(max_length=50, unique=True, index=True) - """用户名,唯一,一经注册不可更改""" + email: str = Field(max_length=50, unique=True, index=True) + """用户邮箱,唯一""" nickname: str | None = Field(default=None, max_length=50) """用于公开展示的名字,可使用真实姓名或昵称""" @@ -426,8 +492,10 @@ class User(UserBase, UUIDTableBaseMixin): ) def to_public(self) -> "UserPublic": - """转换为公开 DTO,排除敏感字段""" - return UserPublic.model_validate(self) + """转换为公开 DTO,排除敏感字段。需要预加载 group 关系。""" + data = UserPublic.model_validate(self) + data.group_name = self.group.name + return data @classmethod async def get_with_count( @@ -457,8 +525,8 @@ class User(UserBase, UUIDTableBaseMixin): if filter_params.group_id is not None: filter_conditions.append(cls.group_id == filter_params.group_id) - if filter_params.username_contains is not None: - filter_conditions.append(cls.username.ilike(f"%{filter_params.username_contains}%")) + if filter_params.email_contains is not None: + filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%")) if filter_params.nickname_contains is not None: filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%")) @@ -482,4 +550,5 @@ class User(UserBase, UUIDTableBaseMixin): order_by=order_by, filter=filter, table_view=table_view, - ) \ No newline at end of file + ) + \ No newline at end of file diff --git a/models/user_authn.py b/sqlmodels/user_authn.py similarity index 100% rename from models/user_authn.py rename to sqlmodels/user_authn.py diff --git a/models/webdav.py b/sqlmodels/webdav.py similarity index 100% rename from models/webdav.py rename to sqlmodels/webdav.py diff --git a/tests/check_imports.py b/tests/check_imports.py index 537e2f1..9f640c1 100644 --- a/tests/check_imports.py +++ b/tests/check_imports.py @@ -49,13 +49,13 @@ def main(): ("itsdangerous", "签名工具"), # 项目模块 - ("models", "数据库模型"), - ("models.user", "用户模型"), - ("models.group", "用户组模型"), - ("models.object", "对象模型"), - ("models.setting", "设置模型"), - ("models.policy", "策略模型"), - ("models.database", "数据库连接"), + ("sqlmodels", "数据库模型"), + ("sqlmodels.user", "用户模型"), + ("sqlmodels.group", "用户组模型"), + ("sqlmodels.object", "对象模型"), + ("sqlmodels.setting", "设置模型"), + ("sqlmodels.policy", "策略模型"), + ("sqlmodels.database", "数据库连接"), ("utils.password.pwd", "密码工具"), ("utils.JWT.JWT", "JWT 工具"), ("service.user.login", "登录服务"), diff --git a/tests/conftest.py b/tests/conftest.py index e452fc6..89d36d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,12 +23,12 @@ from sqlalchemy.orm import sessionmaker sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from main import app -from models.database import get_session -from models.group import Group, GroupOptions -from models.migration import migration -from models.object import Object, ObjectType -from models.policy import Policy, PolicyType -from models.user import User +from sqlmodels.database import get_session +from sqlmodels.group import Group, GroupOptions +from sqlmodels.migration import migration +from sqlmodels.object import Object, ObjectType +from sqlmodels.policy import Policy, PolicyType +from sqlmodels.user import User from utils.JWT.JWT import create_access_token from utils.password.pwd import Password @@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession): @pytest_asyncio.fixture(scope="function") async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: """ - 创建测试用户并返回 {id, username, password, token} + 创建测试用户并返回 {id, email, password, token} 创建一个普通用户,包含用户组、存储策略和根目录。 """ @@ -190,7 +190,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: # 创建测试用户 password = "test_password_123" user = User( - username="testuser", + email="testuser@test.local", nickname="测试用户", password=Password.hash(password), status=True, @@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: # 创建用户根目录 root_folder = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -216,7 +216,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: return { "id": user.id, - "username": user.username, + "email": user.email, "password": password, "token": access_token, "group_id": group.id, @@ -227,7 +227,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: @pytest_asyncio.fixture(scope="function") async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: """ - 获取管理员用户 {id, username, token} + 获取管理员用户 {id, email, token} 创建具有管理员权限的用户。 """ @@ -267,7 +267,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: # 创建管理员用户 password = "admin_password_456" admin = User( - username="admin", + email="admin@disknext.local", nickname="管理员", password=Password.hash(password), status=True, @@ -279,7 +279,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: # 创建管理员根目录 root_folder = Object( - name=admin.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=admin.id, @@ -293,7 +293,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: return { "id": admin.id, - "username": admin.username, + "email": admin.email, "password": password, "token": access_token, "group_id": admin_group.id, diff --git a/tests/example_test.py b/tests/example_test.py index 3a77f8d..94a1096 100644 --- a/tests/example_test.py +++ b/tests/example_test.py @@ -8,9 +8,9 @@ from uuid import UUID from sqlmodel.ext.asyncio.session import AsyncSession -from models.user import User -from models.group import Group -from models.object import Object, ObjectType +from sqlmodels.user import User +from sqlmodels.group import Group +from sqlmodels.object import Object, ObjectType from tests.fixtures import UserFactory, GroupFactory, ObjectFactory @@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession): user = await UserFactory.create( db_session, group_id=group.id, - username="testuser", + email="testuser@test.local", password="password123" ) # 验证 assert user.id is not None - assert user.username == "testuser" + assert user.email == "testuser@test.local" assert user.group_id == group.id assert user.status is True @@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession): async def test_object_factory(db_session: AsyncSession): """测试对象工厂的基本功能""" # 准备依赖 - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = await GroupFactory.create(db_session) user = await UserFactory.create(db_session, group_id=group.id) @@ -102,7 +102,7 @@ async def test_conftest_fixtures( """测试 conftest.py 中的 fixtures""" # 验证 test_user fixture assert test_user["id"] is not None - assert test_user["username"] == "testuser" + assert test_user["email"] == "testuser@test.local" assert test_user["token"] is not None # 验证 auth_headers fixture @@ -112,7 +112,7 @@ async def test_conftest_fixtures( # 验证用户在数据库中存在 user = await User.get(db_session, User.id == test_user["id"]) assert user is not None - assert user.username == test_user["username"] + assert user.email == test_user["email"] @pytest.mark.integration @@ -145,7 +145,7 @@ async def test_test_directory_fixture( @pytest.mark.integration async def test_nested_structure_factory(db_session: AsyncSession): """测试嵌套结构工厂""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType # 准备依赖 group = await GroupFactory.create(db_session) diff --git a/tests/fixtures/groups.py b/tests/fixtures/groups.py index 3198c59..dd1fc1c 100644 --- a/tests/fixtures/groups.py +++ b/tests/fixtures/groups.py @@ -5,7 +5,7 @@ """ from sqlmodel.ext.asyncio.session import AsyncSession -from models.group import Group, GroupOptions +from sqlmodels.group import Group, GroupOptions class GroupFactory: diff --git a/tests/fixtures/objects.py b/tests/fixtures/objects.py index 340c0f3..1183b01 100644 --- a/tests/fixtures/objects.py +++ b/tests/fixtures/objects.py @@ -7,8 +7,8 @@ from uuid import UUID from sqlmodel.ext.asyncio.session import AsyncSession -from models.object import Object, ObjectType -from models.user import User +from sqlmodels.object import Object, ObjectType +from sqlmodels.user import User class ObjectFactory: @@ -119,7 +119,7 @@ class ObjectFactory: Object: 创建的根目录实例 """ root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, diff --git a/tests/fixtures/users.py b/tests/fixtures/users.py index 838dcf9..e4dfa20 100644 --- a/tests/fixtures/users.py +++ b/tests/fixtures/users.py @@ -7,7 +7,7 @@ from uuid import UUID from sqlmodel.ext.asyncio.session import AsyncSession -from models.user import User +from sqlmodels.user import User from utils.password.pwd import Password @@ -18,7 +18,7 @@ class UserFactory: async def create( session: AsyncSession, group_id: UUID, - username: str | None = None, + email: str | None = None, password: str | None = None, **kwargs ) -> User: @@ -28,7 +28,7 @@ class UserFactory: 参数: session: 数据库会话 group_id: 用户组UUID - username: 用户名(默认: test_user_{随机}) + email: 用户邮箱(默认: test_user_{随机}@test.local) password: 明文密码(默认: password123) **kwargs: 其他用户字段 @@ -37,15 +37,15 @@ class UserFactory: """ import uuid - if username is None: - username = f"test_user_{uuid.uuid4().hex[:8]}" + if email is None: + email = f"test_user_{uuid.uuid4().hex[:8]}@test.local" if password is None: password = "password123" user = User( - username=username, - nickname=kwargs.get("nickname", username), + email=email, + nickname=kwargs.get("nickname", email), password=Password.hash(password), status=kwargs.get("status", True), storage=kwargs.get("storage", 0), @@ -67,7 +67,7 @@ class UserFactory: async def create_admin( session: AsyncSession, admin_group_id: UUID, - username: str | None = None, + email: str | None = None, password: str | None = None ) -> User: """ @@ -76,7 +76,7 @@ class UserFactory: 参数: session: 数据库会话 admin_group_id: 管理员组UUID - username: 用户名(默认: admin_{随机}) + email: 用户邮箱(默认: admin_{随机}@disknext.local) password: 明文密码(默认: admin_password) 返回: @@ -84,15 +84,15 @@ class UserFactory: """ import uuid - if username is None: - username = f"admin_{uuid.uuid4().hex[:8]}" + if email is None: + email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local" if password is None: password = "admin_password" admin = User( - username=username, - nickname=f"管理员 {username}", + email=email, + nickname=f"管理员 {email}", password=Password.hash(password), status=True, storage=0, @@ -108,7 +108,7 @@ class UserFactory: async def create_banned( session: AsyncSession, group_id: UUID, - username: str | None = None + email: str | None = None ) -> User: """ 创建被封禁用户 @@ -116,19 +116,19 @@ class UserFactory: 参数: session: 数据库会话 group_id: 用户组UUID - username: 用户名(默认: banned_user_{随机}) + email: 用户邮箱(默认: banned_user_{随机}@test.local) 返回: User: 创建的被封禁用户实例 """ import uuid - if username is None: - username = f"banned_user_{uuid.uuid4().hex[:8]}" + if email is None: + email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local" banned_user = User( - username=username, - nickname=f"封禁用户 {username}", + email=email, + nickname=f"封禁用户 {email}", password=Password.hash("banned_password"), status=False, # 封禁状态 storage=0, @@ -145,7 +145,7 @@ class UserFactory: session: AsyncSession, group_id: UUID, storage_bytes: int, - username: str | None = None + email: str | None = None ) -> User: """ 创建已使用指定存储空间的用户 @@ -154,19 +154,19 @@ class UserFactory: session: 数据库会话 group_id: 用户组UUID storage_bytes: 已使用的存储空间(字节) - username: 用户名(默认: storage_user_{随机}) + email: 用户邮箱(默认: storage_user_{随机}@test.local) 返回: User: 创建的用户实例 """ import uuid - if username is None: - username = f"storage_user_{uuid.uuid4().hex[:8]}" + if email is None: + email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local" user = User( - username=username, - nickname=username, + email=email, + nickname=email, password=Password.hash("password123"), status=True, storage=storage_bytes, diff --git a/tests/integration/api/test_admin.py b/tests/integration/api/test_admin.py index f5ae705..76d4895 100644 --- a/tests/integration/api/test_admin.py +++ b/tests/integration/api/test_admin.py @@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data( if len(users) > 0: user = users[0] assert "id" in user - assert "username" in user + assert "email" in user @pytest.mark.asyncio @@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient): """测试创建用户需要认证""" response = await async_client.post( "/api/admin/user/create", - json={"username": "newadminuser", "password": "pass123"} + json={"email": "newadminuser@test.local", "password": "pass123"} ) assert response.status_code == 401 @@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin( response = await async_client.post( "/api/admin/user/create", headers=auth_headers, - json={"username": "newadminuser", "password": "pass123"} + json={"email": "newadminuser@test.local", "password": "pass123"} ) assert response.status_code == 403 diff --git a/tests/integration/api/test_directory.py b/tests/integration/api/test_directory.py index e89f1ae..7beee5e 100644 --- a/tests/integration/api/test_directory.py +++ b/tests/integration/api/test_directory.py @@ -11,7 +11,7 @@ from uuid import UUID @pytest.mark.asyncio async def test_directory_requires_auth(async_client: AsyncClient): """测试获取目录需要认证""" - response = await async_client.get("/api/directory/testuser") + response = await async_client.get("/api/directory/") assert response.status_code == 401 @@ -24,7 +24,7 @@ async def test_directory_get_root( ): """测试获取用户根目录""" response = await async_client.get( - "/api/directory/testuser", + "/api/directory/", headers=auth_headers ) assert response.status_code == 200 @@ -45,7 +45,7 @@ async def test_directory_get_nested( ): """测试获取嵌套目录""" response = await async_client.get( - "/api/directory/testuser/docs", + "/api/directory/docs", headers=auth_headers ) assert response.status_code == 200 @@ -63,7 +63,7 @@ async def test_directory_get_contains_children( ): """测试目录包含子对象""" response = await async_client.get( - "/api/directory/testuser/docs", + "/api/directory/docs", headers=auth_headers ) assert response.status_code == 200 @@ -75,19 +75,6 @@ async def test_directory_get_contains_children( assert len(objects) >= 1 -@pytest.mark.asyncio -async def test_directory_forbidden_other_user( - async_client: AsyncClient, - auth_headers: dict[str, str] -): - """测试访问他人目录返回 403""" - response = await async_client.get( - "/api/directory/admin", - headers=auth_headers - ) - assert response.status_code == 403 - - @pytest.mark.asyncio async def test_directory_not_found( async_client: AsyncClient, @@ -95,23 +82,23 @@ async def test_directory_not_found( ): """测试目录不存在返回 404""" response = await async_client.get( - "/api/directory/testuser/nonexistent", + "/api/directory/nonexistent", headers=auth_headers ) assert response.status_code == 404 @pytest.mark.asyncio -async def test_directory_empty_path_returns_400( +async def test_directory_root_returns_200( async_client: AsyncClient, auth_headers: dict[str, str] ): - """测试空路径返回 400""" + """测试根目录端点返回 200""" response = await async_client.get( "/api/directory/", headers=auth_headers ) - assert response.status_code == 400 + assert response.status_code == 200 @pytest.mark.asyncio @@ -121,7 +108,7 @@ async def test_directory_response_includes_policy( ): """测试目录响应包含存储策略""" response = await async_client.get( - "/api/directory/testuser", + "/api/directory/", headers=auth_headers ) assert response.status_code == 200 @@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent( """测试在他人目录下创建目录返回 404""" # 先用管理员账号获取管理员的根目录ID admin_response = await async_client.get( - "/api/directory/admin", + "/api/directory/", headers=admin_headers ) assert admin_response.status_code == 200 diff --git a/tests/integration/api/test_user.py b/tests/integration/api/test_user.py index 2fce4c6..c851a8d 100644 --- a/tests/integration/api/test_user.py +++ b/tests/integration/api/test_user.py @@ -16,7 +16,7 @@ async def test_user_login_success( response = await async_client.post( "/api/user/session", data={ - "username": test_user_info["username"], + "username": test_user_info["email"], "password": test_user_info["password"], } ) @@ -38,7 +38,7 @@ async def test_user_login_wrong_password( response = await async_client.post( "/api/user/session", data={ - "username": test_user_info["username"], + "username": test_user_info["email"], "password": "wrongpassword", } ) @@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient): response = await async_client.post( "/api/user/session", data={ - "username": "nonexistent", + "username": "nonexistent@test.local", "password": "anypassword", } ) @@ -67,7 +67,7 @@ async def test_user_login_user_banned( response = await async_client.post( "/api/user/session", data={ - "username": banned_user_info["username"], + "username": banned_user_info["email"], "password": banned_user_info["password"], } ) @@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient): response = await async_client.post( "/api/user/", json={ - "username": "newuser", + "email": "newuser@test.local", "password": "newpass123", } ) @@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient): data = response.json() assert "data" in data assert "user_id" in data["data"] - assert "username" in data["data"] - assert data["data"]["username"] == "newuser" + assert "email" in data["data"] + assert data["data"]["email"] == "newuser@test.local" @pytest.mark.asyncio -async def test_user_register_duplicate_username( +async def test_user_register_duplicate_email( async_client: AsyncClient, test_user_info: dict[str, str] ): - """测试重复用户名返回 400""" + """测试重复邮箱返回 400""" response = await async_client.post( "/api/user/", json={ - "username": test_user_info["username"], + "email": test_user_info["email"], "password": "anypassword", } ) @@ -143,8 +143,8 @@ async def test_user_me_returns_user_info( assert "data" in data user_data = data["data"] assert "id" in user_data - assert "username" in user_data - assert user_data["username"] == "testuser" + assert "email" in user_data + assert user_data["email"] == "testuser@test.local" assert "group" in user_data assert "tags" in user_data diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index a0961fa..bfcba69 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) from main import app -from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User +from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User from utils import Password from utils.JWT import create_access_token from utils.JWT import JWT @@ -92,6 +92,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: Setting(type=SettingsType.VIEW, name="home_view_method", value="list"), Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"), Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"), + Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"), Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""), Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""), Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"), @@ -180,7 +181,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: # 6. 创建测试用户 test_user = User( id=uuid4(), - username="testuser", + email="testuser@test.local", password=Password.hash("testpass123"), nickname="测试用户", status=True, @@ -194,7 +195,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: admin_user = User( id=uuid4(), - username="admin", + email="admin@disknext.local", password=Password.hash("adminpass123"), nickname="管理员", status=True, @@ -208,7 +209,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: banned_user = User( id=uuid4(), - username="banneduser", + email="banneduser@test.local", password=Password.hash("banned123"), nickname="封禁用户", status=False, # 封禁状态 @@ -230,7 +231,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: # 7. 创建用户根目录 test_user_root = Object( id=uuid4(), - name=test_user.username, + name="/", type=ObjectType.FOLDER, owner_id=test_user.id, parent_id=None, @@ -241,7 +242,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: admin_user_root = Object( id=uuid4(), - name=admin_user.username, + name="/", type=ObjectType.FOLDER, owner_id=admin_user.id, parent_id=None, @@ -264,7 +265,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession: def test_user_info() -> dict[str, str]: """测试用户信息""" return { - "username": "testuser", + "email": "testuser@test.local", "password": "testpass123", } @@ -273,7 +274,7 @@ def test_user_info() -> dict[str, str]: def admin_user_info() -> dict[str, str]: """管理员用户信息""" return { - "username": "admin", + "email": "admin@disknext.local", "password": "adminpass123", } @@ -282,7 +283,7 @@ def admin_user_info() -> dict[str, str]: def banned_user_info() -> dict[str, str]: """封禁用户信息""" return { - "username": "banneduser", + "email": "banneduser@test.local", "password": "banned123", } @@ -293,7 +294,7 @@ def banned_user_info() -> dict[str, str]: def test_user_token(test_user_info: dict[str, str]) -> str: """生成测试用户的JWT token""" token, _ = JWT.create_access_token( - data={"sub": test_user_info["username"]}, + data={"sub": test_user_info["email"]}, expires_delta=timedelta(hours=1), ) return token @@ -303,7 +304,7 @@ def test_user_token(test_user_info: dict[str, str]) -> str: def admin_user_token(admin_user_info: dict[str, str]) -> str: """生成管理员的JWT token""" token, _ = JWT.create_access_token( - data={"sub": admin_user_info["username"]}, + data={"sub": admin_user_info["email"]}, expires_delta=timedelta(hours=1), ) return token @@ -313,7 +314,7 @@ def admin_user_token(admin_user_info: dict[str, str]) -> str: def expired_token() -> str: """生成过期的JWT token""" token, _ = JWT.create_access_token( - data={"sub": "testuser"}, + data={"sub": "testuser@test.local"}, expires_delta=timedelta(seconds=-1), # 已过期 ) return token @@ -362,7 +363,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU """创建测试目录结构""" # 获取测试用户和根目录 - test_user = await User.get(initialized_db, User.username == "testuser") + test_user = await User.get(initialized_db, User.email == "testuser@test.local") test_user_root = await Object.get_root(initialized_db, test_user.id) default_policy = await Policy.get(initialized_db, Policy.name == "本地存储") diff --git a/tests/integration/middleware/test_auth.py b/tests/integration/middleware/test_auth.py index 52b0382..17f17b7 100644 --- a/tests/integration/middleware/test_auth.py +++ b/tests/integration/middleware/test_auth.py @@ -83,7 +83,7 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient): async def test_auth_required_nonexistent_user_token(async_client: AsyncClient): """测试用户不存在的token返回 401""" token, _ = JWT.create_access_token( - data={"sub": "nonexistent_user"}, + data={"sub": "nonexistent_user@test.local"}, expires_delta=timedelta(hours=1) ) @@ -178,12 +178,12 @@ async def test_auth_on_directory_endpoint( ): """测试目录端点应用认证""" # 无认证 - response_no_auth = await async_client.get("/api/directory/testuser") + response_no_auth = await async_client.get("/api/directory/") assert response_no_auth.status_code == 401 # 有认证 response_with_auth = await async_client.get( - "/api/directory/testuser", + "/api/directory/", headers=auth_headers ) assert response_with_auth.status_code == 200 @@ -235,7 +235,7 @@ async def test_auth_on_storage_endpoint( async def test_refresh_token_format(test_user_info: dict[str, str]): """测试刷新token格式正确""" refresh_token, _ = JWT.create_refresh_token( - data={"sub": test_user_info["username"]}, + data={"sub": test_user_info["email"]}, expires_delta=timedelta(days=7) ) @@ -247,7 +247,7 @@ async def test_refresh_token_format(test_user_info: dict[str, str]): async def test_access_token_format(test_user_info: dict[str, str]): """测试访问token格式正确""" access_token, expires = JWT.create_access_token( - data={"sub": test_user_info["username"]}, + data={"sub": test_user_info["email"]}, expires_delta=timedelta(hours=1) ) diff --git a/tests/test_database.py b/tests/test_database.py index d6e48f7..0077e97 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -3,14 +3,14 @@ import pytest @pytest.mark.asyncio async def test_initialize_db(): """测试创建数据库结构""" - from models import database + from sqlmodels import database await database.init_db(url='sqlite:///:memory:') @pytest.fixture async def db_session(): """测试获取数据库连接Session""" - from models import database + from sqlmodels import database await database.init_db(url='sqlite:///:memory:') @@ -20,8 +20,8 @@ async def db_session(): @pytest.mark.asyncio async def test_migration(): """测试数据库创建并初始化配置""" - from models import migration - from models import database + from sqlmodels import migration + from sqlmodels import database await database.init_db(url='sqlite:///:memory:') diff --git a/tests/test_db_group.py b/tests/test_db_group.py index d88ef2a..9dd3494 100644 --- a/tests/test_db_group.py +++ b/tests/test_db_group.py @@ -3,8 +3,8 @@ import pytest @pytest.mark.asyncio async def test_group_curd(): """测试数据库的增删改查""" - from models import database, migration - from models.group import Group + from sqlmodels import database, migration + from sqlmodels.group import Group await database.init_db(url='sqlite+aiosqlite:///:memory:') diff --git a/tests/test_db_settings.py b/tests/test_db_settings.py index 62c73c4..3a79944 100644 --- a/tests/test_db_settings.py +++ b/tests/test_db_settings.py @@ -3,8 +3,8 @@ import pytest @pytest.mark.asyncio async def test_settings_curd(): """测试数据库的增删改查""" - from models import database - from models.setting import Setting + from sqlmodels import database + from sqlmodels.setting import Setting await database.init_db(url='sqlite:///:memory:') diff --git a/tests/test_db_user.py b/tests/test_db_user.py index 3b343b6..0b6a728 100644 --- a/tests/test_db_user.py +++ b/tests/test_db_user.py @@ -3,9 +3,9 @@ import pytest @pytest.mark.asyncio async def test_user_curd(): """测试数据库的增删改查""" - from models import database, migration - from models.group import Group - from models.user import User + from sqlmodels import database, migration + from sqlmodels.group import Group + from sqlmodels.user import User await database.init_db(url='sqlite+aiosqlite:///:memory:') @@ -17,7 +17,7 @@ async def test_user_curd(): created_group = await test_user_group.save(session) test_user = User( - username='test_user', + email='test_user@test.local', password='test_password', group_id=created_group.id ) @@ -27,7 +27,7 @@ async def test_user_curd(): # 验证用户是否存在 assert created_user.id is not None - assert created_user.username == 'test_user' + assert created_user.email == 'test_user@test.local' assert created_user.password == 'test_password' assert created_user.group_id == created_group.id @@ -35,18 +35,18 @@ async def test_user_curd(): fetched_user = await User.get(session, User.id == created_user.id) assert fetched_user is not None - assert fetched_user.username == 'test_user' + assert fetched_user.email == 'test_user@test.local' assert fetched_user.password == 'test_password' assert fetched_user.group_id == created_group.id # 测试改 Update updated_user = await fetched_user.update( session, - {"username": "updated_user", "password": "updated_password"} + {"email": "updated_user@test.local", "password": "updated_password"} ) assert updated_user is not None - assert updated_user.username == 'updated_user' + assert updated_user.email == 'updated_user@test.local' assert updated_user.password == 'updated_password' # 测试删除 Delete diff --git a/tests/unit/models/test_base.py b/tests/unit/models/test_base.py index 3be8298..d9765b8 100644 --- a/tests/unit/models/test_base.py +++ b/tests/unit/models/test_base.py @@ -8,8 +8,8 @@ import pytest from fastapi import HTTPException from sqlmodel.ext.asyncio.session import AsyncSession -from models.user import User -from models.group import Group +from sqlmodels.user import User +from sqlmodels.group import Group @pytest.mark.asyncio @@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession): group = await group.save(db_session) # 更新数据 - from models.group import GroupBase + from sqlmodels.group import GroupBase update_data = GroupBase(name="更新后名称") updated_group = await group.update(db_session, update_data) @@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession): await asyncio.sleep(0.1) # 更新记录 - from models.group import GroupBase + from sqlmodels.group import GroupBase update_data = GroupBase(name="更新后的名称") group = await group.update(db_session, update_data) diff --git a/tests/unit/models/test_group.py b/tests/unit/models/test_group.py index 4bea9a3..385e722 100644 --- a/tests/unit/models/test_group.py +++ b/tests/unit/models/test_group.py @@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试 import pytest from sqlmodel.ext.asyncio.session import AsyncSession -from models.group import Group, GroupOptions, GroupResponse +from sqlmodels.group import Group, GroupOptions, GroupResponse @pytest.mark.asyncio diff --git a/tests/unit/models/test_object.py b/tests/unit/models/test_object.py index d240c89..928b95f 100644 --- a/tests/unit/models/test_object.py +++ b/tests/unit/models/test_object.py @@ -5,21 +5,21 @@ import pytest from sqlalchemy.exc import IntegrityError from sqlmodel.ext.asyncio.session import AsyncSession -from models.object import Object, ObjectType -from models.user import User -from models.group import Group +from sqlmodels.object import Object, ObjectType +from sqlmodels.user import User +from sqlmodels.group import Group @pytest.mark.asyncio async def test_object_create_folder(db_session: AsyncSession): """测试创建目录""" # 创建必要的依赖数据 - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="testuser", password="password", group_id=group.id) + user = User(email="testuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy( @@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_create_file(db_session: AsyncSession): """测试创建文件""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="testuser", password="password", group_id=group.id) + user = User(email="testuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy( @@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession): # 创建根目录 root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession): owner_id=user.id, policy_id=policy.id, size=1024, - source_name="test_source.txt" ) file = await file.save(db_session) @@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession): assert file.name == "test.txt" assert file.type == ObjectType.FILE assert file.size == 1024 - assert file.source_name == "test_source.txt" @pytest.mark.asyncio async def test_object_is_file_property(db_session: AsyncSession): """测试 is_file 属性""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="testuser", password="password", group_id=group.id) + user = User(email="testuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_is_folder_property(db_session: AsyncSession): """测试 is_folder 属性""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="testuser", password="password", group_id=group.id) + user = User(email="testuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_get_root(db_session: AsyncSession): """测试 get_root() 方法""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="rootuser", password="password", group_id=group.id) + user = User(email="rootuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession): # 创建根目录 root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_get_by_path_root(db_session: AsyncSession): """测试获取根目录""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="pathuser", password="password", group_id=group.id) + user = User(email="pathuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession): # 创建根目录 root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession): root = await root.save(db_session) # 通过路径获取根目录 - result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username) + result = await Object.get_by_path(db_session, user.id, "/") assert result is not None assert result.id == root.id @@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_get_by_path_nested(db_session: AsyncSession): """测试获取嵌套路径""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="nesteduser", password="password", group_id=group.id) + user = User(email="nesteduser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession): # 创建目录结构: root -> docs -> work -> project root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession): result = await Object.get_by_path( db_session, user.id, - "/nesteduser/docs/work/project", - user.username + "/docs/work/project", ) assert result is not None @@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_get_by_path_not_found(db_session: AsyncSession): """测试路径不存在""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="notfounduser", password="password", group_id=group.id) + user = User(email="notfounduser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession): # 创建根目录 root = Object( - name=user.username, + name="/", type=ObjectType.FOLDER, parent_id=None, owner_id=user.id, @@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession): result = await Object.get_by_path( db_session, user.id, - "/notfounduser/nonexistent", - user.username + "/nonexistent", ) assert result is None @@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_get_children(db_session: AsyncSession): """测试 get_children() 方法""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="childrenuser", password="password", group_id=group.id) + user = User(email="childrenuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_parent_child_relationship(db_session: AsyncSession): """测试父子关系""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="reluser", password="password", group_id=group.id) + user = User(email="reluser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession): @pytest.mark.asyncio async def test_object_unique_constraint(db_session: AsyncSession): """测试同目录名称唯一约束""" - from models.policy import Policy, PolicyType + from sqlmodels.policy import Policy, PolicyType group = Group(name="测试组") group = await group.save(db_session) - user = User(username="uniqueuser", password="password", group_id=group.id) + user = User(email="uniqueuser", password="password", group_id=group.id) user = await user.save(db_session) policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") @@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession): with pytest.raises(IntegrityError): await file2.save(db_session) + + +@pytest.mark.asyncio +async def test_object_get_full_path(db_session: AsyncSession): + """测试 get_full_path() 方法""" + from sqlmodels.policy import Policy, PolicyType + + group = Group(name="测试组") + group = await group.save(db_session) + + user = User(email="pathuser", password="password", group_id=group.id) + user = await user.save(db_session) + + policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") + policy = await policy.save(db_session) + + # 创建目录结构: root -> docs -> images -> photo.jpg + root = Object( + name="/", + type=ObjectType.FOLDER, + parent_id=None, + owner_id=user.id, + policy_id=policy.id + ) + root = await root.save(db_session) + + docs = Object( + name="docs", + type=ObjectType.FOLDER, + parent_id=root.id, + owner_id=user.id, + policy_id=policy.id + ) + docs = await docs.save(db_session) + + images = Object( + name="images", + type=ObjectType.FOLDER, + parent_id=docs.id, + owner_id=user.id, + policy_id=policy.id + ) + images = await images.save(db_session) + + photo = Object( + name="photo.jpg", + type=ObjectType.FILE, + parent_id=images.id, + owner_id=user.id, + policy_id=policy.id, + size=2048 + ) + photo = await photo.save(db_session) + + # 测试完整路径 + full_path = await photo.get_full_path(db_session) + assert full_path == "/docs/images/photo.jpg" + + # 测试根目录的 full_path + root_path = await root.get_full_path(db_session) + assert root_path == "/" diff --git a/tests/unit/models/test_setting.py b/tests/unit/models/test_setting.py index aee9330..2a3aa26 100644 --- a/tests/unit/models/test_setting.py +++ b/tests/unit/models/test_setting.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy.exc import IntegrityError from sqlmodel.ext.asyncio.session import AsyncSession -from models.setting import Setting, SettingsType +from sqlmodels.setting import Setting, SettingsType @pytest.mark.asyncio @@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession): setting = await setting.save(db_session) # 更新值 - from models.base import SQLModelBase + from sqlmodels.base import SQLModelBase class SettingUpdate(SQLModelBase): value: str | None = None diff --git a/tests/unit/models/test_uri.py b/tests/unit/models/test_uri.py new file mode 100644 index 0000000..23a40a1 --- /dev/null +++ b/tests/unit/models/test_uri.py @@ -0,0 +1,273 @@ +""" +DiskNextURI 模型的单元测试 +""" +import pytest + +from sqlmodels.uri import DiskNextURI, FileSystemNamespace + + +class TestDiskNextURIParse: + """测试 URI 解析""" + + def test_parse_my_root(self): + """测试解析个人空间根目录""" + uri = DiskNextURI.parse("disknext://my/") + assert uri.namespace == FileSystemNamespace.MY + assert uri.path == "/" + assert uri.fs_id is None + assert uri.password is None + assert uri.is_root is True + + def test_parse_my_with_path(self): + """测试解析个人空间带路径""" + uri = DiskNextURI.parse("disknext://my/docs/readme.md") + assert uri.namespace == FileSystemNamespace.MY + assert uri.path == "/docs/readme.md" + assert uri.fs_id is None + assert uri.path_parts == ["docs", "readme.md"] + assert uri.is_root is False + + def test_parse_my_with_fs_id(self): + """测试解析带 fs_id 的个人空间""" + uri = DiskNextURI.parse("disknext://some-uuid@my/docs") + assert uri.namespace == FileSystemNamespace.MY + assert uri.fs_id == "some-uuid" + assert uri.path == "/docs" + + def test_parse_share_with_code(self): + """测试解析分享链接""" + uri = DiskNextURI.parse("disknext://abc123@share/") + assert uri.namespace == FileSystemNamespace.SHARE + assert uri.fs_id == "abc123" + assert uri.path == "/" + assert uri.password is None + + def test_parse_share_with_password(self): + """测试解析带密码的分享链接""" + uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir") + assert uri.namespace == FileSystemNamespace.SHARE + assert uri.fs_id == "abc123" + assert uri.password == "mypass" + assert uri.path == "/sub/dir" + + def test_parse_trash(self): + """测试解析回收站""" + uri = DiskNextURI.parse("disknext://trash/") + assert uri.namespace == FileSystemNamespace.TRASH + assert uri.is_root is True + + def test_parse_with_query(self): + """测试解析带查询参数的 URI""" + uri = DiskNextURI.parse("disknext://my/?name=report&type=file") + assert uri.namespace == FileSystemNamespace.MY + assert uri.query is not None + assert uri.query["name"] == "report" + assert uri.query["type"] == "file" + + def test_parse_invalid_scheme(self): + """测试无效的协议前缀""" + with pytest.raises(ValueError, match="disknext://"): + DiskNextURI.parse("http://my/docs") + + def test_parse_invalid_namespace(self): + """测试无效的命名空间""" + with pytest.raises(ValueError, match="无效的命名空间"): + DiskNextURI.parse("disknext://invalid/docs") + + def test_parse_no_namespace(self): + """测试缺少命名空间""" + with pytest.raises(ValueError): + DiskNextURI.parse("disknext://") + + +class TestDiskNextURIBuild: + """测试 URI 构建""" + + def test_build_simple(self): + """测试简单构建""" + uri = DiskNextURI.build(FileSystemNamespace.MY) + assert uri.namespace == FileSystemNamespace.MY + assert uri.path == "/" + assert uri.fs_id is None + + def test_build_with_path(self): + """测试带路径构建""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md") + assert uri.path == "/docs/readme.md" + + def test_build_path_auto_prefix(self): + """测试路径自动添加 / 前缀""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md") + assert uri.path == "/docs/readme.md" + + def test_build_with_fs_id(self): + """测试带 fs_id 构建""" + uri = DiskNextURI.build( + FileSystemNamespace.SHARE, + fs_id="abc123", + password="secret", + ) + assert uri.fs_id == "abc123" + assert uri.password == "secret" + + +class TestDiskNextURIToString: + """测试 URI 序列化""" + + def test_to_string_simple(self): + """测试简单序列化""" + uri = DiskNextURI.build(FileSystemNamespace.MY) + assert uri.to_string() == "disknext://my/" + + def test_to_string_with_path(self): + """测试带路径序列化""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md") + assert uri.to_string() == "disknext://my/docs/readme.md" + + def test_to_string_with_fs_id(self): + """测试带 fs_id 序列化""" + uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123") + assert uri.to_string() == "disknext://uuid-123@my/" + + def test_to_string_with_password(self): + """测试带密码序列化""" + uri = DiskNextURI.build( + FileSystemNamespace.SHARE, + fs_id="code", + password="pass", + ) + assert uri.to_string() == "disknext://code:pass@share/" + + def test_to_string_roundtrip(self): + """测试序列化-反序列化往返""" + original = "disknext://abc123:pass@share/sub/dir" + uri = DiskNextURI.parse(original) + result = uri.to_string() + assert result == original + + +class TestDiskNextURIId: + """测试 id() 方法""" + + def test_id_with_fs_id(self): + """测试有 fs_id 时返回 fs_id""" + uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid") + assert uri.id("default") == "my-uuid" + + def test_id_without_fs_id(self): + """测试无 fs_id 时返回默认值""" + uri = DiskNextURI.build(FileSystemNamespace.MY) + assert uri.id("default-uuid") == "default-uuid" + + def test_id_without_fs_id_no_default(self): + """测试无 fs_id 且无默认值时返回 None""" + uri = DiskNextURI.build(FileSystemNamespace.MY) + assert uri.id() is None + + +class TestDiskNextURIJoin: + """测试 join() 方法""" + + def test_join_single(self): + """测试拼接单个路径元素""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs") + joined = uri.join("readme.md") + assert joined.path == "/docs/readme.md" + + def test_join_multiple(self): + """测试拼接多个路径元素""" + uri = DiskNextURI.build(FileSystemNamespace.MY) + joined = uri.join("docs", "work", "report.pdf") + assert joined.path == "/docs/work/report.pdf" + + def test_join_preserves_metadata(self): + """测试 join 保留 namespace 和 fs_id""" + uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123") + joined = uri.join("sub") + assert joined.namespace == FileSystemNamespace.SHARE + assert joined.fs_id == "code123" + + +class TestDiskNextURIDirUri: + """测试 dir_uri() 方法""" + + def test_dir_uri_file(self): + """测试获取文件的父目录 URI""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md") + parent = uri.dir_uri() + assert parent.path == "/docs/" + + def test_dir_uri_root(self): + """测试根目录的 dir_uri 返回自身""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/") + parent = uri.dir_uri() + assert parent.path == "/" + + +class TestDiskNextURIRoot: + """测试 root() 方法""" + + def test_root_resets_path(self): + """测试 root 重置路径""" + uri = DiskNextURI.build( + FileSystemNamespace.MY, + path="/docs/work/report.pdf", + fs_id="uuid-123", + ) + root = uri.root() + assert root.path == "/" + assert root.fs_id == "uuid-123" + assert root.namespace == FileSystemNamespace.MY + + +class TestDiskNextURIName: + """测试 name() 方法""" + + def test_name_file(self): + """测试获取文件名""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md") + assert uri.name() == "readme.md" + + def test_name_directory(self): + """测试获取目录名""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work") + assert uri.name() == "work" + + def test_name_root(self): + """测试根目录的 name 返回空字符串""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/") + assert uri.name() == "" + + +class TestDiskNextURIProperties: + """测试属性方法""" + + def test_path_parts(self): + """测试路径分割""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf") + assert uri.path_parts == ["docs", "work", "report.pdf"] + + def test_path_parts_root(self): + """测试根路径分割""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/") + assert uri.path_parts == [] + + def test_is_root_true(self): + """测试 is_root 为真""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/") + assert uri.is_root is True + + def test_is_root_false(self): + """测试 is_root 为假""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs") + assert uri.is_root is False + + def test_str_representation(self): + """测试字符串表示""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs") + assert str(uri) == "disknext://my/docs" + + def test_repr(self): + """测试 repr""" + uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs") + assert "disknext://my/docs" in repr(uri) diff --git a/tests/unit/models/test_user.py b/tests/unit/models/test_user.py index 99038e1..56d8bfb 100644 --- a/tests/unit/models/test_user.py +++ b/tests/unit/models/test_user.py @@ -5,8 +5,8 @@ import pytest from sqlalchemy.exc import IntegrityError from sqlmodel.ext.asyncio.session import AsyncSession -from models.user import User, ThemeType, UserPublic -from models.group import Group +from sqlmodels.user import User, ThemeType, UserPublic +from sqlmodels.group import Group @pytest.mark.asyncio @@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession): # 创建用户 user = User( - username="testuser", + email="testuser@test.local", nickname="测试用户", password="hashed_password", group_id=group.id @@ -26,7 +26,7 @@ async def test_user_create(db_session: AsyncSession): user = await user.save(db_session) assert user.id is not None - assert user.username == "testuser" + assert user.email == "testuser@test.local" assert user.nickname == "测试用户" assert user.status is True assert user.storage == 0 @@ -34,15 +34,15 @@ async def test_user_create(db_session: AsyncSession): @pytest.mark.asyncio -async def test_user_unique_username(db_session: AsyncSession): - """测试用户名唯一约束""" +async def test_user_unique_email(db_session: AsyncSession): + """测试邮箱唯一约束""" # 创建用户组 group = Group(name="默认组") group = await group.save(db_session) # 创建第一个用户 user1 = User( - username="duplicate", + email="duplicate@test.local", password="password1", group_id=group.id ) @@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession): # 尝试创建同名用户 user2 = User( - username="duplicate", + email="duplicate@test.local", password="password2", group_id=group.id ) @@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession): # 创建用户 user = User( - username="publicuser", + email="publicuser@test.local", nickname="公开用户", password="secret_password", storage=1024, @@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession): assert isinstance(public_user, UserPublic) assert public_user.id == user.id - assert public_user.username == "publicuser" + assert public_user.email == "publicuser@test.local" # 注意: UserPublic.nick 字段名与 User.nickname 不同, # model_validate 不会自动映射,所以 nick 为 None # 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段 @@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession): # 创建用户 user = User( - username="vipuser", + email="vipuser@test.local", password="password", group_id=group.id ) @@ -125,7 +125,7 @@ async def test_user_status_default(db_session: AsyncSession): group = await group.save(db_session) user = User( - username="defaultuser", + email="defaultuser@test.local", password="password", group_id=group.id ) @@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession): group = await group.save(db_session) user = User( - username="storageuser", + email="storageuser@test.local", password="password", group_id=group.id ) @@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试默认值 user1 = User( - username="user1", + email="user1@test.local", password="password", group_id=group.id ) @@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试设置为 LIGHT user2 = User( - username="user2", + email="user2@test.local", password="password", theme=ThemeType.LIGHT, group_id=group.id @@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession): # 测试设置为 DARK user3 = User( - username="user3", + email="user3@test.local", password="password", theme=ThemeType.DARK, group_id=group.id diff --git a/tests/unit/service/test_login.py b/tests/unit/service/test_login.py index e57f34b..a9dde15 100644 --- a/tests/unit/service/test_login.py +++ b/tests/unit/service/test_login.py @@ -4,8 +4,8 @@ Login 服务的单元测试 import pytest from sqlmodel.ext.asyncio.session import AsyncSession -from models.user import User, LoginRequest, TokenResponse -from models.group import Group +from sqlmodels.user import User, LoginRequest, TokenResponse +from sqlmodels.group import Group from service.user.login import login from utils.password.pwd import Password @@ -20,7 +20,7 @@ async def setup_user(db_session: AsyncSession): # 创建正常用户 plain_password = "secure_password_123" user = User( - username="loginuser", + email="loginuser@test.local", password=Password.hash(plain_password), status=True, group_id=group.id @@ -41,7 +41,7 @@ async def setup_banned_user(db_session: AsyncSession): group = await group.save(db_session) user = User( - username="banneduser", + email="banneduser@test.local", password=Password.hash("password"), status=False, # 封禁状态 group_id=group.id @@ -61,7 +61,7 @@ async def setup_2fa_user(db_session: AsyncSession): secret = pyotp.random_base32() user = User( - username="2fauser", + email="2fauser@test.local", password=Password.hash("password"), status=True, two_factor=secret, @@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user): user_data = setup_user login_request = LoginRequest( - username="loginuser", + email="loginuser@test.local", password=user_data["password"] ) @@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user): async def test_login_user_not_found(db_session: AsyncSession): """测试用户不存在""" login_request = LoginRequest( - username="nonexistent_user", + email="nonexistent@test.local", password="any_password" ) @@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession): async def test_login_wrong_password(db_session: AsyncSession, setup_user): """测试密码错误""" login_request = LoginRequest( - username="loginuser", + email="loginuser@test.local", password="wrong_password" ) @@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user): async def test_login_user_banned(db_session: AsyncSession, setup_banned_user): """测试用户被封禁""" login_request = LoginRequest( - username="banneduser", + email="banneduser@test.local", password="password" ) @@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user): user_data = setup_2fa_user login_request = LoginRequest( - username="2fauser", + email="2fauser@test.local", password=user_data["password"] # 未提供 two_fa_code ) @@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user): user_data = setup_2fa_user login_request = LoginRequest( - username="2fauser", + email="2fauser@test.local", password=user_data["password"], two_fa_code="000000" # 错误的验证码 ) @@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user): valid_code = totp.now() login_request = LoginRequest( - username="2fauser", + email="2fauser@test.local", password=user_data["password"], two_fa_code=valid_code ) @@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user): user_data = setup_user login_request = LoginRequest( - username="loginuser", + email="loginuser@test.local", password=user_data["password"] ) @@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user): @pytest.mark.asyncio -async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user): - """测试用户名大小写敏感""" +async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user): + """测试邮箱大小写敏感""" user_data = setup_user - # 使用大写用户名登录(如果数据库是 loginuser) + # 使用大写邮箱登录 login_request = LoginRequest( - username="LOGINUSER", + email="LOGINUSER@TEST.LOCAL", password=user_data["password"] ) result = await login(db_session, login_request) - # 应该失败,因为用户名大小写不匹配 + # 应该失败,因为邮箱大小写不匹配 assert result is None diff --git a/tests/unit/utils/test_password.py b/tests/unit/utils/test_password.py index 057254c..7bf9975 100644 --- a/tests/unit/utils/test_password.py +++ b/tests/unit/utils/test_password.py @@ -72,9 +72,9 @@ def test_password_verify_expired(): @pytest.mark.asyncio async def test_totp_generate(): """测试 TOTP 密钥生成""" - username = "testuser" + email = "testuser@test.local" - response = await Password.generate_totp(username) + response = await Password.generate_totp(email) assert response.setup_token is not None assert response.uri is not None @@ -82,7 +82,7 @@ async def test_totp_generate(): assert isinstance(response.uri, str) # TOTP URI 格式: otpauth://totp/... assert response.uri.startswith("otpauth://totp/") - assert username in response.uri + assert email in response.uri def test_totp_verify_valid(): diff --git a/utils/JWT/__init__.py b/utils/JWT/__init__.py index 408114f..1e6db0f 100644 --- a/utils/JWT/__init__.py +++ b/utils/JWT/__init__.py @@ -4,7 +4,7 @@ from uuid import UUID, uuid4 import jwt from fastapi.security import OAuth2PasswordBearer -from models import AccessTokenBase, RefreshTokenBase +from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse oauth2_scheme = OAuth2PasswordBearer( scheme_name='获取 JWT Bearer 令牌', @@ -21,8 +21,8 @@ async def load_secret_key() -> None: 从数据库读取 JWT 的密钥。 """ # 延迟导入以避免循环依赖 - from models.database import get_session - from models.setting import Setting + from sqlmodels.database import get_session + from sqlmodels.setting import Setting global SECRET_KEY async for session in get_session(): @@ -69,19 +69,29 @@ def build_token_payload( # 访问令牌 def create_access_token( - data: dict, + sub: UUID, + jti: UUID, expires_delta: timedelta | None = None, - algorithm: str = "HS256" + algorithm: str = "HS256", + **kwargs ) -> AccessTokenBase: """ 生成访问令牌,默认有效期 3 小时。 - :param data: 需要放进 JWT Payload 的字段。 + :param sub: 令牌的主题,通常是用户 ID。 + :param jti: 令牌的唯一标识符,通常是一个 UUID。 :param expires_delta: 过期时间, 缺省时为 3 小时。 :param algorithm: JWT 密钥强度,缺省时为 HS256 + :param kwargs: 需要放进 JWT Payload 的字段。 :return: 包含密钥本身和过期时间的 `AccessTokenBase` """ + + data = {"sub": str(sub), "jti": str(jti)} + + # 将额外的字段添加到 Payload 中 + for key, value in kwargs.items(): + data[key] = value access_token, expire_at = build_token_payload( data, @@ -97,19 +107,29 @@ def create_access_token( # 刷新令牌 def create_refresh_token( - data: dict, + sub: UUID, + jti: UUID, expires_delta: timedelta | None = None, - algorithm: str = "HS256" + algorithm: str = "HS256", + **kwargs, ) -> RefreshTokenBase: """ 生成刷新令牌,默认有效期 30 天。 - :param data: 需要放进 JWT Payload 的字段。 + :param sub: 令牌的主题,通常是用户 ID。 + :param jti: 令牌的唯一标识符,通常是一个 UUID。 :param expires_delta: 过期时间, 缺省时为 30 天。 :param algorithm: JWT 密钥强度,缺省时为 HS256 + :param kwargs: 需要放进 JWT Payload 的字段。 :return: 包含密钥本身和过期时间的 `RefreshTokenBase` """ + + data = {"sub": str(sub), "jti": str(jti)} + + # 将额外的字段添加到 Payload 中 + for key, value in kwargs.items(): + data[key] = value refresh_token, expire_at = build_token_payload( data, diff --git a/utils/http/http_exceptions.py b/utils/http/http_exceptions.py index 42ad969..e76f83a 100644 --- a/utils/http/http_exceptions.py +++ b/utils/http/http_exceptions.py @@ -28,6 +28,10 @@ def raise_forbidden(detail: str | None = None, *args, **kwargs) -> NoReturn: """Raises an HTTP 403 Forbidden exception.""" raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs) +def raise_banned(detail: str = "此文件已被管理员封禁,仅允许删除操作", *args, **kwargs) -> NoReturn: + """Raises an HTTP 403 Forbidden exception for banned objects.""" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs) + def raise_not_found(detail: str | None = None, *args, **kwargs) -> NoReturn: """Raises an HTTP 404 Not Found exception.""" raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail, *args, **kwargs) diff --git a/utils/password/pwd.py b/utils/password/pwd.py index eb96844..9b5a33d 100644 --- a/utils/password/pwd.py +++ b/utils/password/pwd.py @@ -73,6 +73,8 @@ class Password: :param length: 密码长度 :type length: int + :param url_safe: 是否生成 URL 安全的密码 + :type url_safe: bool :return: 随机密码 :rtype: str """