feat: add models for physical files, policies, and user management

- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
2026-02-10 16:25:49 +08:00
parent 62c671e07b
commit 209cb24ab4
92 changed files with 3640 additions and 1444 deletions

View File

@@ -3,7 +3,9 @@
"allow": [ "allow": [
"Bash(git rev-parse:*)", "Bash(git rev-parse:*)",
"Bash(findstr:*)", "Bash(findstr:*)",
"Bash(find:*)" "Bash(find:*)",
"Bash(yarn tsc:*)",
"Bash(dir:*)"
] ]
} }
} }

12
main.py
View File

@@ -1,25 +1,29 @@
from typing import NoReturn from typing import NoReturn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from utils.conf import appmeta from utils.conf import appmeta
from utils.http.http_exceptions import raise_internal_error from utils.http.http_exceptions import raise_internal_error
from utils.lifespan import lifespan from utils.lifespan import lifespan
from models.database import init_db from sqlmodels.database_connection import DatabaseManager
from models.migration import migration from sqlmodels.migration import migration
from utils import JWT from utils import JWT
from routers import router from routers import router
from service.redis import RedisManager from service.redis import RedisManager
from loguru import logger as l from loguru import logger as l
async def _init_db() -> None:
"""初始化数据库连接引擎"""
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
# 添加初始化数据库启动项 # 添加初始化数据库启动项
lifespan.add_startup(init_db) lifespan.add_startup(_init_db)
lifespan.add_startup(migration) lifespan.add_startup(migration)
lifespan.add_startup(JWT.load_secret_key) lifespan.add_startup(JWT.load_secret_key)
lifespan.add_startup(RedisManager.connect) lifespan.add_startup(RedisManager.connect)
# 添加关闭项 # 添加关闭项
lifespan.add_shutdown(DatabaseManager.close)
lifespan.add_shutdown(RedisManager.disconnect) lifespan.add_shutdown(RedisManager.disconnect)
# 创建应用实例并设置元数据 # 创建应用实例并设置元数据

View File

@@ -4,7 +4,7 @@ from uuid import UUID
from fastapi import Depends from fastapi import Depends
import jwt import jwt
from models.user import User from sqlmodels.user import User
from utils import JWT from utils import JWT
from .dependencies import SessionDep from .dependencies import SessionDep
from utils import http_exceptions from utils import http_exceptions
@@ -25,8 +25,8 @@ async def auth_required(
user_id = UUID(user_id) user_id = UUID(user_id)
# 从数据库获取用户信息 # 从数据库获取用户信息(预加载 group 关系)
user = await User.get(session, User.id == user_id) user = await User.get(session, User.id == user_id, load=User.group)
if not user: if not user:
http_exceptions.raise_unauthorized("账号或密码错误") http_exceptions.raise_unauthorized("账号或密码错误")
@@ -44,8 +44,7 @@ async def admin_required(
使用方法: 使用方法:
>>> APIRouter(dependencies=[Depends(admin_required)]) >>> APIRouter(dependencies=[Depends(admin_required)])
""" """
group = await user.awaitable_attrs.group if user.group.admin:
if group.admin:
return user return user
raise http_exceptions.raise_forbidden("Admin Required") raise http_exceptions.raise_forbidden("Admin Required")

View File

@@ -14,14 +14,14 @@ from uuid import UUID
from fastapi import Depends, Query from fastapi import Depends, Query
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.database import get_session from sqlmodels.database_connection import DatabaseManager
from models.mixin import TimeFilterRequest, TableViewRequest from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
from models.user import UserFilterParams, UserStatus from sqlmodels.user import UserFilterParams, UserStatus
# --- 数据库会话依赖 --- # --- 数据库会话依赖 ---
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(get_session)] SessionDep: TypeAlias = Annotated[AsyncSession, Depends(DatabaseManager.get_session)]
"""数据库会话依赖,用于路由函数中获取数据库会话""" """数据库会话依赖,用于路由函数中获取数据库会话"""
@@ -79,14 +79,14 @@ TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends(_get_table_
async def _get_user_filter_params( async def _get_user_filter_params(
group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None, group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None,
username: Annotated[str | None, Query(max_length=50, description="用户名模糊搜索")] = None, email: Annotated[str | None, Query(max_length=50, description="邮箱模糊搜索")] = None,
nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None, nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None,
status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None, status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None,
) -> UserFilterParams: ) -> UserFilterParams:
"""解析用户过滤查询参数""" """解析用户过滤查询参数"""
return UserFilterParams( return UserFilterParams(
group_id=group_id, group_id=group_id,
username_contains=username, email_contains=email,
nickname_contains=nickname, nickname_contains=nickname,
status=status, status=status,
) )

View File

@@ -1,456 +0,0 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import String, inspect
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from models.base.sqlmodel_base import SQLModelBase
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的default_factory丢失问题。
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory
遵循单一真相原则(不硬编码 default_factory
需要修复的字段:
- id: 主键(从父类获取 default_factory
- created_at: 创建时间戳(从父类获取 default_factory
- updated_at: 更新时间戳(从父类获取 default_factory
"""
super().__pydantic_init_subclass__(**kwargs)
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 污染的
if not isinstance(field_info.default, InstrumentedAttribute):
return field_info
return None
# 动态检测所有需要修复的字段
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
# 1. default 是 InstrumentedAttribute被 SQLAlchemy 污染)
# 2. 原始定义有 default_factory 或明确的 default 值
#
# 覆盖场景:
# - UUID主键UUIDTableBaseMixinid 有 default_factory=uuid.uuid4需要修复
# - int主键TableBaseMixinid 用 default=None不需要修复数据库自增
# - created_at/updated_at有 default_factory=now需要修复
# - 外键字段created_by_id等有 default=None需要修复
# - 普通字段name, temperature等无 default_factory不需要修复
#
# MRO 查找保证:
# - 在多重继承场景下MRO 顺序是确定性的
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute
if not isinstance(current_field.default, InstrumentedAttribute):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制SQLModel特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
Example:
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

View File

@@ -5,15 +5,15 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from sqlmodels import (
User, ResponseBase, User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo, Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
) )
from models.base import SQLModelBase from sqlmodels.base import SQLModelBase
from models.setting import ( from sqlmodels.setting import (
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse, SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
) )
from models.setting import SettingsType from sqlmodels.setting import SettingsType
from utils import http_exceptions from utils import http_exceptions
from utils.conf import appmeta from utils.conf import appmeta
from .file import admin_file_router from .file import admin_file_router

View File

@@ -5,14 +5,60 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from loguru import logger as l from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from models import ( from sqlmodels import (
Policy, PolicyType, User, ResponseBase, ListResponse, Policy, PolicyType, User, ListResponse,
Object, ObjectType, AdminFileResponse, FileBanRequest, ) Object, ObjectType, AdminFileResponse, FileBanRequest, )
from service.storage import LocalStorageService from service.storage import LocalStorageService
async def _set_ban_recursive(
session: AsyncSession,
obj: Object,
ban: bool,
admin_id: UUID,
reason: str | None,
) -> int:
"""
递归设置封禁状态,返回受影响对象数量。
:param session: 数据库会话
:param obj: 要封禁/解禁的对象
:param ban: True=封禁, False=解禁
:param admin_id: 管理员UUID
:param reason: 封禁原因
:return: 受影响的对象数量
"""
count = 0
# 如果是文件夹,先递归处理子对象
if obj.is_folder:
children = await Object.get(
session,
Object.parent_id == obj.id,
fetch_mode="all",
)
for child in children:
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
# 设置当前对象
obj.is_banned = ban
if ban:
obj.banned_at = datetime.now()
obj.banned_by = admin_id
obj.ban_reason = reason
else:
obj.banned_at = None
obj.banned_by = None
obj.ban_reason = None
await obj.save(session)
count += 1
return count
admin_file_router = APIRouter( admin_file_router = APIRouter(
prefix="/file", prefix="/file",
tags=["admin", "admin_file"], tags=["admin", "admin_file"],
@@ -119,15 +165,17 @@ async def router_admin_preview_file(
summary='封禁/解禁文件', summary='封禁/解禁文件',
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.', description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_ban_file( async def router_admin_ban_file(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
request: FileBanRequest, request: FileBanRequest,
admin: Annotated[User, Depends(admin_required)], admin: Annotated[User, Depends(admin_required)],
) -> ResponseBase: ) -> None:
""" """
封禁或解禁文件。封禁后用户无法访问该文件。 封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
封禁文件夹时会级联封禁所有子对象。
:param session: 数据库会话 :param session: 数据库会话
:param file_id: 文件UUID :param file_id: 文件UUID
@@ -139,24 +187,10 @@ async def router_admin_ban_file(
if not file_obj: if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在") raise HTTPException(status_code=404, detail="文件不存在")
file_obj.is_banned = request.is_banned count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason)
if request.is_banned:
file_obj.banned_at = datetime.now()
file_obj.banned_by = admin.id
file_obj.ban_reason = request.reason
else:
file_obj.banned_at = None
file_obj.banned_by = None
file_obj.ban_reason = None
file_obj = await file_obj.save(session) action = "封禁" if request.ban else "解禁"
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
action = "封禁" if request.is_banned else "解禁"
l.info(f"管理员{action}了文件: {file_obj.name}")
return ResponseBase(data={
"id": str(file_obj.id),
"is_banned": file_obj.is_banned,
})
@admin_file_router.delete( @admin_file_router.delete(
@@ -164,12 +198,13 @@ async def router_admin_ban_file(
summary='删除文件', summary='删除文件',
description='Delete file by ID', description='Delete file by ID',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_delete_file( async def router_admin_delete_file(
session: SessionDep, session: SessionDep,
file_id: UUID, file_id: UUID,
delete_physical: bool = True, delete_physical: bool = True,
) -> ResponseBase: ) -> None:
""" """
删除文件。 删除文件。
@@ -212,4 +247,3 @@ async def router_admin_delete_file(
await Object.delete(session, condition=Object.id == file_obj.id) await Object.delete(session, condition=Object.id == file_obj.id)
l.info(f"管理员删除了文件: {file_name}") l.info(f"管理员删除了文件: {file_name}")
return ResponseBase(data={"deleted": True})

View File

@@ -5,12 +5,12 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from models import ( from sqlmodels import (
User, ResponseBase, UserPublic, ListResponse, User, ResponseBase, UserPublic, ListResponse,
Group, GroupOptions, ) Group, GroupOptions, )
from models.group import ( from sqlmodels.group import (
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, ) GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
from models.policy import GroupPolicyLink from sqlmodels.policy import GroupPolicyLink
admin_group_router = APIRouter( admin_group_router = APIRouter(
prefix="/group", prefix="/group",
@@ -113,11 +113,12 @@ async def router_admin_get_group_members(
summary='创建用户组', summary='创建用户组',
description='Create a new user group', description='Create a new user group',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_create_group( async def router_admin_create_group(
session: SessionDep, session: SessionDep,
request: GroupCreateRequest, request: GroupCreateRequest,
) -> ResponseBase: ) -> None:
""" """
创建新的用户组。 创建新的用户组。
@@ -164,7 +165,6 @@ async def router_admin_create_group(
await session.commit() await session.commit()
l.info(f"管理员创建了用户组: {group.name}") l.info(f"管理员创建了用户组: {group.name}")
return ResponseBase(data={"id": str(group.id), "name": group.name})
@admin_group_router.patch( @admin_group_router.patch(
@@ -172,12 +172,13 @@ async def router_admin_create_group(
summary='更新用户组信息', summary='更新用户组信息',
description='Update user group information by ID', description='Update user group information by ID',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_update_group( async def router_admin_update_group(
session: SessionDep, session: SessionDep,
group_id: UUID, group_id: UUID,
request: GroupUpdateRequest, request: GroupUpdateRequest,
) -> ResponseBase: ) -> None:
""" """
根据用户组ID更新用户组信息。 根据用户组ID更新用户组信息。
@@ -233,8 +234,7 @@ async def router_admin_update_group(
session.add(link) session.add(link)
await session.commit() await session.commit()
l.info(f"管理员更新了用户组: {group.name}") l.info(f"管理员更新了用户组: {group_id}")
return ResponseBase(data={"id": str(group.id)})
@admin_group_router.delete( @admin_group_router.delete(
@@ -242,11 +242,12 @@ async def router_admin_update_group(
summary='删除用户组', summary='删除用户组',
description='Delete user group by ID', description='Delete user group by ID',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_delete_group( async def router_admin_delete_group(
session: SessionDep, session: SessionDep,
group_id: UUID, group_id: UUID,
) -> ResponseBase: ) -> None:
""" """
根据用户组ID删除用户组。 根据用户组ID删除用户组。
@@ -271,5 +272,4 @@ async def router_admin_delete_group(
group_name = group.name group_name = group.name
await Group.delete(session, group) await Group.delete(session, group)
l.info(f"管理员删除了用户组: {group_name}") l.info(f"管理员删除了用户组: {group_id}")
return ResponseBase(data={"deleted": True})

View File

@@ -6,10 +6,10 @@ from sqlmodel import Field
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from models import ( from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase, Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
ListResponse, Object, ) ListResponse, Object, )
from models.base import SQLModelBase from sqlmodels.base import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService from service.storage import DirectoryCreationError, LocalStorageService
admin_policy_router = APIRouter( admin_policy_router = APIRouter(

View File

@@ -5,7 +5,7 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from models import ( from sqlmodels import (
ResponseBase, ListResponse, ResponseBase, ListResponse,
Share, AdminShareListItem, ) Share, AdminShareListItem, )
@@ -80,7 +80,7 @@ async def router_admin_get_share(
"score": share.score, "score": share.score,
"has_password": bool(share.password), "has_password": bool(share.password),
"user_id": str(share.user_id), "user_id": str(share.user_id),
"username": user.username if user else None, "username": user.email if user else None,
"object": { "object": {
"id": str(obj.id), "id": str(obj.id),
"name": obj.name, "name": obj.name,

View File

@@ -5,7 +5,7 @@ from loguru import logger as l
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
from models import ( from sqlmodels import (
ResponseBase, ListResponse, ResponseBase, ListResponse,
Task, TaskSummary, Task, TaskSummary,
) )
@@ -89,7 +89,7 @@ async def router_admin_get_task(
"progress": task.progress, "progress": task.progress,
"error": task.error, "error": task.error,
"user_id": str(task.user_id), "user_id": str(task.user_id),
"username": user.username if user else None, "username": user.email if user else None,
"props": props.model_dump() if props else None, "props": props.model_dump() if props else None,
"created_at": task.created_at.isoformat(), "created_at": task.created_at.isoformat(),
"updated_at": task.updated_at.isoformat(), "updated_at": task.updated_at.isoformat(),

View File

@@ -6,11 +6,13 @@ from sqlalchemy import func
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
from models import ( from sqlmodels import (
User, ResponseBase, UserPublic, ListResponse, User, ResponseBase, UserPublic, ListResponse,
Group, Object, ObjectType, ) Group, Object, ObjectType, Setting, SettingsType,
from models.user import ( BatchDeleteRequest,
UserAdminUpdateRequest, UserCalibrateResponse, )
from sqlmodels.user import (
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse,
) )
from utils import Password, http_exceptions from utils import Password, http_exceptions
@@ -26,19 +28,19 @@ admin_user_router = APIRouter(
description='Get user information by ID', description='Get user information by ID',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
) )
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase: async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
""" """
根据用户ID获取用户信息包括用户名、邮箱、注册时间等。 根据用户ID获取用户信息包括用户名、邮箱、注册时间等。
Args: Args:
session(SessionDep): 数据库会话依赖项。 session(SessionDep): 数据库会话依赖项。
user_id (int): 用户ID。 user_id (UUID): 用户ID。
Returns: Returns:
ResponseBase: 包含用户信息的响应模型。 ResponseBase: 包含用户信息的响应模型。
""" """
user = await User.get_exist_one(session, user_id) user = await User.get_exist_one(session, user_id)
return ResponseBase(data=user.to_public().model_dump()) return user.to_public()
@admin_user_router.get( @admin_user_router.get(
@@ -60,7 +62,7 @@ async def router_admin_get_users(
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态) :param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
:return: 分页用户列表 :return: 分页用户列表
""" """
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view) result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
return ListResponse( return ListResponse(
items=[user.to_public() for user in result.items], items=[user.to_public() for user in result.items],
count=result.count, count=result.count,
@@ -75,22 +77,33 @@ async def router_admin_get_users(
) )
async def router_admin_create_user( async def router_admin_create_user(
session: SessionDep, session: SessionDep,
user: User, request: UserAdminCreateRequest,
) -> ResponseBase: ) -> UserPublic:
""" """
创建一个新的用户,设置用户名、密码等信息。 创建一个新的用户,设置邮箱、密码、用户组等信息。
Returns: :param session: 数据库会话
ResponseBase: 包含创建结果的响应模型。 :param request: 创建用户请求 DTO
:return: 创建结果
""" """
existing_user = await User.get(session, User.username == user.username) existing_user = await User.get(session, User.email == request.email)
if existing_user: if existing_user:
return ResponseBase( raise HTTPException(status_code=409, detail="该邮箱已被注册")
code=400,
msg="User with this username already exists." # 验证用户组存在
group = await Group.get(session, Group.id == request.group_id)
if not group:
raise HTTPException(status_code=400, detail="目标用户组不存在")
user = User(
email=request.email,
password=Password.hash(request.password),
nickname=request.nickname,
group_id=request.group_id,
status=request.status,
) )
user = await user.save(session) user = await user.save(session)
return ResponseBase(data=user.to_public().model_dump()) return user.to_public()
@admin_user_router.patch( @admin_user_router.patch(
@@ -98,12 +111,13 @@ async def router_admin_create_user(
summary='更新用户信息', summary='更新用户信息',
description='Update user information by ID', description='Update user information by ID',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204
) )
async def router_admin_update_user( async def router_admin_update_user(
session: SessionDep, session: SessionDep,
user_id: UUID, user_id: UUID,
request: UserAdminUpdateRequest, request: UserAdminUpdateRequest,
) -> ResponseBase: ) -> None:
""" """
根据用户ID更新用户信息。 根据用户ID更新用户信息。
@@ -116,8 +130,15 @@ async def router_admin_update_user(
if not user: if not user:
raise HTTPException(status_code=404, detail="用户不存在") raise HTTPException(status_code=404, detail="用户不存在")
# 默认管理员(用户名为 admin不允许更改用户组 # 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
if request.group_id and user.username == "admin" and request.group_id != user.group_id: default_admin_setting = await Setting.get(
session,
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
)
if (request.group_id
and default_admin_setting
and default_admin_setting.value == str(user_id)
and request.group_id != user.group_id):
http_exceptions.raise_forbidden("默认管理员不允许更改用户组") http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
# 如果更新用户组,验证新组存在 # 如果更新用户组,验证新组存在
@@ -143,38 +164,35 @@ async def router_admin_update_user(
setattr(user, key, value) setattr(user, key, value)
user = await user.save(session) user = await user.save(session)
l.info(f"管理员更新了用户: {user.username}") l.info(f"管理员更新了用户: {request.email}")
return ResponseBase(data=user.to_public().model_dump())
@admin_user_router.delete( @admin_user_router.delete(
path='/{user_id}', path='/',
summary='删除用户', summary='删除用户(支持批量)',
description='Delete user by ID', description='Delete users by ID list',
dependencies=[Depends(admin_required)], dependencies=[Depends(admin_required)],
status_code=204,
) )
async def router_admin_delete_user( async def router_admin_delete_users(
session: SessionDep, session: SessionDep,
user_id: UUID, request: BatchDeleteRequest,
) -> ResponseBase: ) -> None:
""" """
根据用户ID删除用户及其所有数据。 批量删除用户及其所有数据。
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。 注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
:param session: 数据库会话 :param session: 数据库会话
:param user_id: 用户UUID :param request: 批量删除请求,包含待删除用户的 UUID 列表
:return: 删除结果 :return: 删除结果(已删除数 / 总请求数)
""" """
user = await User.get(session, User.id == user_id) deleted = 0
if not user: for uid in request.ids:
raise HTTPException(status_code=404, detail="用户不存在") user = await User.get(session, User.id == uid)
if user:
username = user.username
await User.delete(session, user) await User.delete(session, user)
l.info(f"管理员删除了用户: {user.email}")
l.info(f"管理员删除了用户: {username}")
return ResponseBase(data={"deleted": True})
@admin_user_router.post( @admin_user_router.post(
@@ -186,7 +204,7 @@ async def router_admin_delete_user(
async def router_admin_calibrate_storage( async def router_admin_calibrate_storage(
session: SessionDep, session: SessionDep,
user_id: UUID, user_id: UUID,
) -> ResponseBase: ) -> UserCalibrateResponse:
""" """
重新计算用户的已用存储空间。 重新计算用户的已用存储空间。
@@ -228,5 +246,5 @@ async def router_admin_calibrate_storage(
file_count=file_count, file_count=file_count,
) )
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}") l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
return ResponseBase(data=response.model_dump()) return response

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from sqlmodels import (
ResponseBase, ResponseBase,
) )

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from models import ResponseBase from sqlmodels import ResponseBase
import service.oauth import service.oauth
from utils import http_exceptions from utils import http_exceptions

View File

@@ -1,10 +1,12 @@
from typing import Annotated from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from sqlmodels import (
DirectoryCreateRequest, DirectoryCreateRequest,
DirectoryResponse, DirectoryResponse,
Object, Object,
@@ -14,50 +16,28 @@ from models import (
User, User,
ResponseBase, ResponseBase,
) )
from utils import http_exceptions
directory_router = APIRouter( directory_router = APIRouter(
prefix="/directory", prefix="/directory",
tags=["directory"] tags=["directory"]
) )
@directory_router.get(
path="/{path:path}", async def _get_directory_response(
summary="获取目录内容", session: AsyncSession,
) user_id: UUID,
async def router_directory_get( folder: Object,
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
path: str
) -> DirectoryResponse: ) -> DirectoryResponse:
""" """
获取目录内容 构建目录响应 DTO
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
:param session: 数据库会话 :param session: 数据库会话
:param user: 当前登录用户 :param user_id: 用户UUID
:param path: 目录路径(必须以用户名开头) :param folder: 目录对象
:return: 目录内容 :return: DirectoryResponse
""" """
# 路径必须以用户名开头 children = await Object.get_children(session, user_id, folder.id)
path = path.strip("/")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
path_parts = path.split("/")
if path_parts[0] != user.username:
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
if not folder:
raise HTTPException(status_code=404, detail="目录不存在")
if not folder.is_folder:
raise HTTPException(status_code=400, detail="指定路径不是目录")
children = await Object.get_children(session, user.id, folder.id)
policy = await folder.awaitable_attrs.policy policy = await folder.awaitable_attrs.policy
objects = [ objects = [
@@ -67,8 +47,8 @@ async def router_directory_get(
thumb=False, thumb=False,
size=child.size, size=child.size,
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE, type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
date=child.updated_at, created_at=child.created_at,
create_date=child.created_at, updated_at=child.updated_at,
source_enabled=False, source_enabled=False,
) )
for child in children for child in children
@@ -89,7 +69,74 @@ async def router_directory_get(
) )
@directory_router.put( @directory_router.get(
path="/",
summary="获取根目录内容",
)
async def router_directory_root(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> DirectoryResponse:
"""
获取当前用户的根目录内容
:param session: 数据库会话
:param user: 当前登录用户
:return: 根目录内容
"""
root = await Object.get_root(session, user.id)
if not root:
raise HTTPException(status_code=404, detail="根目录不存在")
if root.is_banned:
http_exceptions.raise_banned()
return await _get_directory_response(session, user.id, root)
@directory_router.get(
path="/{path:path}",
summary="获取目录内容",
)
async def router_directory_get(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
path: str
) -> DirectoryResponse:
"""
获取目录内容
路径从用户根目录开始,不包含用户名前缀。
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
:param session: 数据库会话
:param user: 当前登录用户
:param path: 目录路径(从根目录开始的相对路径)
:return: 目录内容
"""
path = path.strip("/")
if not path:
# 空路径交给根目录端点处理(理论上不会到达这里)
root = await Object.get_root(session, user.id)
if not root:
raise HTTPException(status_code=404, detail="根目录不存在")
return await _get_directory_response(session, user.id, root)
folder = await Object.get_by_path(session, user.id, "/" + path)
if not folder:
raise HTTPException(status_code=404, detail="目录不存在")
if not folder.is_folder:
raise HTTPException(status_code=400, detail="指定路径不是目录")
if folder.is_banned:
http_exceptions.raise_banned()
return await _get_directory_response(session, user.id, folder)
@directory_router.post(
path="/", path="/",
summary="创建目录", summary="创建目录",
) )
@@ -123,6 +170,9 @@ async def router_directory_create(
if not parent.is_folder: if not parent.is_folder:
raise HTTPException(status_code=400, detail="父路径不是目录") raise HTTPException(status_code=400, detail="父路径不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名对象 # 检查是否已存在同名对象
existing = await Object.get( existing = await Object.get(
session, session,

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import auth_required from middleware.auth import auth_required
from models import ResponseBase from sqlmodels import ResponseBase
from utils import http_exceptions from utils import http_exceptions
download_router = APIRouter( download_router = APIRouter(

View File

@@ -18,7 +18,7 @@ from loguru import logger as l
from middleware.auth import auth_required, verify_download_token from middleware.auth import auth_required, verify_download_token
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from sqlmodels import (
CreateFileRequest, CreateFileRequest,
CreateUploadSessionRequest, CreateUploadSessionRequest,
Object, Object,
@@ -91,6 +91,9 @@ async def create_upload_session(
if not parent.is_folder: if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录") raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 确定存储策略 # 确定存储策略
policy_id = request.policy_id or parent.policy_id policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id) policy = await Policy.get(session, Policy.id == policy_id)
@@ -100,7 +103,7 @@ async def create_upload_session(
# 验证文件大小限制 # 验证文件大小限制
if policy.max_size > 0 and request.file_size > policy.max_size: if policy.max_size > 0 and request.file_size > policy.max_size:
raise HTTPException( raise HTTPException(
status_code=400, status_code=413,
detail=f"文件大小超过限制 ({policy.max_size} bytes)" detail=f"文件大小超过限制 ({policy.max_size} bytes)"
) )
@@ -221,30 +224,40 @@ async def upload_chunk(
upload_session.uploaded_size += len(content) upload_session.uploaded_size += len(content)
upload_session = await upload_session.save(session) upload_session = await upload_session.save(session)
# 检查是否完成 # 在后续可能的 commit 前保存需要的属性
is_complete = upload_session.is_complete is_complete = upload_session.is_complete
uploaded_chunks = upload_session.uploaded_chunks
total_chunks = upload_session.total_chunks
file_object_id: UUID | None = None file_object_id: UUID | None = None
if is_complete: if is_complete:
# 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name
uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path
upload_session_id = upload_session.id
parent_id = upload_session.parent_id
policy_id = upload_session.policy_id
# 创建 PhysicalFile 记录 # 创建 PhysicalFile 记录
physical_file = PhysicalFile( physical_file = PhysicalFile(
storage_path=upload_session.storage_path, storage_path=storage_path,
size=upload_session.uploaded_size, size=uploaded_size,
policy_id=upload_session.policy_id, policy_id=policy_id,
reference_count=1, reference_count=1,
) )
physical_file = await physical_file.save(session, commit=False) physical_file = await physical_file.save(session, commit=False)
# 创建 Object 记录 # 创建 Object 记录
file_object = Object( file_object = Object(
name=upload_session.file_name, name=file_name,
type=ObjectType.FILE, type=ObjectType.FILE,
size=upload_session.uploaded_size, size=uploaded_size,
physical_file_id=physical_file.id, physical_file_id=physical_file.id,
upload_session_id=str(upload_session.id), upload_session_id=str(upload_session_id),
parent_id=upload_session.parent_id, parent_id=parent_id,
owner_id=user_id, owner_id=user_id,
policy_id=upload_session.policy_id, policy_id=policy_id,
) )
file_object = await file_object.save(session, commit=False) file_object = await file_object.save(session, commit=False)
file_object_id = file_object.id file_object_id = file_object.id
@@ -252,18 +265,18 @@ async def upload_chunk(
# 删除上传会话(使用条件删除) # 删除上传会话(使用条件删除)
await UploadSession.delete( await UploadSession.delete(
session, session,
condition=UploadSession.id == upload_session.id, condition=UploadSession.id == upload_session_id,
commit=False commit=False
) )
# 统一提交所有更改 # 统一提交所有更改
await session.commit() await session.commit()
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}") l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}")
return UploadChunkResponse( return UploadChunkResponse(
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks, uploaded_chunks=uploaded_chunks if not is_complete else total_chunks,
total_chunks=upload_session.total_chunks, total_chunks=total_chunks,
is_complete=is_complete, is_complete=is_complete,
object_id=file_object_id, object_id=file_object_id,
) )
@@ -368,6 +381,9 @@ async def create_download_token_endpoint(
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
if file_obj.is_banned:
http_exceptions.raise_banned()
token = create_download_token(file_id, user.id) token = create_download_token(file_id, user.id)
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}") l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
@@ -410,6 +426,9 @@ async def download_file(
if not file_obj.is_file: if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件") raise HTTPException(status_code=400, detail="对象不是文件")
if file_obj.is_banned:
http_exceptions.raise_banned()
# 预加载 physical_file 关系以获取存储路径 # 预加载 physical_file 关系以获取存储路径
physical_file = await file_obj.awaitable_attrs.physical_file physical_file = await file_obj.awaitable_attrs.physical_file
if not physical_file or not physical_file.storage_path: if not physical_file or not physical_file.storage_path:
@@ -470,6 +489,9 @@ async def create_empty_file(
if not parent.is_folder: if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录") raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名文件 # 检查是否已存在同名文件
existing = await Object.get( existing = await Object.get(
session, session,

View File

@@ -1,19 +0,0 @@
from fastapi import APIRouter
from models import MCPRequestBase, MCPResponseBase, MCPMethod
# MCP 路由
MCP_router = APIRouter(
prefix='/mcp',
tags=["mcp"],
)
@MCP_router.get(
"/",
)
async def mcp_root(
param: MCPRequestBase
):
match param.method:
case MCPMethod.PING:
return MCPResponseBase(result="pong", **param.model_dump())

View File

@@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from sqlmodels import (
CreateFileRequest,
Object, Object,
ObjectCopyRequest, ObjectCopyRequest,
ObjectDeleteRequest, ObjectDeleteRequest,
@@ -26,10 +27,11 @@ from models import (
PhysicalFile, PhysicalFile,
Policy, Policy,
PolicyType, PolicyType,
ResponseBase,
User, User,
) )
from models import ResponseBase
from service.storage import LocalStorageService from service.storage import LocalStorageService
from utils import http_exceptions
object_router = APIRouter( object_router = APIRouter(
prefix="/object", prefix="/object",
@@ -59,15 +61,22 @@ async def _delete_object_recursive(
""" """
deleted_count = 0 deleted_count = 0
if obj.is_folder: # 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
obj_id = obj.id
obj_name = obj.name
obj_is_folder = obj.is_folder
obj_is_file = obj.is_file
obj_physical_file_id = obj.physical_file_id
if obj_is_folder:
# 递归删除子对象 # 递归删除子对象
children = await Object.get_children(session, user_id, obj.id) children = await Object.get_children(session, user_id, obj_id)
for child in children: for child in children:
deleted_count += await _delete_object_recursive(session, child, user_id) deleted_count += await _delete_object_recursive(session, child, user_id)
# 如果是文件,处理物理文件引用 # 如果是文件,处理物理文件引用
if obj.is_file and obj.physical_file_id: if obj_is_file and obj_physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id) physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
if physical_file: if physical_file:
# 减少引用计数 # 减少引用计数
new_count = physical_file.decrement_reference() new_count = physical_file.decrement_reference()
@@ -81,11 +90,11 @@ async def _delete_object_recursive(
await storage_service.move_to_trash( await storage_service.move_to_trash(
source_path=physical_file.storage_path, source_path=physical_file.storage_path,
user_id=user_id, user_id=user_id,
object_id=obj.id, object_id=obj_id,
) )
l.debug(f"物理文件已移动到回收站: {obj.name}") l.debug(f"物理文件已移动到回收站: {obj_name}")
except Exception as e: except Exception as e:
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}") l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
# 删除 PhysicalFile 记录 # 删除 PhysicalFile 记录
await PhysicalFile.delete(session, physical_file) await PhysicalFile.delete(session, physical_file)
@@ -95,8 +104,8 @@ async def _delete_object_recursive(
await physical_file.save(session) await physical_file.save(session)
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}") l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
# 删除数据库记录 # 使用条件删除,避免访问过期的 obj 实例
await Object.delete(session, obj) await Object.delete(session, condition=Object.id == obj_id)
deleted_count += 1 deleted_count += 1
return deleted_count return deleted_count
@@ -168,6 +177,97 @@ async def _copy_object_recursive(
return copied_count, new_ids return copied_count, new_ids
@object_router.post(
path='/',
summary='创建空白文件',
description='在指定目录下创建空白文件。',
)
async def router_object_create(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: CreateFileRequest,
) -> ResponseBase:
"""
创建空白文件端点
:param session: 数据库会话
:param user: 当前登录用户
:param request: 创建文件请求parent_id, name
:return: 创建结果
"""
user_id = user.id
# 验证文件名
if not request.name or '/' in request.name or '\\' in request.name:
raise HTTPException(status_code=400, detail="无效的文件名")
# 验证父目录
parent = await Object.get(session, Object.id == request.parent_id)
if not parent or parent.owner_id != user_id:
raise HTTPException(status_code=404, detail="父目录不存在")
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名文件
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == parent.id) &
(Object.name == request.name)
)
if existing:
raise HTTPException(status_code=409, detail="同名文件已存在")
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
parent_id = parent.id
# 生成存储路径并创建空文件
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
size=0,
policy_id=policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session)
# 创建 Object 记录
file_object = Object(
name=request.name,
type=ObjectType.FILE,
size=0,
physical_file_id=physical_file.id,
parent_id=parent_id,
owner_id=user_id,
policy_id=policy_id,
)
await file_object.save(session)
l.info(f"创建空白文件: {request.name}")
return ResponseBase()
@object_router.delete( @object_router.delete(
path='/', path='/',
summary='删除对象', summary='删除对象',
@@ -197,10 +297,7 @@ async def router_object_delete(
user_id = user.id user_id = user.id
deleted_count = 0 deleted_count = 0
# 处理单个 UUID 或 UUID 列表 for obj_id in request.ids:
ids = request.ids if isinstance(request.ids, list) else [request.ids]
for obj_id in ids:
obj = await Object.get(session, Object.id == obj_id) obj = await Object.get(session, Object.id == obj_id)
if not obj or obj.owner_id != user_id: if not obj or obj.owner_id != user_id:
continue continue
@@ -219,7 +316,7 @@ async def router_object_delete(
return ResponseBase( return ResponseBase(
data={ data={
"deleted": deleted_count, "deleted": deleted_count,
"total": len(ids), "total": len(request.ids),
} }
) )
@@ -253,6 +350,9 @@ async def router_object_move(
if not dst.is_folder: if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹") raise HTTPException(status_code=400, detail="目标不是有效文件夹")
if dst.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问 # 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
dst_id = dst.id dst_id = dst.id
dst_parent_id = dst.parent_id dst_parent_id = dst.parent_id
@@ -264,6 +364,9 @@ async def router_object_move(
if not src or src.owner_id != user_id: if not src or src.owner_id != user_id:
continue continue
if src.is_banned:
continue
# 不能移动根目录 # 不能移动根目录
if src.parent_id is None: if src.parent_id is None:
continue continue
@@ -348,6 +451,9 @@ async def router_object_copy(
if not dst.is_folder: if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹") raise HTTPException(status_code=400, detail="目标不是有效文件夹")
if dst.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
copied_count = 0 copied_count = 0
new_ids: list[UUID] = [] new_ids: list[UUID] = []
@@ -356,6 +462,9 @@ async def router_object_copy(
if not src or src.owner_id != user_id: if not src or src.owner_id != user_id:
continue continue
if src.is_banned:
continue
# 不能复制根目录 # 不能复制根目录
if src.parent_id is None: if src.parent_id is None:
continue continue
@@ -438,6 +547,9 @@ async def router_object_rename(
if obj.owner_id != user_id: if obj.owner_id != user_id:
raise HTTPException(status_code=403, detail="无权操作此对象") raise HTTPException(status_code=403, detail="无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 不能重命名根目录 # 不能重命名根目录
if obj.parent_id is None: if obj.parent_id is None:
raise HTTPException(status_code=400, detail="无法重命名根目录") raise HTTPException(status_code=400, detail="无法重命名根目录")
@@ -543,7 +655,7 @@ async def router_object_property_detail(
policy_name = policy.name if policy else None policy_name = policy.name if policy else None
# 获取分享统计 # 获取分享统计
from models import Share from sqlmodels import Share
shares = await Share.get( shares = await Share.get(
session, session,
Share.object_id == obj.id, Share.object_id == obj.id,

View File

@@ -7,11 +7,11 @@ from loguru import logger as l
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ResponseBase from sqlmodels import ResponseBase
from models.user import User from sqlmodels.user import User
from models.share import Share, ShareCreateRequest, ShareResponse from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
from models.object import Object from sqlmodels.object import Object
from models.mixin import ListResponse, TableViewRequest from sqlmodels.mixin import ListResponse, TableViewRequest
from utils import http_exceptions from utils import http_exceptions
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase:
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@share_router.get(
path='/doc/{id}',
summary='取得Office文档预览地址',
description='Get Office document preview URL by ID.',
)
def router_share_doc(id: str) -> ResponseBase:
"""
Get Office document preview URL by ID.
Args:
id (str): The ID of the Office document.
Returns:
dict: A dictionary containing the document preview URL.
"""
http_exceptions.raise_not_implemented()
@share_router.get( @share_router.get(
path='/content/{id}', path='/content/{id}',
summary='获取文本文件内容', summary='获取文本文件内容',
@@ -261,6 +244,9 @@ async def router_share_create(
if not obj or obj.owner_id != user.id: if not obj or obj.owner_id != user.id:
raise HTTPException(status_code=404, detail="对象不存在或无权限") raise HTTPException(status_code=404, detail="对象不存在或无权限")
if obj.is_banned:
http_exceptions.raise_banned()
# 生成分享码 # 生成分享码
code = str(uuid4()) code = str(uuid4())

View File

@@ -1,7 +1,8 @@
from fastapi import APIRouter from fastapi import APIRouter
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
from sqlmodels.setting import CaptchaType
from utils import http_exceptions from utils import http_exceptions
site_router = APIRouter( site_router = APIRouter(
@@ -43,16 +44,43 @@ def router_site_captcha():
@site_router.get( @site_router.get(
path='/config', path='/config',
summary='站点全局配置', summary='站点全局配置',
description='Get the configuration file.', description='获取站点全局配置,包括验证码设置、注册开关等。',
response_model=ResponseBase,
) )
async def router_site_config(session: SessionDep) -> SiteConfigResponse: async def router_site_config(session: SessionDep) -> SiteConfigResponse:
""" """
Get the configuration file. 获取站点全局配置
Returns: 无需认证。前端在初始化时调用此端点获取验证码类型、
dict: The site configuration. 登录/注册/找回密码是否需要验证码等配置。
""" """
return SiteConfigResponse( # 批量查询所需设置
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")), settings: list[Setting] = await Setting.get(
session,
(Setting.type == SettingsType.BASIC) |
(Setting.type == SettingsType.LOGIN) |
(Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA),
fetch_mode="all",
)
# 构建 name→value 映射
s: dict[str, str | None] = {item.name: item.value for item in settings}
# 根据 captcha_type 选择对应的 public key
captcha_type_str = s.get("captcha_type", "default")
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
captcha_key: str | None = None
if captcha_type == CaptchaType.GCAPTCHA:
captcha_key = s.get("captcha_ReCaptchaKey") or None
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
captcha_key = s.get("captcha_CloudflareKey") or None
return SiteConfigResponse(
title=s.get("siteName") or "DiskNext",
register_enabled=s.get("register_enabled") == "1",
login_captcha=s.get("login_captcha") == "1",
reg_captcha=s.get("reg_captcha") == "1",
forget_captcha=s.get("forget_captcha") == "1",
captcha_type=captcha_type,
captcha_key=captcha_key,
) )

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from middleware.auth import auth_required from middleware.auth import auth_required
from models import ResponseBase from sqlmodels import ResponseBase
from utils import http_exceptions from utils import http_exceptions
slave_router = APIRouter( slave_router = APIRouter(

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import auth_required from middleware.auth import auth_required
from models import ResponseBase from sqlmodels import ResponseBase
from utils import http_exceptions from utils import http_exceptions
tag_router = APIRouter( tag_router = APIRouter(

View File

@@ -1,30 +1,26 @@
from typing import Annotated, Literal from typing import Annotated, Literal
from uuid import UUID from uuid import UUID, uuid4
import jwt
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from loguru import logger
from webauthn import generate_registration_options from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict from webauthn.helpers import options_to_json_dict
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from loguru import logger
import models
import service import service
import sqlmodels
from middleware.auth import auth_required from middleware.auth import auth_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from utils.JWT import SECRET_KEY from utils import JWT, Password, http_exceptions
from utils import Password, http_exceptions from .settings import user_settings_router
user_router = APIRouter( user_router = APIRouter(
prefix="/user", prefix="/user",
tags=["user"], tags=["user"],
) )
user_settings_router = APIRouter( user_router.include_router(user_settings_router)
prefix='/user/settings',
tags=["user", "user_settings"],
dependencies=[Depends(auth_required)],
)
@user_router.post( @user_router.post(
path='/session', path='/session',
@@ -34,7 +30,7 @@ user_settings_router = APIRouter(
async def router_user_session( async def router_user_session(
session: SessionDep, session: SessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> models.TokenResponse: ) -> sqlmodels.TokenResponse:
""" """
用户登录端点。 用户登录端点。
@@ -43,7 +39,7 @@ async def router_user_session(
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码 OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
""" """
username = form_data.username email = form_data.username # OAuth2 表单字段名为 username实际传入的是 email
password = form_data.password password = form_data.password
# 从 scopes 中提取 OTP 验证码OAuth2.1 扩展方式) # 从 scopes 中提取 OTP 验证码OAuth2.1 扩展方式)
@@ -59,8 +55,8 @@ async def router_user_session(
result = await service.user.login( result = await service.user.login(
session, session,
models.LoginRequest( sqlmodels.LoginRequest(
username=username, email=email,
password=password, password=password,
two_fa_code=otp_code, two_fa_code=otp_code,
), ),
@@ -75,19 +71,70 @@ async def router_user_session(
) )
async def router_user_session_refresh( async def router_user_session_refresh(
session: SessionDep, session: SessionDep,
request, # RefreshTokenRequest request: sqlmodels.RefreshTokenRequest,
) -> models.TokenResponse: ) -> sqlmodels.TokenResponse:
http_exceptions.raise_not_implemented() """
使用 refresh_token 签发新的 access_token 和 refresh_token。
流程:
1. 解码 refresh_token JWT
2. 验证 token_type 为 refresh
3. 验证用户存在且状态正常
4. 签发新的 access_token + refresh_token
:param session: 数据库会话
:param request: 刷新令牌请求
:return: 新的 TokenResponse
"""
try:
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
# 验证是 refresh token
if payload.get("token_type") != "refresh":
http_exceptions.raise_unauthorized("非刷新令牌")
user_id_str = payload.get("sub")
if not user_id_str:
http_exceptions.raise_unauthorized("令牌缺少用户标识")
user_id = UUID(user_id_str)
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if not user.status:
http_exceptions.raise_forbidden("账户已被禁用")
# 签发新令牌
access_token = JWT.create_access_token(
sub=user.id,
jti=uuid4(),
)
refresh_token = JWT.create_refresh_token(
sub=user.id,
jti=uuid4(),
)
return sqlmodels.TokenResponse(
access_token=access_token.access_token,
access_expires=access_token.access_expires,
refresh_token=refresh_token.refresh_token,
refresh_expires=refresh_token.refresh_expires,
)
@user_router.post( @user_router.post(
path='/', path='/',
summary='用户注册', summary='用户注册',
description='User registration endpoint.', description='User registration endpoint.',
status_code=204,
) )
async def router_user_register( async def router_user_register(
session: SessionDep, session: SessionDep,
request: models.RegisterRequest, request: sqlmodels.RegisterRequest,
) -> models.ResponseBase: ) -> None:
""" """
用户注册端点 用户注册端点
@@ -95,7 +142,7 @@ async def router_user_register(
1. 验证用户名唯一性 1. 验证用户名唯一性
2. 获取默认用户组 2. 获取默认用户组
3. 创建用户记录 3. 创建用户记录
4. 创建用户名命名的根目录 4. 创建用户根目录name="/"
:param session: 数据库会话 :param session: 数据库会话
:param request: 注册请求 :param request: 注册请求
@@ -103,62 +150,53 @@ async def router_user_register(
:raises HTTPException 400: 用户名已存在 :raises HTTPException 400: 用户名已存在
:raises HTTPException 500: 默认用户组或存储策略不存在 :raises HTTPException 500: 默认用户组或存储策略不存在
""" """
# 1. 验证用户名唯一性 # 1. 验证邮箱唯一性
existing_user = await models.User.get( existing_user = await sqlmodels.User.get(
session, session,
models.User.username == request.username sqlmodels.User.email == request.email
) )
if existing_user: if existing_user:
raise HTTPException(status_code=400, detail="用户名已存在") raise HTTPException(status_code=400, detail="邮箱已存在")
# 2. 获取默认用户组(从设置中读取 UUID # 2. 获取默认用户组(从设置中读取 UUID
default_group_setting: models.Setting | None = await models.Setting.get( default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
session, session,
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group") (sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
) )
if default_group_setting is None or not default_group_setting.value: if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
default_group_id = UUID(default_group_setting.value) default_group_id = UUID(default_group_setting.value)
default_group = await models.Group.get(session, models.Group.id == default_group_id) default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
if not default_group: if not default_group:
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
# 3. 创建用户 # 3. 创建用户
hashed_password = Password.hash(request.password) hashed_password = Password.hash(request.password)
new_user = models.User( new_user = sqlmodels.User(
username=request.username, email=request.email,
password=hashed_password, password=hashed_password,
group_id=default_group.id, group_id=default_group.id,
) )
new_user_id = new_user.id # 在 save 前保存 UUID new_user_id = new_user.id
new_user_username = new_user.username
await new_user.save(session) await new_user.save(session)
# 4. 创建用户名命名的根目录 # 4. 创建用户根目录
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储") default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
if not default_policy: if not default_policy:
logger.error("默认存储策略不存在") logger.error("默认存储策略不存在")
http_exceptions.raise_internal_error() http_exceptions.raise_internal_error()
await models.Object( await sqlmodels.Object(
name=new_user_username, name="/",
type=models.ObjectType.FOLDER, type=sqlmodels.ObjectType.FOLDER,
owner_id=new_user_id, owner_id=new_user_id,
parent_id=None, parent_id=None,
policy_id=default_policy.id, policy_id=default_policy.id,
).save(session) ).save(session)
return models.ResponseBase(
data={
"user_id": new_user_id,
"username": new_user_username,
},
msg="注册成功",
)
@user_router.post( @user_router.post(
path='/code', path='/code',
summary='发送验证码邮件', summary='发送验证码邮件',
@@ -166,7 +204,7 @@ async def router_user_register(
) )
def router_user_email_code( def router_user_email_code(
reason: Literal['register', 'reset'] = 'register', reason: Literal['register', 'reset'] = 'register',
) -> models.ResponseBase: ) -> sqlmodels.ResponseBase:
""" """
Send a verification code email. Send a verification code email.
@@ -180,7 +218,7 @@ def router_user_email_code(
summary='初始化QQ登录', summary='初始化QQ登录',
description='Initialize QQ login for a user.', description='Initialize QQ login for a user.',
) )
def router_user_qq() -> models.ResponseBase: def router_user_qq() -> sqlmodels.ResponseBase:
""" """
Initialize QQ login for a user. Initialize QQ login for a user.
@@ -194,7 +232,7 @@ def router_user_qq() -> models.ResponseBase:
summary='WebAuthn登录初始化', summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.', description='Initialize WebAuthn login for a user.',
) )
async def router_user_authn(username: str) -> models.ResponseBase: async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@@ -203,7 +241,7 @@ async def router_user_authn(username: str) -> models.ResponseBase:
summary='WebAuthn登录', summary='WebAuthn登录',
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
) )
def router_user_authn_finish(username: str) -> models.ResponseBase: def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -220,7 +258,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
summary='获取用户主页展示用分享', summary='获取用户主页展示用分享',
description='Get user profile for display.', description='Get user profile for display.',
) )
def router_user_profile(id: str) -> models.ResponseBase: def router_user_profile(id: str) -> sqlmodels.ResponseBase:
""" """
Get user profile for display. Get user profile for display.
@@ -237,7 +275,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
summary='获取用户头像', summary='获取用户头像',
description='Get user avatar by ID and size.', description='Get user avatar by ID and size.',
) )
def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase: def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
""" """
Get user avatar by ID and size. Get user avatar by ID and size.
@@ -259,12 +297,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
summary='获取用户信息', summary='获取用户信息',
description='Get user information.', description='Get user information.',
dependencies=[Depends(dependency=auth_required)], dependencies=[Depends(dependency=auth_required)],
response_model=models.UserResponse, response_model=sqlmodels.UserResponse,
) )
async def router_user_me( async def router_user_me(
session: SessionDep, session: SessionDep,
user: Annotated[models.User, Depends(auth_required)], user: Annotated[sqlmodels.User, Depends(auth_required)],
) -> models.ResponseBase: ) -> sqlmodels.UserResponse:
""" """
获取用户信息. 获取用户信息.
@@ -272,10 +310,10 @@ async def router_user_me(
:rtype: ResponseBase :rtype: ResponseBase
""" """
# 加载 group 及其 options 关系 # 加载 group 及其 options 关系
group = await models.Group.get( group = await sqlmodels.Group.get(
session, session,
models.Group.id == user.group_id, sqlmodels.Group.id == user.group_id,
load=models.Group.options load=sqlmodels.Group.options
) )
# 构建 GroupResponse # 构建 GroupResponse
@@ -284,9 +322,9 @@ async def router_user_me(
# 异步加载 tags 关系 # 异步加载 tags 关系
user_tags = await user.awaitable_attrs.tags user_tags = await user.awaitable_attrs.tags
return models.UserResponse( return sqlmodels.UserResponse(
id=user.id, id=user.id,
username=user.username, email=user.email,
status=user.status, status=user.status,
score=user.score, score=user.score,
nickname=user.nickname, nickname=user.nickname,
@@ -304,30 +342,26 @@ async def router_user_me(
) )
async def router_user_storage( async def router_user_storage(
session: SessionDep, session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> models.ResponseBase: ) -> sqlmodels.UserStorageResponse:
""" """
获取用户存储空间信息。 获取用户存储空间信息。
返回值:
- used: 已使用空间(字节)
- free: 剩余空间(字节)
- total: 总容量(字节)= 用户组容量
""" """
# 获取用户组的基础存储容量 # 获取用户组的基础存储容量
group = await models.Group.get(session, models.Group.id == user.group_id) group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
if not group: if not group:
raise HTTPException(status_code=500, detail="用户组不存在") raise HTTPException(status_code=404, detail="用户组不存在")
# [TODO] 总空间加上用户购买的额外空间
total: int = group.max_storage total: int = group.max_storage
used: int = user.storage used: int = user.storage
free: int = max(0, total - used) free: int = max(0, total - used)
return models.ResponseBase( return sqlmodels.UserStorageResponse(
data={ used=used,
"used": used, free=free,
"free": free, total=total,
"total": total,
}
) )
@user_router.put( @user_router.put(
@@ -338,8 +372,8 @@ async def router_user_storage(
) )
async def router_user_authn_start( async def router_user_authn_start(
session: SessionDep, session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)], user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> models.ResponseBase: ) -> sqlmodels.ResponseBase:
""" """
Initialize WebAuthn login for a user. Initialize WebAuthn login for a user.
@@ -347,30 +381,30 @@ async def router_user_authn_start(
dict: A dictionary containing WebAuthn initialization information. dict: A dictionary containing WebAuthn initialization information.
""" """
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等 # TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await models.Setting.get( authn_setting = await sqlmodels.Setting.get(
session, session,
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled") (sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
) )
if not authn_setting or authn_setting.value != "1": if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled") raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
site_url_setting = await models.Setting.get( site_url_setting = await sqlmodels.Setting.get(
session, session,
(models.Setting.type == "basic") & (models.Setting.name == "siteURL") (sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
) )
site_title_setting = await models.Setting.get( site_title_setting = await sqlmodels.Setting.get(
session, session,
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle") (sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
) )
options = generate_registration_options( options = generate_registration_options(
rp_id=site_url_setting.value if site_url_setting else "", rp_id=site_url_setting.value if site_url_setting else "",
rp_name=site_title_setting.value if site_title_setting else "", rp_name=site_title_setting.value if site_title_setting else "",
user_name=user.username, user_name=user.email,
user_display_name=user.nick or user.username, user_display_name=user.nickname or user.email,
) )
return models.ResponseBase(data=options_to_json_dict(options)) return sqlmodels.ResponseBase(data=options_to_json_dict(options))
@user_router.put( @user_router.put(
path='/authn/finish', path='/authn/finish',
@@ -378,7 +412,7 @@ async def router_user_authn_start(
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
dependencies=[Depends(auth_required)], dependencies=[Depends(auth_required)],
) )
def router_user_authn_finish() -> models.ResponseBase: def router_user_authn_finish() -> sqlmodels.ResponseBase:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -386,171 +420,3 @@ def router_user_authn_finish() -> models.ResponseBase:
dict: A dictionary containing WebAuthn login information. dict: A dictionary containing WebAuthn login information.
""" """
http_exceptions.raise_not_implemented() http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
def router_user_settings_policies() -> models.ResponseBase:
"""
Get user selectable storage policies.
Returns:
dict: A dictionary containing available storage policies for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/nodes',
summary='获取用户可选节点',
description='Get user selectable nodes.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_nodes() -> models.ResponseBase:
"""
Get user selectable nodes.
Returns:
dict: A dictionary containing available nodes for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/tasks',
summary='任务队列',
description='Get user task queue.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_tasks() -> models.ResponseBase:
"""
Get user task queue.
Returns:
dict: A dictionary containing the user's task queue information.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/',
summary='获取当前用户设定',
description='Get current user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings() -> models.ResponseBase:
"""
Get current user settings.
Returns:
dict: A dictionary containing the current user settings.
"""
return models.ResponseBase(data=models.UserSettingResponse().model_dump())
@user_settings_router.post(
path='/avatar',
summary='从文件上传头像',
description='Upload user avatar from file.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar() -> models.ResponseBase:
"""
Upload user avatar from file.
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
"""
Set user avatar to Gravatar.
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
path='/{option}',
summary='更新用户设定',
description='Update user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_patch(option: str) -> models.ResponseBase:
"""
Update user settings.
Args:
option (str): The setting option to update.
Returns:
dict: A dictionary containing the result of the settings update.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/2fa',
summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa(
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
Get two-factor authentication initialization information.
Returns:
dict: A dictionary containing two-factor authentication setup information.
"""
return models.ResponseBase(
data=await Password.generate_totp(user.username)
)
@user_settings_router.post(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)],
setup_token: str,
code: str,
) -> models.ResponseBase:
"""
Enable two-factor authentication for the user.
Returns:
dict: A dictionary containing the result of enabling two-factor authentication.
"""
serializer = URLSafeTimedSerializer(SECRET_KEY)
try:
# 1. 解包 Token设置有效期例如 600秒
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码
if not Password.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
user.two_factor = secret
user = await user.save(session)
return models.ResponseBase(
data={"message": "Two-factor authentication enabled successfully"}
)

View File

@@ -0,0 +1,203 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import sqlmodels
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from utils import JWT, Password, http_exceptions
user_settings_router = APIRouter(
prefix='/settings',
tags=["user", "user_settings"],
dependencies=[Depends(auth_required)],
)
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
def router_user_settings_policies() -> sqlmodels.ResponseBase:
"""
Get user selectable storage policies.
Returns:
dict: A dictionary containing available storage policies for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/nodes',
summary='获取用户可选节点',
description='Get user selectable nodes.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
"""
Get user selectable nodes.
Returns:
dict: A dictionary containing available nodes for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/tasks',
summary='任务队列',
description='Get user task queue.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
"""
Get user task queue.
Returns:
dict: A dictionary containing the user's task queue information.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/',
summary='获取当前用户设定',
description='Get current user settings.',
)
def router_user_settings(
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.UserSettingResponse:
"""
Get current user settings.
Returns:
dict: A dictionary containing the current user settings.
"""
return sqlmodels.UserSettingResponse(
id=user.id,
email=user.email,
nickname=user.nickname,
created_at=user.created_at,
group_name=user.group.name,
language=user.language,
timezone=user.timezone,
group_expires=user.group_expires,
two_factor=user.two_factor is not None,
)
@user_settings_router.post(
path='/avatar',
summary='从文件上传头像',
description='Upload user avatar from file.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
"""
Upload user avatar from file.
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
"""
Set user avatar to Gravatar.
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
path='/{option}',
summary='更新用户设定',
description='Update user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
"""
Update user settings.
Args:
option (str): The setting option to update.
Returns:
dict: A dictionary containing the result of the settings update.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/2fa',
summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa(
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase:
"""
Get two-factor authentication initialization information.
Returns:
dict: A dictionary containing two-factor authentication setup information.
"""
return sqlmodels.ResponseBase(
data=await Password.generate_totp(user.email)
)
@user_settings_router.post(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
setup_token: str,
code: str,
) -> sqlmodels.ResponseBase:
"""
Enable two-factor authentication for the user.
Returns:
dict: A dictionary containing the result of enabling two-factor authentication.
"""
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
try:
# 1. 解包 Token设置有效期例如 600秒
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码
if not Password.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
user.two_factor = secret
user = await user.save(session)
return sqlmodels.ResponseBase(
data={"message": "Two-factor authentication enabled successfully"}
)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import auth_required from middleware.auth import auth_required
from models import ResponseBase from sqlmodels import ResponseBase
from utils import http_exceptions from utils import http_exceptions
vas_router = APIRouter( vas_router = APIRouter(

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import auth_required from middleware.auth import auth_required
from models import ResponseBase from sqlmodels import ResponseBase
from utils import http_exceptions from utils import http_exceptions
# WebDAV 管理路由 # WebDAV 管理路由

View File

@@ -15,7 +15,7 @@ import aiofiles
import aiofiles.os import aiofiles.os
from loguru import logger as l from loguru import logger as l
from models.policy import Policy from sqlmodels.policy import Policy
from .exceptions import ( from .exceptions import (
DirectoryCreationError, DirectoryCreationError,
FileReadError, FileReadError,

View File

@@ -23,7 +23,7 @@ import string
from datetime import datetime from datetime import datetime
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from models.base import SQLModelBase from sqlmodels.base import SQLModelBase
class NamingContext(SQLModelBase): class NamingContext(SQLModelBase):

View File

@@ -3,7 +3,7 @@ from uuid import uuid4
from loguru import logger from loguru import logger
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import LoginRequest, TokenResponse, User from sqlmodels import LoginRequest, TokenResponse, User
from utils import http_exceptions from utils import http_exceptions
from utils.JWT import create_access_token, create_refresh_token from utils.JWT import create_access_token, create_refresh_token
from utils.password.pwd import Password, PasswordStatus from utils.password.pwd import Password, PasswordStatus
@@ -30,17 +30,17 @@ async def login(
# is_captcha_required = captcha_setting and captcha_setting.value == "1" # is_captcha_required = captcha_setting and captcha_setting.value == "1"
# 获取用户信息 # 获取用户信息
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore current_user: User = await User.get(session, User.email == login_request.email, fetch_mode="first") #type: ignore
# 验证用户是否存在 # 验证用户是否存在
if not current_user: if not current_user:
logger.debug(f"Cannot find user with username: {login_request.username}") logger.debug(f"Cannot find user with email: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid username or password") http_exceptions.raise_unauthorized("Invalid email or password")
# 验证密码是否正确 # 验证密码是否正确
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID: if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
logger.debug(f"Password verification failed for user: {login_request.username}") logger.debug(f"Password verification failed for user: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid username or password") http_exceptions.raise_unauthorized("Invalid email or password")
# 验证用户是否可登录 # 验证用户是否可登录
if not current_user.status: if not current_user.status:
@@ -50,23 +50,23 @@ async def login(
if current_user.two_factor: if current_user.two_factor:
# 用户已启用两步验证 # 用户已启用两步验证
if not login_request.two_fa_code: if not login_request.two_fa_code:
logger.debug(f"2FA required for user: {login_request.username}") logger.debug(f"2FA required for user: {login_request.email}")
http_exceptions.raise_precondition_required("2FA required") http_exceptions.raise_precondition_required("2FA required")
# 验证 OTP 码 # 验证 OTP 码
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID: if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
logger.debug(f"Invalid 2FA code for user: {login_request.username}") logger.debug(f"Invalid 2FA code for user: {login_request.email}")
http_exceptions.raise_unauthorized("Invalid 2FA code") http_exceptions.raise_unauthorized("Invalid 2FA code")
# 创建令牌 # 创建令牌
access_token = create_access_token(data={ access_token = create_access_token(
'sub': str(current_user.id), sub=current_user.id,
'jti': str(uuid4()) jti=uuid4()
}) )
refresh_token = create_refresh_token(data={ refresh_token = create_refresh_token(
'sub': str(current_user.id), sub=current_user.id,
'jti': str(uuid4()) jti=uuid4()
}) )
return TokenResponse( return TokenResponse(
access_token=access_token.access_token, access_token=access_token.access_token,

View File

@@ -1,11 +1,14 @@
from .user import ( from .user import (
BatchDeleteRequest,
LoginRequest, LoginRequest,
RefreshTokenRequest,
RegisterRequest, RegisterRequest,
AccessTokenBase, AccessTokenBase,
RefreshTokenBase, RefreshTokenBase,
TokenResponse, TokenResponse,
User, User,
UserBase, UserBase,
UserStorageResponse,
UserPublic, UserPublic,
UserResponse, UserResponse,
UserSettingResponse, UserSettingResponse,
@@ -66,6 +69,7 @@ from .object import (
FileBanRequest, FileBanRequest,
) )
from .physical_file import PhysicalFile, PhysicalFileBase from .physical_file import PhysicalFile, PhysicalFileBase
from .uri import DiskNextURI, FileSystemNamespace
from .order import Order, OrderStatus, OrderType from .order import Order, OrderStatus, OrderType
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
from .redeem import Redeem, RedeemType from .redeem import Redeem, RedeemType
@@ -82,7 +86,7 @@ from .tag import Tag, TagType
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
from .webdav import WebDAV from .webdav import WebDAV
from .database import engine, get_session from .database_connection import DatabaseManager
from .model_base import ( from .model_base import (
MCPBase, MCPBase,

View File

@@ -630,7 +630,7 @@ For developers modifying this module:
- Handles Python 3.14 annotations via `get_type_hints()` - Handles Python 3.14 annotations via `get_type_hints()`
**Metaclass processing order**: **Metaclass processing order**:
1. Check if class should be a table (`_is_table_mixin`) 1. Check if class should be a table (`_has_table_mixin`)
2. Collect `__mapper_args__` from kwargs and explicit dict 2. Collect `__mapper_args__` from kwargs and explicit dict
3. Process `table_args`, `table_name`, `abstract` parameters 3. Process `table_args`, `table_name`, `abstract` parameters
4. Resolve annotations using `get_type_hints()` 4. Resolve annotations using `get_type_hints()`

View File

@@ -5,8 +5,8 @@ SQLModel 基础模块
- SQLModelBase: 所有 SQLModel 类的基类真正的基类 - SQLModelBase: 所有 SQLModel 类的基类真正的基类
注意 注意
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
为了避免循环导入此处不再重新导出它们 为了避免循环导入此处不再重新导出它们
请直接从 models.mixin 导入这些类 请直接从 sqlmodels.mixin 导入这些类
""" """
from .sqlmodel_base import SQLModelBase from .sqlmodel_base import SQLModelBase

View File

@@ -414,7 +414,7 @@ class __DeclarativeMeta(SQLModelMetaclass):
def __new__(cls, name, bases, attrs, **kwargs): def __new__(cls, name, bases, attrs, **kwargs):
# 1. 约定优于配置:自动设置 table=True # 1. 约定优于配置:自动设置 table=True
is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases) is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
if is_intended_as_table and 'table' not in kwargs: if is_intended_as_table and 'table' not in kwargs:
kwargs['table'] = True kwargs['table'] = True

View File

@@ -0,0 +1,78 @@
from typing import AsyncGenerator, ClassVar
from loguru import logger
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
class DatabaseManager:
engine: ClassVar[AsyncEngine | None] = None
_async_session_factory: ClassVar[sessionmaker | None] = None
@classmethod
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
async with cls._async_session_factory() as session:
yield session
@classmethod
async def init(
cls,
database_url: str,
debug: bool = False,
):
"""
初始化数据库连接引擎。
:param database_url: 数据库连接URL
:param debug: 是否开启调试模式
"""
# 构建引擎参数
engine_kwargs: dict = {
'echo': debug,
'future': True,
}
if debug:
# Debug 模式使用 NullPool无连接池每次创建新连接
engine_kwargs['poolclass'] = NullPool
else:
# 生产模式使用 AsyncAdaptedQueuePool 连接池
engine_kwargs.update({
'poolclass': AsyncAdaptedQueuePool,
'pool_size': 40,
'max_overflow': 80,
'pool_timeout': 30,
'pool_recycle': 1800,
'pool_pre_ping': True,
})
# 只在需要时添加 connect_args
if database_url.startswith("sqlite"):
engine_kwargs['connect_args'] = {'check_same_thread': False}
cls.engine = create_async_engine(database_url, **engine_kwargs)
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
# 开发阶段直接 create_all 创建表结构
async with cls.engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
logger.info("数据库引擎初始化完成")
@classmethod
async def close(cls):
"""
优雅地关闭数据库连接引擎。
仅应在应用结束时调用。
"""
if cls.engine:
logger.info("正在关闭数据库连接引擎...")
await cls.engine.dispose()
logger.info("数据库连接引擎已成功关闭。")
else:
logger.info("数据库连接引擎未初始化,无需关闭。")

View File

@@ -104,9 +104,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA), Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA), Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA), Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
Setting(name="captcha_IsUseReCaptcha", value="0", type=SettingsType.CAPTCHA), Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA),
Setting(name="captcha_ReCaptchaKey", value="defaultKey", type=SettingsType.CAPTCHA), Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_ReCaptchaSecret", value="defaultSecret", type=SettingsType.CAPTCHA), Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA),
Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA),
Setting(name="thumb_width", value="400", type=SettingsType.THUMB), Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
Setting(name="thumb_height", value="300", type=SettingsType.THUMB), Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA), Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
@@ -119,11 +121,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None: async def init_default_settings() -> None:
from .setting import Setting from .setting import Setting
from .database import get_session from .database_connection import DatabaseManager
log.info('初始化设置...') log.info('初始化设置...')
async for session in get_session(): async for session in DatabaseManager.get_session():
# 检查是否已经存在版本设置 # 检查是否已经存在版本设置
ver = await Setting.get( ver = await Setting.get(
session, session,
@@ -139,11 +141,11 @@ async def init_default_group() -> None:
from .group import Group, GroupOptions from .group import Group, GroupOptions
from .policy import Policy, GroupPolicyLink from .policy import Policy, GroupPolicyLink
from .setting import Setting from .setting import Setting
from .database import get_session from .database_connection import DatabaseManager
log.info('初始化用户组...') log.info('初始化用户组...')
async for session in get_session(): async for session in DatabaseManager.get_session():
# 获取默认存储策略 # 获取默认存储策略
default_policy = await Policy.get(session, Policy.name == "本地存储") default_policy = await Policy.get(session, Policy.name == "本地存储")
default_policy_id = default_policy.id if default_policy else None default_policy_id = default_policy.id if default_policy else None
@@ -231,13 +233,20 @@ async def init_default_user() -> None:
from .group import Group from .group import Group
from .object import Object, ObjectType from .object import Object, ObjectType
from .policy import Policy from .policy import Policy
from .database import get_session from .database_connection import DatabaseManager
log.info('初始化管理员用户...') log.info('初始化管理员用户...')
async for session in get_session(): async for session in DatabaseManager.get_session():
# 检查管理员用户是否存在 # 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断)
admin_user = await User.get(session, User.username == "admin") admin_id_setting = await Setting.get(
session,
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
)
admin_user = None
if admin_id_setting and admin_id_setting.value:
from uuid import UUID
admin_user = await User.get(session, User.id == UUID(admin_id_setting.value))
if not admin_user: if not admin_user:
# 获取管理员组 # 获取管理员组
@@ -256,18 +265,24 @@ async def init_default_user() -> None:
hashed_admin_password = Password.hash(admin_password) hashed_admin_password = Password.hash(admin_password)
admin_user = User( admin_user = User(
username="admin", email="admin@disknext.local",
nickname="admin", nickname="admin",
group_id=admin_group.id, group_id=admin_group.id,
password=hashed_admin_password, password=hashed_admin_password,
) )
admin_user_id = admin_user.id # 在 save 前保存 UUID admin_user_id = admin_user.id # 在 save 前保存 UUID
admin_username = admin_user.username
await admin_user.save(session) await admin_user.save(session)
# 为管理员创建根目录(使用用户名作为目录名) # 记录默认管理员 ID 到 Setting
await Setting(
name="default_admin_id",
value=str(admin_user_id),
type=SettingsType.AUTH,
).save(session)
# 为管理员创建根目录
await Object( await Object(
name=admin_username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
owner_id=admin_user_id, owner_id=admin_user_id,
parent_id=None, parent_id=None,
@@ -275,18 +290,18 @@ async def init_default_user() -> None:
).save(session) ).save(session)
log.warning('请注意,账号密码仅显示一次,请妥善保管') log.warning('请注意,账号密码仅显示一次,请妥善保管')
log.info(f'初始管理员账号: admin') log.info(f'初始管理员邮箱: admin@disknext.local')
log.info(f'初始管理员密码: {admin_password}') log.info(f'初始管理员密码: {admin_password}')
async def init_default_policy() -> None: async def init_default_policy() -> None:
from .policy import Policy, PolicyType from .policy import Policy, PolicyType
from .database import get_session from .database_connection import DatabaseManager
from service.storage import LocalStorageService from service.storage import LocalStorageService
log.info('初始化默认存储策略...') log.info('初始化默认存储策略...')
async for session in get_session(): async for session in DatabaseManager.get_session():
# 检查默认存储策略是否存在 # 检查默认存储策略是否存在
default_policy = await Policy.get(session, Policy.name == "本地存储") default_policy = await Policy.get(session, Policy.name == "本地存储")

View File

@@ -5,42 +5,58 @@ SQLModel Mixin模块
包含 包含
- polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin - polymorphic: 联表继承工具create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin
- optimistic_lock: 乐观锁OptimisticLockMixin, OptimisticLockError
- table: 表基类TableBaseMixin, UUIDTableBaseMixin - table: 表基类TableBaseMixin, UUIDTableBaseMixin
- table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest - table: 查询参数类TimeFilterRequest, PaginationRequest, TableViewRequest
- relation_preload: 关系预加载RelationPreloadMixin, requires_relations
- jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入 - jwt/: JWT认证相关JWTAuthMixin, JWTManager, JWTKey等- 需要时直接从 .jwt 导入
- info_response: InfoResponse DTO的id/时间戳Mixin - info_response: InfoResponse DTO的id/时间戳Mixin
导入顺序很重要避免循环导入 导入顺序很重要避免循环导入
1. polymorphic只依赖 SQLModelBase 1. polymorphic只依赖 SQLModelBase
2. table依赖 polymorphic 2. optimistic_lock依赖 SQLAlchemy
3. table依赖 polymorphic optimistic_lock
4. relation_preload只依赖 SQLModelBase
注意jwt 模块不在此处导入因为 jwt/manager.py 导入 ServerConfig 注意jwt 模块不在此处导入因为 jwt/manager.py 导入 ServerConfig
ServerConfig 导入本模块会形成循环需要 jwt 功能时请直接从 .jwt 导入 ServerConfig 导入本模块会形成循环需要 jwt 功能时请直接从 .jwt 导入
""" """
# polymorphic 必须先导入 # polymorphic 必须先导入
from .polymorphic import ( from .polymorphic import (
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin, AutoPolymorphicIdentityMixin,
PolymorphicBaseMixin, PolymorphicBaseMixin,
create_subclass_id_mixin,
register_sti_column_properties_for_all_subclasses,
register_sti_columns_for_all_subclasses,
) )
# table 依赖 polymorphic # optimistic_lock 只依赖 SQLAlchemy必须在 table 之前
from .optimistic_lock import (
OptimisticLockError,
OptimisticLockMixin,
)
# table 依赖 polymorphic 和 optimistic_lock
from .table import ( from .table import (
TableBaseMixin,
UUIDTableBaseMixin,
TimeFilterRequest,
PaginationRequest,
TableViewRequest,
ListResponse, ListResponse,
PaginationRequest,
T, T,
TableBaseMixin,
TableViewRequest,
TimeFilterRequest,
UUIDTableBaseMixin,
now, now,
now_date, now_date,
) )
# relation_preload 只依赖 SQLModelBase
from .relation_preload import (
RelationPreloadMixin,
requires_relations,
)
# jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt # jwt 不在此处导入避免循环jwt/manager.py → ServerConfig → mixin → jwt
# 需要时直接从 sqlmodels.mixin.jwt 导入 # 需要时直接从 sqlmodels.mixin.jwt 导入
from .info_response import ( from .info_response import (
IntIdInfoMixin,
UUIDIdInfoMixin,
DatetimeInfoMixin, DatetimeInfoMixin,
IntIdDatetimeInfoMixin, IntIdDatetimeInfoMixin,
IntIdInfoMixin,
UUIDIdDatetimeInfoMixin, UUIDIdDatetimeInfoMixin,
UUIDIdInfoMixin,
) )

View File

@@ -12,7 +12,7 @@ InfoResponse DTO Mixin模块
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from models.base import SQLModelBase from sqlmodels.base import SQLModelBase
class IntIdInfoMixin(SQLModelBase): class IntIdInfoMixin(SQLModelBase):

View File

@@ -0,0 +1,90 @@
"""
乐观锁 Mixin
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
乐观锁适用场景:
- 涉及"状态转换"的表(如:待支付 -> 已支付)
- 涉及"数值变动"的表(如:余额、库存)
不适用场景:
- 日志表、纯插入表、低价值统计表
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
使用示例:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
status: OrderStatusEnum
amount: Decimal
# save/update 时自动检查版本号
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
try:
order = await order.save(session)
except OptimisticLockError as e:
# 处理冲突:重新查询并重试,或报错给用户
l.warning(f"乐观锁冲突: {e}")
"""
from typing import ClassVar
from sqlalchemy.orm.exc import StaleDataError
class OptimisticLockError(Exception):
"""
乐观锁冲突异常
当 save/update 操作检测到版本号不匹配时抛出。
这意味着在读取和写入之间,其他事务已经修改了该记录。
Attributes:
model_class: 发生冲突的模型类名
record_id: 记录 ID如果可用
expected_version: 期望的版本号(如果可用)
original_error: 原始的 StaleDataError
"""
def __init__(
self,
message: str,
model_class: str | None = None,
record_id: str | None = None,
expected_version: int | None = None,
original_error: StaleDataError | None = None,
):
super().__init__(message)
self.model_class = model_class
self.record_id = record_id
self.expected_version = expected_version
self.original_error = original_error
class OptimisticLockMixin:
"""
乐观锁 Mixin
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
session.commit() 会抛出 StaleDataError被 save/update 方法捕获并转换为 OptimisticLockError。
原理:
1. 每条记录有一个 version 字段,初始值为 0
2. 每次 UPDATE 时SQLAlchemy 生成的 SQL 类似:
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
3. 如果 WHERE 条件不匹配version 已被其他事务修改),
UPDATE 影响 0 行SQLAlchemy 抛出 StaleDataError
继承顺序:
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
...
配套重试:
如果加了乐观锁,业务层需要处理 OptimisticLockError
- 报错给用户:"数据已被修改,请刷新后重试"
- 自动重试:重新查询最新数据并再次尝试
"""
_has_optimistic_lock: ClassVar[bool] = True
"""标记此类启用了乐观锁"""
version: int = 0
"""乐观锁版本号,每次更新自动递增"""

View File

@@ -0,0 +1,710 @@
"""
联表继承Joined Table Inheritance的通用工具
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
Usage Example:
from sqlmodels.base import SQLModelBase
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import (
PolymorphicBaseMixin,
create_subclass_id_mixin,
AutoPolymorphicIdentityMixin
)
# 1. 定义Base类只有字段无表
class ASRBase(SQLModelBase):
name: str
\"\"\"配置名称\"\"\"
base_url: str
\"\"\"服务地址\"\"\"
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
class ASR(
ASRBase,
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC
):
\"\"\"ASR配置的抽象基类\"\"\"
# PolymorphicBaseMixin 自动提供:
# - _polymorphic_name 字段
# - polymorphic_on='_polymorphic_name'
# - polymorphic_abstract=True当有抽象方法时
# 3. 为第二层子类创建ID Mixin
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
# 4. 创建第二层抽象类(如果需要)
class FunASR(
ASRSubclassIdMixin,
ASR,
AutoPolymorphicIdentityMixin,
polymorphic_abstract=True
):
\"\"\"FunASR的抽象基类可能有多个实现\"\"\"
pass
# 5. 创建具体实现类
class FunASRLocal(FunASR, table=True):
\"\"\"FunASR本地部署版本\"\"\"
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
pass
# 6. 获取所有具体子类(用于 selectin_polymorphic
concrete_asrs = ASR.get_concrete_subclasses()
# 返回 [FunASRLocal, ...]
"""
import uuid
from abc import ABC
from uuid import UUID
from loguru import logger as l
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from sqlalchemy import Column, String, inspect
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlmodel import Field
from sqlmodel.main import get_column_from_field
from sqlmodels.base.sqlmodel_base import SQLModelBase
# 用于延迟注册 STI 子类列的队列
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
_sti_subclasses_to_register: list[type] = []
def register_sti_columns_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
此函数应在 configure_mappers() 之前调用。
将 STI 子类的字段添加到父表的 metadata 中。
同时修复被 Column 对象污染的 model_fields。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_columns()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
# 修复被 Column 对象污染的 model_fields
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
try:
_fix_polluted_model_fields(cls)
except Exception as e:
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
def register_sti_column_properties_for_all_subclasses() -> None:
"""
为所有已注册的 STI 子类添加列属性到 mapper第二阶段
此函数应在 configure_mappers() 之后调用。
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
"""
for cls in _sti_subclasses_to_register:
try:
cls._register_sti_column_properties()
except Exception as e:
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
# 清空队列
_sti_subclasses_to_register.clear()
def _fix_polluted_model_fields(cls: type) -> None:
"""
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
当 SQLModel 类继承有表的父类时SQLAlchemy 会在类上创建 InstrumentedAttribute
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
时错误地使用这些 SQLAlchemy 对象作为默认值。
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
:param cls: 要修复的类
"""
if not hasattr(cls, 'model_fields'):
return
def find_original_field_info(field_name: str) -> FieldInfo | None:
"""从 MRO 中查找字段的原始定义(未被污染的)"""
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, 'model_fields') and field_name in base.model_fields:
field_info = base.model_fields[field_name]
# 跳过被 InstrumentedAttribute 或 Column 污染的
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
return field_info
return None
for field_name, current_field in cls.model_fields.items():
# 检查是否被污染default 是 InstrumentedAttribute 或 Column
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
# 被继承到有表的子类时SQLModel 会创建一个 Column 对象替换原始默认值
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
continue # 未被污染,跳过
# 从父类查找原始定义
original = find_original_field_info(field_name)
if original is None:
continue # 找不到原始定义,跳过
# 根据原始定义的 default/default_factory 来修复
if original.default_factory:
# 有 default_factory如 uuid.uuid4, now
new_field = FieldInfo(
default_factory=original.default_factory,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
elif original.default is not PydanticUndefined:
# 有明确的 default 值(如 None, 0, True且不是 PydanticUndefined
# PydanticUndefined 表示字段没有默认值(必填)
new_field = FieldInfo(
default=original.default,
annotation=current_field.annotation,
json_schema_extra=current_field.json_schema_extra,
)
else:
continue # 既没有 default_factory 也没有有效的 default跳过
# 复制 SQLModel 特有的属性
if hasattr(current_field, 'foreign_key'):
new_field.foreign_key = current_field.foreign_key
if hasattr(current_field, 'primary_key'):
new_field.primary_key = current_field.primary_key
cls.model_fields[field_name] = new_field
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
"""
动态创建SubclassIdMixin类
在联表继承中,子类需要一个外键指向父表的主键。
此函数生成一个Mixin类提供这个外键字段并自动生成UUID。
Args:
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function'
Returns:
一个Mixin类包含id字段外键 + 主键 + default_factory=uuid.uuid4
Example:
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
... pass
Note:
- 生成的Mixin应该放在继承列表的第一位确保通过MRO覆盖UUIDTableBaseMixin的id
- 生成的类名为 {ParentTableName}SubclassIdMixinPascalCase
- 本项目所有联表继承均使用UUID主键UUIDTableBaseMixin
"""
if not parent_table_name:
raise ValueError("parent_table_name 不能为空")
# 转换为PascalCase作为类名
class_name_parts = parent_table_name.split('_')
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
# 使用闭包捕获parent_table_name
_parent_table_name = parent_table_name
# 创建带有__init_subclass__的mixin类用于在子类定义后修复model_fields
class SubclassIdMixin(SQLModelBase):
# 定义id字段
id: UUID = Field(
default_factory=uuid.uuid4,
foreign_key=f'{_parent_table_name}.id',
primary_key=True,
)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复联表继承中子类字段的 default_factory 丢失问题。
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
# 设置类名和文档
SubclassIdMixin.__name__ = class_name
SubclassIdMixin.__qualname__ = class_name
SubclassIdMixin.__doc__ = f"""
{parent_table_name}子类的ID Mixin
用于{parent_table_name}的子类,提供外键指向父表。
通过MRO确保此id字段覆盖继承的id字段。
"""
return SubclassIdMixin
class AutoPolymorphicIdentityMixin:
"""
自动生成polymorphic_identity的Mixin并支持STI子类列注册
使用此Mixin的类会自动根据类名生成polymorphic_identity。
格式:{parent_polymorphic_identity}.{classname_lowercase}
如果没有父类的polymorphic_identity则直接使用类名小写。
**重要:数据库迁移注意事项**
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
例如,对于以下继承链::
LLM (polymorphic_on='_polymorphic_name')
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
迁移脚本中设置 _polymorphic_name 时::
# ❌ 错误:缺少父类前缀
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
# ✅ 正确:包含完整的继承链前缀
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
**STI单表继承支持**
当子类与父类共用同一张表STI模式此Mixin会自动将子类的新字段
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
注册到父表的问题。
Example (JTI):
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
... __polymorphic_name: str
...
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
... pass
... # polymorphic_identity 会自动设置为 'function'
...
>>> class CodeInterpreterFunction(Function, table=True):
... pass
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
Example (STI):
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
... user_id: UUID
...
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
... # polymorphic_identity 会自动设置为 'pendingfile'
Note:
- 如果手动在__mapper_args__中指定了polymorphic_identity会被保留
- 此Mixin应该在继承列表中靠后的位置在表基类之前
- STI模式下新字段会在类定义时自动添加到父表的metadata中
"""
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
"""
子类化钩子自动生成polymorphic_identity并处理STI列注册
Args:
polymorphic_identity: 如果手动指定,则使用指定的值
**kwargs: 其他SQLModel参数如table=True, polymorphic_abstract=True
"""
super().__init_subclass__(**kwargs)
# 如果手动指定了polymorphic_identity使用指定的值
if polymorphic_identity is not None:
identity = polymorphic_identity
else:
# 自动生成polymorphic_identity
class_name = cls.__name__.lower()
# 尝试从父类获取polymorphic_identity作为前缀
parent_identity = None
for base in cls.__mro__[1:]: # 跳过自己
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
parent_identity = base.__mapper_args__.get('polymorphic_identity')
if parent_identity:
break
# 构建identity
if parent_identity:
identity = f'{parent_identity}.{class_name}'
else:
identity = class_name
# 设置到__mapper_args__
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 只在尚未设置polymorphic_identity时设置
if 'polymorphic_identity' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_identity'] = identity
# 注册 STI 子类列的延迟执行
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
_sti_subclasses_to_register.append(cls)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
修复 STI 继承中子类字段被 Column 对象污染的问题。
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
"""
super().__pydantic_init_subclass__(**kwargs)
_fix_polluted_model_fields(cls)
@classmethod
def _register_sti_columns(cls) -> None:
"""
将STI子类的新字段注册到父表的列定义中
检测当前类是否是STI子类与父类共用同一张表
如果是则将子类定义的新字段添加到父表的metadata中。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_fields: set[str] = set()
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
# 收集父类的所有字段名
if hasattr(base, 'model_fields'):
parent_fields.update(base.model_fields.keys())
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列注册到父表
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列注册
# 获取当前类的新字段(不在父类中的字段)
if not hasattr(cls, 'model_fields'):
return
existing_columns = {col.name for col in parent_table.columns}
for field_name, field_info in cls.model_fields.items():
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段和ClassVar
if field_name.startswith('_'):
continue
# 跳过已存在的列
if field_name in existing_columns:
continue
# 使用 SQLModel 的内置 API 创建列
try:
column = get_column_from_field(field_info)
column.name = field_name
column.key = field_name
# STI子类字段在数据库层面必须可空因为其他子类的行不会有这些字段的值
# Pydantic层面的约束仍然有效创建特定子类时会验证必填字段
column.nullable = True
# 将列添加到父表
parent_table.append_column(column)
except Exception as e:
l.warning(f"{cls.__name__} 创建列 {field_name} 失败: {e}")
@classmethod
def _register_sti_column_properties(cls) -> None:
"""
将 STI 子类的列作为 ColumnProperty 添加到 mapper
此方法在 configure_mappers() 之后调用,将已添加到表中的列
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
这确保了查询父类时SELECT 语句包含所有 STI 子类的列,
避免在响应序列化时触发懒加载MissingGreenlet 错误)。
JTI联表继承类会被自动跳过因为它们有自己独立的表。
"""
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
parent_table = None
parent_class = None
for base in cls.__mro__[1:]:
if hasattr(base, '__table__') and base.__table__ is not None:
parent_table = base.__table__
parent_class = base
break
if parent_table is None:
return # 没有找到父表,可能是根类
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
if hasattr(cls, '__table__') and cls.__table__ is not None:
if cls.__table__.name != parent_table.name:
return # JTI跳过 STI 列属性注册
# 获取子类和父类的 mapper
child_mapper = inspect(cls).mapper
parent_mapper = inspect(parent_class).mapper
local_table = child_mapper.local_table
# 查找父类的所有字段名
parent_fields: set[str] = set()
if hasattr(parent_class, 'model_fields'):
parent_fields.update(parent_class.model_fields.keys())
if not hasattr(cls, 'model_fields'):
return
# 获取两个 mapper 已有的列属性
child_existing_props = {p.key for p in child_mapper.column_attrs}
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
for field_name in cls.model_fields:
# 跳过从父类继承的字段
if field_name in parent_fields:
continue
# 跳过私有字段
if field_name.startswith('_'):
continue
# 检查表中是否有这个列
if field_name not in local_table.columns:
continue
column = local_table.columns[field_name]
# 添加到子类的 mapper如果尚不存在
if field_name not in child_existing_props:
try:
prop = ColumnProperty(column)
child_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"{cls.__name__} 添加列属性 {field_name} 失败: {e}")
# 同时添加到父类的 mapper确保查询父类时 SELECT 包含所有 STI 子类的列)
if field_name not in parent_existing_props:
try:
prop = ColumnProperty(column)
parent_mapper.add_property(field_name, prop)
except Exception as e:
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
class PolymorphicBaseMixin:
"""
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
此 Mixin 自动设置以下内容:
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
使用场景:
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
用法示例:
```python
from abc import ABC
from sqlmodels.mixin import UUIDTableBaseMixin
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
# 定义基类
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
__tablename__ = 'mytool'
# 不需要手动定义 _polymorphic_name
# 不需要手动设置 polymorphic_on
# 不需要手动设置 polymorphic_abstract
# 定义子类
class SpecificTool(MyTool):
__tablename__ = 'specifictool'
# 会自动继承 polymorphic 配置
```
自动行为:
1. 定义 `_polymorphic_name: str` 字段(带索引)
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
3. 自动检测抽象类:
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
- 否则设置为 False
手动覆盖:
可以在类定义时手动指定参数来覆盖自动行为:
```python
class MyTool(
UUIDTableBaseMixin,
PolymorphicBaseMixin,
ABC,
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
polymorphic_abstract=False # 强制不设为抽象类
):
pass
```
注意事项:
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
- 适用于联表继承joined table inheritance场景
- 子类会自动继承 _polymorphic_name 字段定义
- 使用单下划线前缀是因为:
* SQLAlchemy 会映射单下划线字段为数据库列
* Pydantic 将其视为私有属性,不参与序列化
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
"""
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
#
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
#
# 为什么这样做:
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
# 4. 字段不出现在 Pydantic 序列化中model_dump() 和 JSON schema
# 5. 内部代码仍然可以正常访问和修改此字段
#
# 详细说明请参考sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
"""
多态鉴别器字段,用于标识具体的子类类型
注意:此字段使用单下划线前缀,表示内部使用。
- ✅ 存储到数据库
- ✅ 不出现在 API 序列化中
- ✅ 防止外部直接修改
"""
def __init_subclass__(
cls,
polymorphic_on: str | None = None,
polymorphic_abstract: bool | None = None,
**kwargs
):
"""
在子类定义时自动配置 polymorphic 设置
Args:
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'
设置为其他值可以使用不同的字段作为多态鉴别器。
polymorphic_abstract: 是否为抽象类。
- None: 自动检测(默认)
- True: 强制设为抽象类
- False: 强制设为非抽象类
**kwargs: 传递给父类的其他参数
"""
super().__init_subclass__(**kwargs)
# 初始化 __mapper_args__如果还没有
if '__mapper_args__' not in cls.__dict__:
cls.__mapper_args__ = {}
# 设置 polymorphic_on默认为 _polymorphic_name
if 'polymorphic_on' not in cls.__mapper_args__:
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
# 自动检测或设置 polymorphic_abstract
if 'polymorphic_abstract' not in cls.__mapper_args__:
if polymorphic_abstract is None:
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
has_abc = ABC in cls.__mro__
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
polymorphic_abstract = has_abc and has_abstract_methods
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
@classmethod
def _is_joined_table_inheritance(cls) -> bool:
"""
检测当前类是否使用联表继承Joined Table Inheritance
通过检查子类是否有独立的表来判断:
- JTI: 子类有独立的 local_table与父类不同
- STI: 子类与父类共用同一个 local_table
:return: True 表示 JTIFalse 表示 STI 或无子类
"""
mapper = inspect(cls)
base_table_name = mapper.local_table.name
# 检查所有直接子类
for subclass in cls.__subclasses__():
sub_mapper = inspect(subclass)
# 如果任何子类有不同的表名,说明是 JTI
if sub_mapper.local_table.name != base_table_name:
return True
return False
@classmethod
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
"""
递归获取当前类的所有具体(非抽象)子类
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
可在任意多态基类上调用,返回该类的所有非抽象子类。
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
"""
result: list[type[PolymorphicBaseMixin]] = []
for subclass in cls.__subclasses__():
# 使用 inspect() 获取 mapper 的公开属性
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
mapper = inspect(subclass)
if not mapper.polymorphic_abstract:
result.append(subclass)
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
if hasattr(subclass, 'get_concrete_subclasses'):
result.extend(subclass.get_concrete_subclasses())
return result
@classmethod
def get_polymorphic_discriminator(cls) -> str:
"""
获取多态鉴别字段名
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
:return: 多态鉴别字段名(如 '_polymorphic_name'
:raises ValueError: 如果类未配置 polymorphic_on
"""
polymorphic_on = inspect(cls).polymorphic_on
if polymorphic_on is None:
raise ValueError(
f"{cls.__name__} 未配置 polymorphic_on"
f"请确保正确继承 PolymorphicBaseMixin"
)
return polymorphic_on.key
@classmethod
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
"""
获取 polymorphic_identity 到具体子类的映射
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
:return: identity 到子类的映射字典
"""
result: dict[str, type[PolymorphicBaseMixin]] = {}
for subclass in cls.get_concrete_subclasses():
identity = inspect(subclass).polymorphic_identity
if identity:
result[identity] = subclass
return result

View File

@@ -0,0 +1,470 @@
"""
关系预加载 Mixin
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
设计原则:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 查询最优:相同关系只查询一次,不同关系增量查询
- 零侵入:调用方无需任何改动
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
使用方式:
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
kling_video_generator: KlingO1Generator = Relationship(...)
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# 自动加载,可以安全访问
price = self.kling_video_generator.kling_o1.pro_price_per_second
...
# 调用方 - 无需任何改动
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
await tool._call(...) # 关系相同则跳过,否则增量加载
支持 AsyncGenerator
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # 装饰器正确处理 async generator
"""
import inspect as python_inspect
from functools import wraps
from typing import Callable, TypeVar, ParamSpec, Any
from loguru import logger as l
from sqlalchemy import inspect as sa_inspect
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import RelationshipInfo
P = ParamSpec('P')
R = TypeVar('R')
def _extract_session(
func: Callable,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> AsyncSession | None:
"""
从方法参数中提取 AsyncSession
按以下顺序查找:
1. kwargs 中名为 'session' 的参数
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
3. kwargs 中类型为 AsyncSession 的参数
"""
# 1. 优先从 kwargs 查找
if 'session' in kwargs:
return kwargs['session']
# 2. 从函数签名定位位置参数
try:
sig = python_inspect.signature(func)
param_names = list(sig.parameters.keys())
if 'session' in param_names:
# 计算位置(减去 self
idx = param_names.index('session') - 1
if 0 <= idx < len(args):
return args[idx]
except (ValueError, TypeError):
pass
# 3. 遍历 kwargs 找 AsyncSession 类型
for value in kwargs.values():
if isinstance(value, AsyncSession):
return value
return None
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
"""
检查对象的关系是否已加载(独立函数版本)
Args:
obj: 要检查的对象
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(obj)
return rel_name not in state.unloaded
except Exception:
return False
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
"""
在类中查找指向目标类的关系属性名
Args:
from_class: 源类
to_class: 目标类
Returns:
关系属性名,如果找不到则返回 None
Example:
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
# 返回 'kling_video_generator'
"""
for attr_name in dir(from_class):
try:
attr = getattr(from_class, attr_name, None)
if attr is None:
continue
# 检查是否是 SQLAlchemy InstrumentedAttribute关系属性
# parent.class_ 是关系所在的类property.mapper.class_ 是关系指向的目标类
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
target_class = attr.property.mapper.class_
if target_class == to_class:
return attr_name
except AttributeError:
continue
return None
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
装饰器:声明方法需要的关系,自动按需增量加载
参数格式:
- 字符串:本类属性名,如 'kling_video_generator'
- RelationshipInfo外部类属性如 KlingO1Generator.kling_o1
行为:
- 方法调用时自动检查关系是否已加载
- 未加载的关系会被增量加载(单次查询)
- 已加载的关系直接跳过
支持:
- 普通 async 方法:`async def cost(...) -> ToolCost`
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
Example:
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
async def cost(self, params, context, session) -> ToolCost:
# self.kling_video_generator.kling_o1 已自动加载
...
@requires_relations('twitter_api')
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
yield ToolResponse(...) # AsyncGenerator 正确处理
验证:
- 字符串格式的关系名在类创建时__init_subclass__验证
- 拼写错误会在导入时抛出 AttributeError
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
# 检测是否是 async generator 函数
is_async_gen = python_inspect.isasyncgenfunction(func)
if is_async_gen:
# AsyncGenerator 需要特殊处理wrapper 也必须是 async generator
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
# 委托给原始 async generator逐个 yield 值
async for item in func(self, *args, **kwargs):
yield item # type: ignore
else:
# 普通 async 函数await 并返回结果
@wraps(func)
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
session = _extract_session(func, args, kwargs)
if session is not None:
await self._ensure_relations_loaded(session, relations)
return await func(self, *args, **kwargs)
# 保存关系声明供验证和内省使用
wrapper._required_relations = relations # type: ignore
return wrapper
return decorator
class RelationPreloadMixin:
"""
关系预加载 Mixin
提供按需增量加载能力,确保 SQL 查询数理论最优。
特性:
- 按需加载:只加载被调用方法需要的关系
- 增量加载:已加载的关系不重复加载
- 原地更新:直接修改 self无需替换实例
- 导入时验证:字符串关系名在类创建时验证
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
"""
def __init_subclass__(cls, **kwargs) -> None:
"""类创建时验证所有 @requires_relations 声明"""
super().__init_subclass__(**kwargs)
# 收集类及其父类的所有注解(包含普通字段)
all_annotations: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__annotations__'):
all_annotations.update(klass.__annotations__.keys())
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__
sqlmodel_relationships: set[str] = set()
for klass in cls.__mro__:
if hasattr(klass, '__sqlmodel_relationships__'):
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
# 合并所有可用的属性名
all_available_names = all_annotations | sqlmodel_relationships
for method_name in dir(cls):
if method_name.startswith('__'):
continue
try:
method = getattr(cls, method_name, None)
except AttributeError:
continue
if method is None or not hasattr(method, '_required_relations'):
continue
# 验证字符串格式的关系名
for spec in method._required_relations:
if isinstance(spec, str):
# 检查注解、Relationship 或已有属性
if spec not in all_available_names and not hasattr(cls, spec):
raise AttributeError(
f"{cls.__name__}.{method_name} 声明了关系 '{spec}'"
f"{cls.__name__} 没有此属性"
)
def _is_relation_loaded(self, rel_name: str) -> bool:
"""
检查关系是否真正已加载(基于 SQLAlchemy inspect
使用 SQLAlchemy 的 inspect 检测真实加载状态,
自动处理 commit 导致的 expire 问题。
Args:
rel_name: 关系属性名
Returns:
True 如果关系已加载False 如果未加载或已过期
"""
try:
state = sa_inspect(self)
# unloaded 包含未加载的关系属性名
return rel_name not in state.unloaded
except Exception:
# 对象可能未被 SQLAlchemy 管理
return False
async def _ensure_relations_loaded(
self,
session: AsyncSession,
relations: tuple[str | RelationshipInfo, ...],
) -> None:
"""
确保指定关系已加载,只加载未加载的部分
基于 SQLAlchemy inspect 检测真实状态,自动处理:
- 首次访问的关系
- commit 后 expire 的关系
- 嵌套关系(如 KlingO1Generator.kling_o1
Args:
session: 数据库会话
relations: 需要的关系规格
"""
# 找出真正未加载的关系(基于 SQLAlchemy inspect
to_load: list[str | RelationshipInfo] = []
# 区分直接关系和嵌套关系的 key
direct_keys: set[str] = set() # 本类的直接关系属性名
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
for rel in relations:
if isinstance(rel, str):
# 直接关系:检查本类的关系是否已加载
if not self._is_relation_loaded(rel):
to_load.append(rel)
direct_keys.add(rel)
else:
# 嵌套关系InstrumentedAttribute如 KlingO1Generator.kling_o1
# 1. 查找指向父类的关系属性
parent_class = rel.parent.class_
parent_attr = _find_relation_to_class(self.__class__, parent_class)
if parent_attr is None:
# 找不到路径,可能是配置错误,但仍尝试加载
l.warning(
f"无法找到从 {self.__class__.__name__}{parent_class.__name__} 的关系路径,"
f"无法检查 {rel.key} 是否已加载"
)
to_load.append(rel)
continue
# 2. 检查父对象是否已加载
if not self._is_relation_loaded(parent_attr):
# 父对象未加载,需要同时加载父对象和嵌套关系
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
else:
# 3. 父对象已加载,检查嵌套关系是否已加载
parent_obj = getattr(self, parent_attr)
if not _is_obj_relation_loaded(parent_obj, rel.key):
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
# 因为 _build_load_chains 需要完整的链来构建 selectinload
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
to_load.append(parent_attr)
nested_parent_keys.add(parent_attr)
to_load.append(rel)
if not to_load:
return # 全部已加载,跳过
# 构建 load 参数
load_options = self._specs_to_load_options(to_load)
if not load_options:
return
# 安全地获取主键值(避免触发懒加载)
state = sa_inspect(self)
pk_tuple = state.key[1] if state.key else None
if pk_tuple is None:
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
return
# 主键是元组,取第一个值(假设单列主键)
pk_value = pk_tuple[0]
# 单次查询加载缺失的关系
fresh = await self.__class__.get(
session,
self.__class__.id == pk_value,
load=load_options,
)
if fresh is None:
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
return
# 原地复制到 self只复制直接关系嵌套关系通过父关系自动可访问
all_direct_keys = direct_keys | nested_parent_keys
for key in all_direct_keys:
value = getattr(fresh, key, None)
object.__setattr__(self, key, value)
def _specs_to_load_options(
self,
specs: list[str | RelationshipInfo],
) -> list[RelationshipInfo]:
"""
将关系规格转换为 load 参数
- 字符串 → cls.{name}
- RelationshipInfo → 直接使用
"""
result: list[RelationshipInfo] = []
for spec in specs:
if isinstance(spec, str):
rel = getattr(self.__class__, spec, None)
if rel is not None:
result.append(rel)
else:
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
else:
result.append(spec)
return result
# ==================== 可选的手动预加载 API ====================
@classmethod
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
"""
获取指定方法声明的关系(用于外部预加载场景)
Args:
method_name: 方法名
Returns:
RelationshipInfo 列表
"""
method = getattr(cls, method_name, None)
if method is None or not hasattr(method, '_required_relations'):
return []
result: list[RelationshipInfo] = []
for spec in method._required_relations:
if isinstance(spec, str):
rel = getattr(cls, spec, None)
if rel:
result.append(rel)
else:
result.append(spec)
return result
@classmethod
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
"""
获取多个方法的关系并去重(用于批量预加载场景)
Args:
method_names: 方法名列表
Returns:
去重后的 RelationshipInfo 列表
"""
seen: set[str] = set()
result: list[RelationshipInfo] = []
for method_name in method_names:
for rel in cls.get_relations_for_method(method_name):
key = rel.key
if key not in seen:
seen.add(key)
result.append(rel)
return result
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
"""
手动预加载指定方法的关系(可选优化 API
当需要确保在调用方法前完成所有加载时使用。
通常情况下不需要调用此方法,装饰器会自动处理。
Args:
session: 数据库会话
method_names: 方法名列表
Returns:
self支持链式调用
Example:
# 可选:显式预加载(通常不需要)
tool = await tool.preload_for(session, 'cost', '_call')
"""
all_relations: list[str | RelationshipInfo] = []
for method_name in method_names:
method = getattr(self.__class__, method_name, None)
if method and hasattr(method, '_required_relations'):
all_relations.extend(method._required_relations)
if all_relations:
await self._ensure_relations_loaded(session, tuple(all_relations))
return self

View File

@@ -12,7 +12,14 @@
mixin/table.py 当前文件导入 PolymorphicBaseMixin mixin/table.py 当前文件导入 PolymorphicBaseMixin
base/__init__.py mixin 重新导出保持向后兼容 base/__init__.py mixin 重新导出保持向后兼容
维护须知
增删功能时必须更新 __version__ 字段遵循语义化版本
版本历史
0.1.0 - delete() 方法支持条件删除condition 参数
""" """
__version__ = "0.1.0"
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import TypeVar, Literal, override, Any, ClassVar, Generic from typing import TypeVar, Literal, override, Any, ClassVar, Generic
@@ -26,16 +33,19 @@ from typing import TypeVar, Literal, override, Any, ClassVar, Generic
# 未来: PR #1275合并后可改回继承SQLModelBase # 未来: PR #1275合并后可改回继承SQLModelBase
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
from sqlalchemy.orm.exc import StaleDataError
from sqlmodel import Field, select from sqlmodel import Field, select
from .optimistic_lock import OptimisticLockError
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.sql._typing import _OnClauseArgument from sqlalchemy.sql._typing import _OnClauseArgument
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlmodel.main import RelationshipInfo from sqlmodel.main import RelationshipInfo
from .polymorphic import PolymorphicBaseMixin from .polymorphic import PolymorphicBaseMixin
from models.base.sqlmodel_base import SQLModelBase from sqlmodels.base.sqlmodel_base import SQLModelBase
# Type variables for generic type hints, improving code completion and analysis. # Type variables for generic type hints, improving code completion and analysis.
T = TypeVar("T", bound="TableBaseMixin") T = TypeVar("T", bound="TableBaseMixin")
@@ -196,8 +206,8 @@ class TableBaseMixin(AsyncAttrs):
created_at (datetime): 记录创建时的时间戳, 自动设置. created_at (datetime): 记录创建时的时间戳, 自动设置.
updated_at (datetime): 记录每次更新时的时间戳, 自动更新. updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
""" """
_is_table_mixin: ClassVar[bool] = True _has_table_mixin: ClassVar[bool] = True
"""标记此类表混入类的内部属性""" """标记此类继承了表混入类的内部属性"""
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
""" """
@@ -218,7 +228,7 @@ class TableBaseMixin(AsyncAttrs):
) )
@classmethod @classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]: async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
""" """
向数据库中添加一个新的或多个新的记录. 向数据库中添加一个新的或多个新的记录.
@@ -230,8 +240,6 @@ class TableBaseMixin(AsyncAttrs):
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表. instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True. refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
commit (bool): 是否提交事务设为 False 可在批量操作时减少提交次数
之后需要手动调用 `session.commit()`默认为 True.
Returns: Returns:
T | list[T]: 已添加并可选地刷新的一个或多个模型实例. T | list[T]: 已添加并可选地刷新的一个或多个模型实例.
@@ -246,11 +254,6 @@ class TableBaseMixin(AsyncAttrs):
# 添加单个实例 # 添加单个实例
item3 = Item(name="Cherry") item3 = Item(name="Cherry")
added_item = await Item.add(session, item3) added_item = await Item.add(session, item3)
# 批量操作,减少提交次数
await Item.add(session, [item1, item2], commit=False)
await Item.add(session, [item3, item4], commit=False)
await session.commit()
""" """
is_list = False is_list = False
if isinstance(instances, list): if isinstance(instances, list):
@@ -259,10 +262,7 @@ class TableBaseMixin(AsyncAttrs):
else: else:
session.add(instances) session.add(instances)
if commit:
await session.commit() await session.commit()
else:
await session.flush()
if refresh: if refresh:
if is_list: if is_list:
@@ -278,14 +278,16 @@ class TableBaseMixin(AsyncAttrs):
session: AsyncSession, session: AsyncSession,
load: RelationshipInfo | list[RelationshipInfo] | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True, refresh: bool = True,
commit: bool = True commit: bool = True,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
optimistic_retry_count: int = 0,
) -> T: ) -> T:
""" """
保存插入或更新当前模型实例到数据库. 保存插入或更新当前模型实例到数据库.
这是一个实例方法它将当前对象添加到会话中并提交更改 这是一个实例方法它将当前对象添加到会话中并提交更改
可以用于创建新记录或更新现有记录还可以选择在保存后 可以用于创建新记录或更新现有记录还可以选择在保存后
预加载eager load一个或多个关联关系. 预加载eager load一个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired **重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象必须使用返回值 如果需要继续使用该对象必须使用返回值
@@ -298,13 +300,17 @@ class TableBaseMixin(AsyncAttrs):
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False) await client.save(session, refresh=False)
# ✅ 正确:批量操作,减少提交次数 # ✅ 正确:批量操作时延迟提交
await item1.save(session, commit=False) for item in items:
await item2.save(session, commit=False) item = await item.save(session, commit=False)
await session.commit() await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系 # ✅ 正确:保存后需要访问多态关系
user = await user.save(session, load=[User.group, User.tags]) tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all')
return tool_set # tools 关系已正确加载子类数据
# ✅ 正确:启用乐观锁自动重试
order = await order.save(session, optimistic_retry_count=3)
# ❌ 错误:需要返回值但未使用 # ❌ 错误:需要返回值但未使用
await client.save(session) await client.save(session)
@@ -313,34 +319,77 @@ class TableBaseMixin(AsyncAttrs):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | list[Relationship] | None): 可选的指定在保存和刷新后要预加载的关联属性. load (Relationship | None): 可选的指定在保存和刷新后要预加载的关联属性.
可以是单个关系或关系列表. 例如 `User.posts`.
例如 `User.posts` `[User.group, User.tags]`.
refresh (bool): 是否在保存后刷新对象如果不需要使用返回值 refresh (bool): 是否在保存后刷新对象如果不需要使用返回值
设为 False 可节省一次数据库查询默认为 True. 设为 False 可节省一次数据库查询默认为 True.
commit (bool): 是否提交事务 False 可在批量操作时减少提交次数 commit (bool): 是否在保存后提交事务如果 False只会 flush 获取 ID
之后需要手动调用 `session.commit()`默认为 True. 但不提交适用于批量操作场景默认为 True.
jti_subclasses: 多态子类加载选项需要与 load 参数配合使用
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询只加载实际关联的子类
- None默认: 不使用多态加载
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数默认为 0不重试
重试时会重新查询最新数据将当前修改合并后再次保存
Returns: Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self. T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
Raises:
OptimisticLockError: 如果启用了乐观锁且版本号不匹配且重试次数已耗尽
""" """
session.add(self) cls = type(self)
instance = self
retries_remaining = optimistic_retry_count
current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时
while True:
session.add(instance)
try:
if commit: if commit:
await session.commit() await session.commit()
else: else:
await session.flush() await session.flush()
break # 成功,退出循环
except StaleDataError as e:
await session.rollback()
if retries_remaining <= 0:
raise OptimisticLockError(
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
model_class=cls.__name__,
record_id=str(getattr(instance, 'id', None)),
expected_version=getattr(instance, 'version', None),
original_error=e,
) from e
# 失败后重试:重新查询最新数据并合并修改
retries_remaining -= 1
if current_data is None:
current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'})
fresh = await cls.get(session, cls.id == self.id)
if fresh is None:
raise OptimisticLockError(
message=f"{cls.__name__} 重试失败:记录已被删除",
model_class=cls.__name__,
record_id=str(getattr(self, 'id', None)),
original_error=e,
) from e
for key, value in current_data.items():
if hasattr(fresh, key):
setattr(fresh, key, value)
instance = fresh
if not refresh: if not refresh:
return self return instance
if load is not None: if load is not None:
cls = type(self) await session.refresh(instance)
await session.refresh(self) return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
# 如果指定了 load, 重新获取实例并加载关联关系
return await cls.get(session, cls.id == self.id, load=load)
else: else:
await session.refresh(self) await session.refresh(instance)
return self return instance
async def update( async def update(
self: T, self: T,
@@ -351,7 +400,9 @@ class TableBaseMixin(AsyncAttrs):
exclude: set[str] | None = None, exclude: set[str] | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True, refresh: bool = True,
commit: bool = True commit: bool = True,
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
optimistic_retry_count: int = 0,
) -> T: ) -> T:
""" """
使用另一个模型实例或字典中的数据来更新当前实例. 使用另一个模型实例或字典中的数据来更新当前实例.
@@ -371,16 +422,20 @@ class TableBaseMixin(AsyncAttrs):
user = await user.update(session, update_data, load=User.permission) user = await user.update(session, update_data, load=User.permission)
return user return user
# ✅ 正确:更新后需要访问多态关系时
tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all')
return tool_set # tools 关系已正确加载子类数据
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False) await client.update(session, update_data, refresh=False)
# ✅ 正确:批量操作,减少提交次数 # ✅ 正确:批量操作时延迟提交
await user1.update(session, data1, commit=False) for item in items:
await user2.update(session, data2, commit=False) item = await item.update(session, data, commit=False)
await session.commit() await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系 # ✅ 正确:启用乐观锁自动重试
user = await user.update(session, data, load=[User.group, User.tags]) order = await order.update(session, update_data, optimistic_retry_count=3)
# ❌ 错误:需要返回值但未使用 # ❌ 错误:需要返回值但未使用
await client.update(session, update_data) await client.update(session, update_data)
@@ -394,38 +449,72 @@ class TableBaseMixin(AsyncAttrs):
exclude_unset (bool): 如果为 True, `other` 对象中未设置即值为 None 或未提供 exclude_unset (bool): 如果为 True, `other` 对象中未设置即值为 None 或未提供
的字段将被忽略. 默认为 True. 的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合例如 {'permission'}. exclude (set[str] | None): 要从更新中排除的字段名集合例如 {'permission'}.
load (Relationship | list[Relationship] | None): 可选的指定在更新和刷新后要预加载的关联属性. load (RelationshipInfo | None): 可选的指定在更新和刷新后要预加载的关联属性.
可以是单个关系或关系列表. 例如 `User.permission`.
例如 `User.permission` `[User.group, User.tags]`.
refresh (bool): 是否在更新后刷新对象如果不需要使用返回值 refresh (bool): 是否在更新后刷新对象如果不需要使用返回值
设为 False 可节省一次数据库查询默认为 True. 设为 False 可节省一次数据库查询默认为 True.
commit (bool): 是否提交事务 False 可在批量操作时减少提交次数 commit (bool): 是否在更新后提交事务如果 False只会 flush
之后需要手动调用 `session.commit()`默认为 True. 但不提交适用于批量操作场景默认为 True.
jti_subclasses: 多态子类加载选项需要与 load 参数配合使用
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询只加载实际关联的子类
- None默认: 不使用多态加载
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数默认为 0不重试
重试时会重新查询最新数据 other 的更新重新应用后再次保存
Returns: Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self. T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
self.sqlmodel_update(
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
update=extra_data
)
session.add(self) Raises:
OptimisticLockError: 如果启用了乐观锁且版本号不匹配且重试次数已耗尽
"""
cls = type(self)
update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
instance = self
retries_remaining = optimistic_retry_count
while True:
instance.sqlmodel_update(update_data, update=extra_data)
session.add(instance)
try:
if commit: if commit:
await session.commit() await session.commit()
else: else:
await session.flush() await session.flush()
break # 成功,退出循环
except StaleDataError as e:
await session.rollback()
if retries_remaining <= 0:
raise OptimisticLockError(
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
model_class=cls.__name__,
record_id=str(getattr(instance, 'id', None)),
expected_version=getattr(instance, 'version', None),
original_error=e,
) from e
# 失败后重试:重新查询最新数据并重新应用更新
retries_remaining -= 1
fresh = await cls.get(session, cls.id == self.id)
if fresh is None:
raise OptimisticLockError(
message=f"{cls.__name__} 重试失败:记录已被删除",
model_class=cls.__name__,
record_id=str(getattr(self, 'id', None)),
original_error=e,
) from e
instance = fresh
if not refresh: if not refresh:
return self return instance
if load is not None: if load is not None:
cls = type(self) await session.refresh(instance)
await session.refresh(self) return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
return await cls.get(session, cls.id == self.id, load=load)
else: else:
await session.refresh(self) await session.refresh(instance)
return self return instance
@classmethod @classmethod
async def delete( async def delete(
@@ -434,68 +523,57 @@ class TableBaseMixin(AsyncAttrs):
instances: T | list[T] | None = None, instances: T | list[T] | None = None,
*, *,
condition: BinaryExpression | ClauseElement | None = None, condition: BinaryExpression | ClauseElement | None = None,
commit: bool = True commit: bool = True,
) -> int: ) -> int:
""" """
从数据库中删除记录. 从数据库中删除记录支持实例删除和条件删除两种模式
支持两种删除方式
1. 实例删除传入 instances 参数先加载再删除
2. 条件删除传入 condition 参数直接 SQL 删除更高效
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session: 用于数据库操作的异步会话对象
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表可选. instances: 要删除的单个模型实例或模型实例列表实例删除模式
condition (BinaryExpression | ClauseElement | None): 删除条件可选 instances 二选一. condition: WHERE 条件表达式条件删除模式直接执行 SQL DELETE
commit (bool): 是否提交事务 False 可在批量操作时减少提交次数 commit: 是否在删除后提交事务默认 True
之后需要手动调用 `session.commit()`默认为 True.
Returns: Returns:
int: 删除的记录数量 删除的记录数条件删除模式返回实际删除数实例删除模式返回实例数
Raises:
ValueError: 同时提供 instances condition或两者都未提供
Usage: Usage:
# 实例删除 # 实例删除模式
item_to_delete = await Item.get(session, Item.id == 1) item = await Item.get(session, Item.id == 1)
if item_to_delete: if item:
deleted_count = await Item.delete(session, item_to_delete) await Item.delete(session, item)
# 条件删除(更高效,无需加载实例) items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
if items:
await Item.delete(session, items)
# 条件删除模式(高效批量删除,不加载实例到内存)
deleted_count = await Item.delete( deleted_count = await Item.delete(
session, session,
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date) condition=(Item.user_id == user_id) & (Item.status == "expired"),
) )
# 批量删除后手动提交
await Item.delete(session, item1, commit=False)
await Item.delete(session, item2, commit=False)
await session.commit()
""" """
# 条件删除模式 if instances is not None and condition is not None:
if condition is not None: raise ValueError("不能同时提供 instances 和 condition 参数")
from sqlmodel import delete as sql_delete if instances is None and condition is None:
raise ValueError("必须提供 instances 或 condition 参数之一")
if instances is not None:
raise ValueError("不能同时指定 instances 和 condition")
# 执行条件删除
stmt = sql_delete(cls).where(condition)
result = await session.exec(stmt)
deleted_count = result.rowcount
if commit:
await session.commit()
return deleted_count
# 实例删除模式(原有逻辑)
if instances is None:
raise ValueError("必须指定 instances 或 condition")
deleted_count = 0 deleted_count = 0
if condition is not None:
# 条件删除模式:直接执行 SQL DELETE
stmt = sql_delete(cls).where(condition)
result = await session.execute(stmt)
deleted_count = result.rowcount
else:
# 实例删除模式
if isinstance(instances, list): if isinstance(instances, list):
for instance in instances: for instance in instances:
await session.delete(instance) await session.delete(instance)
deleted_count += 1 deleted_count = len(instances)
else: else:
await session.delete(instances) await session.delete(instances)
deleted_count = 1 deleted_count = 1
@@ -552,7 +630,8 @@ class TableBaseMixin(AsyncAttrs):
filter: BinaryExpression | ClauseElement | None = None, filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False, with_for_update: bool = False,
table_view: TableViewRequest | None = None, table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
populate_existing: bool = False,
created_before_datetime: datetime | None = None, created_before_datetime: datetime | None = None,
created_after_datetime: datetime | None = None, created_after_datetime: datetime | None = None,
updated_before_datetime: datetime | None = None, updated_before_datetime: datetime | None = None,
@@ -581,8 +660,10 @@ class TableBaseMixin(AsyncAttrs):
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据, options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`. 例如 `[selectinload(User.posts)]`.
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式用于预加载关联关系. load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式用于预加载关联关系.
可以是单个关系或关系列表. 可以是单个关系或关系列表支持嵌套关系预加载
例如 `User.profile` `[User.group, User.tags]`. 当传入多个关系时会自动检测依赖关系并构建链式 selectinload
例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]`
会自动构建 `selectinload(element_links).selectinload(node)`
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表. order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`. 例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件. filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
@@ -593,11 +674,16 @@ class TableBaseMixin(AsyncAttrs):
会覆盖offsetlimitorder_by及时间筛选参数 会覆盖offsetlimitorder_by及时间筛选参数
这是推荐的分页排序方式统一了所有LIST端点的参数格式 这是推荐的分页排序方式统一了所有LIST端点的参数格式
load_polymorphic: 多态子类加载选项需要与 load 参数配合使用 jti_subclasses: 多态子类加载选项需要与 load 参数配合使用
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表 - list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
- 'all': 两阶段查询只加载实际关联的子类对于 > 10 个子类的场景有明显性能收益 - 'all': 两阶段查询只加载实际关联的子类对于 > 10 个子类的场景有明显性能收益
- None默认: 不使用多态加载 - None默认: 不使用多态加载
populate_existing (bool): 如果为 True强制用数据库数据覆盖 session 中已存在的对象identity map
用于批量刷新对象避免循环调用 session.refresh() 导致的 N 次查询
注意只刷新标量字段不影响运行时属性_开头的属性
对于 STI单表继承对象推荐按子类分组查询以包含子类字段默认为 False
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录 created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录 created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录 updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
@@ -607,7 +693,7 @@ class TableBaseMixin(AsyncAttrs):
T | list[T] | None: 根据 `fetch_mode` 的设置返回单个实例实例列表或 `None`. T | list[T] | None: 根据 `fetch_mode` 的设置返回单个实例实例列表或 `None`.
Raises: Raises:
ValueError: 如果提供了无效的 `fetch_mode` load_polymorphic 未与 load 配合使用. ValueError: 如果提供了无效的 `fetch_mode` jti_subclasses 未与 load 配合使用.
Examples: Examples:
# 使用table_view参数推荐 # 使用table_view参数推荐
@@ -621,13 +707,13 @@ class TableBaseMixin(AsyncAttrs):
session, session,
ToolSet.id == tool_set_id, ToolSet.id == tool_set_id,
load=ToolSet.tools, load=ToolSet.tools,
load_polymorphic='all' # 只加载实际关联的子类 jti_subclasses='all' # 只加载实际关联的子类
) )
""" """
# 参数验证:load_polymorphic 需要与 load 配合使用 # 参数验证:jti_subclasses 需要与 load 配合使用
if load_polymorphic is not None and load is None: if jti_subclasses is not None and load is None:
raise ValueError( raise ValueError(
"load_polymorphic 参数需要与 load 参数配合使用," "jti_subclasses 参数需要与 load 参数配合使用,"
"请同时指定要加载的关系" "请同时指定要加载的关系"
) )
@@ -656,13 +742,34 @@ class TableBaseMixin(AsyncAttrs):
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列 # 对于多态基类,使用 with_polymorphic 预加载所有子类的列
# 这避免了在响应序列化时的延迟加载问题MissingGreenlet 错误) # 这避免了在响应序列化时的延迟加载问题MissingGreenlet 错误)
if issubclass(cls, PolymorphicBaseMixin): polymorphic_cls = None # 保存多态实体,用于子类关系预加载
is_polymorphic = issubclass(cls, PolymorphicBaseMixin)
is_jti = is_polymorphic and cls._is_joined_table_inheritance()
is_sti = is_polymorphic and not cls._is_joined_table_inheritance()
# JTI 模式:总是使用 with_polymorphic避免 N+1 查询)
# STI 模式:不使用 with_polymorphic批量刷新时请按子类分组查询
if is_jti:
# '*' 表示加载所有子类 # '*' 表示加载所有子类
polymorphic_cls = with_polymorphic(cls, '*') polymorphic_cls = with_polymorphic(cls, '*')
statement = select(polymorphic_cls) statement = select(polymorphic_cls)
else: else:
statement = select(cls) statement = select(cls)
# 对于 STI单表继承子类自动添加多态过滤条件
# SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤
# 这是已知行为,参考:
# - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤)
# - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整)
# 社区最佳实践是显式添加多态过滤条件
if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance():
mapper = inspect(cls)
# 检查是否有 polymorphic_identity 且不是抽象类
if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract:
poly_on = mapper.polymorphic_on
if poly_on is not None:
statement = statement.where(poly_on == mapper.polymorphic_identity)
if condition is not None: if condition is not None:
statement = statement.where(condition) statement = statement.where(condition)
@@ -688,12 +795,19 @@ class TableBaseMixin(AsyncAttrs):
# 标准化为列表 # 标准化为列表
load_list = load if isinstance(load, list) else [load] load_list = load if isinstance(load, list) else [load]
# 处理多态加载 # 构建链式 selectinload支持嵌套关系预加载
if load_polymorphic is not None: # 例如load=[NodeGroupNode.element_links, NodeGroupElementLink.node]
# 多态加载只支持单个关系 # 会构建selectinload(element_links).selectinload(node)
if len(load_list) > 1: load_chains = cls._build_load_chains(load_list)
raise ValueError("load_polymorphic 仅支持单个关系")
target_class = load_list[0].property.mapper.class_ # 处理多态加载(仅支持单链且只有一个关系)
if jti_subclasses is not None:
if len(load_chains) > 1 or len(load_chains[0]) > 1:
raise ValueError(
"jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系"
)
single_load = load_chains[0][0]
target_class = single_load.property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin # 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin): if not issubclass(target_class, PolymorphicBaseMixin):
@@ -702,26 +816,48 @@ class TableBaseMixin(AsyncAttrs):
f"请确保其继承自 PolymorphicBaseMixin" f"请确保其继承自 PolymorphicBaseMixin"
) )
if load_polymorphic == 'all': if jti_subclasses == 'all':
# 两阶段查询:获取实际关联的多态类型 # 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses( subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, load_list[0], target_class session, condition, single_load, target_class
) )
else: else:
subclasses_to_load = load_polymorphic subclasses_to_load = jti_subclasses
if subclasses_to_load: if subclasses_to_load:
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项 # 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options( statement = statement.options(
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load) selectinload(single_load).selectin_polymorphic(subclasses_to_load)
) )
else: else:
statement = statement.options(selectinload(load_list[0])) statement = statement.options(selectinload(single_load))
else: else:
# 为每个关系添加 selectinload # 为每条链构建链式 selectinload
for rel in load_list: for chain in load_chains:
statement = statement.options(selectinload(rel)) # 获取第一个关系并检查是否需要通过多态实体访问
first_rel = chain[0]
first_rel_parent = first_rel.property.parent.class_
# 如果关系的 parent_class 是当前类的子类(不是 cls 本身),
# 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问
if (
polymorphic_cls is not None
and first_rel_parent is not cls
and issubclass(first_rel_parent, cls)
):
# 通过多态实体访问子类的关系属性
# 例如polymorphic_cls.NodeGroupNode.element_links
subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__)
rel_name = first_rel.key
first_rel_via_poly = getattr(subclass_alias, rel_name)
loader = selectinload(first_rel_via_poly)
else:
loader = selectinload(first_rel)
for rel in chain[1:]:
loader = loader.selectinload(rel)
statement = statement.options(loader)
if order_by is not None: if order_by is not None:
statement = statement.order_by(*order_by) statement = statement.order_by(*order_by)
@@ -736,8 +872,18 @@ class TableBaseMixin(AsyncAttrs):
statement = statement.filter(filter) statement = statement.filter(filter)
if with_for_update: if with_for_update:
# 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制
# PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE
if issubclass(cls, PolymorphicBaseMixin):
statement = statement.with_for_update(of=cls)
else:
statement = statement.with_for_update() statement = statement.with_for_update()
if populate_existing:
# 强制用数据库数据覆盖 identity map 中的对象
# 用于批量刷新,避免循环 refresh() 的 N 次查询
statement = statement.execution_options(populate_existing=True)
result = await session.exec(statement) result = await session.exec(statement)
if fetch_mode == "one": if fetch_mode == "one":
@@ -749,6 +895,79 @@ class TableBaseMixin(AsyncAttrs):
else: else:
raise ValueError(f"无效的 fetch_mode: {fetch_mode}") raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
@staticmethod
def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]:
"""
将关系列表构建为链式加载结构
自动检测关系之间的依赖关系构建嵌套预加载链
例如[NodeGroupNode.element_links, NodeGroupElementLink.node]
会构建[[element_links, node]]一条链
算法
1. 获取每个关系的 parent class target class
2. 如果关系 B parent class 等于关系 A target class B 链在 A 后面
3. 独立的关系各自成为一条链
Args:
load_list: 关系属性列表
Returns:
链式关系列表每条链是一个关系列表
"""
if not load_list:
return []
# 构建关系信息:{关系: (parent_class, target_class)}
rel_info: dict[RelationshipInfo, tuple[type, type]] = {}
for rel in load_list:
parent_class = rel.property.parent.class_
target_class = rel.property.mapper.class_
rel_info[rel] = (parent_class, target_class)
# 构建依赖图:{关系: 其前置关系}
predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list}
for rel_b in load_list:
parent_b, _ = rel_info[rel_b]
for rel_a in load_list:
if rel_a is rel_b:
continue
_, target_a = rel_info[rel_a]
# 如果 B 的 parent 精确等于 A 的 target则 B 链在 A 后面
# 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类)
if parent_b is target_a:
predecessors[rel_b] = rel_a
break
# 找出所有链的起点(没有前置关系的)
roots = [rel for rel, pred in predecessors.items() if pred is None]
# 构建链
chains: list[list[RelationshipInfo]] = []
used: set[RelationshipInfo] = set()
for root in roots:
chain = [root]
used.add(root)
# 找后续节点
current = root
while True:
# 找以 current 的 target 为 parent 的关系
_, current_target = rel_info[current]
next_rel = None
for rel, (parent, _) in rel_info.items():
if rel not in used and parent is current_target:
next_rel = rel
break
if next_rel is None:
break
chain.append(next_rel)
used.add(next_rel)
current = next_rel
chains.append(chain)
return chains
@classmethod @classmethod
async def _resolve_polymorphic_subclasses( async def _resolve_polymorphic_subclasses(
cls: type[T], cls: type[T],
@@ -791,12 +1010,15 @@ class TableBaseMixin(AsyncAttrs):
)) ))
) )
else: else:
# 一对多关系:通过外键查询 # 多对一/一对多关系:通过外键查询
foreign_key_col = relationship_property.local_remote_pairs[0][1] # local_remote_pairs[0] = (local_fk_col, remote_pk_col)
# 对于多对一local 是当前类的外键remote 是目标类的主键
local_fk_col = relationship_property.local_remote_pairs[0][0]
remote_pk_col = relationship_property.local_remote_pairs[0][1]
type_query = ( type_query = (
select(distinct(poly_name_col)) select(distinct(poly_name_col))
.where(foreign_key_col.in_( .where(remote_pk_col.in_(
select(cls.id).where(condition) if condition is not None else select(cls.id) select(local_fk_col).where(condition) if condition is not None else select(local_fk_col)
)) ))
) )
@@ -898,7 +1120,7 @@ class TableBaseMixin(AsyncAttrs):
order_by: list[ClauseElement] | None = None, order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None, filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None, table_view: TableViewRequest | None = None,
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None, jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
) -> 'ListResponse[T]': ) -> 'ListResponse[T]':
""" """
获取分页列表及总数直接返回 ListResponse 获取分页列表及总数直接返回 ListResponse
@@ -918,7 +1140,7 @@ class TableBaseMixin(AsyncAttrs):
order_by: 排序子句 order_by: 排序子句
filter: 附加过滤条件 filter: 附加过滤条件
table_view: 分页排序参数推荐使用 table_view: 分页排序参数推荐使用
load_polymorphic: 多态子类加载选项 jti_subclasses: 多态子类加载选项
Returns: Returns:
ListResponse[T]: 包含 count items 的分页响应 ListResponse[T]: 包含 count items 的分页响应
@@ -957,7 +1179,7 @@ class TableBaseMixin(AsyncAttrs):
order_by=order_by, order_by=order_by,
filter=filter, filter=filter,
table_view=table_view, table_view=table_view,
load_polymorphic=load_polymorphic, jti_subclasses=jti_subclasses,
) )
return ListResponse(count=total_count, items=items) return ListResponse(count=total_count, items=items)
@@ -973,8 +1195,7 @@ class TableBaseMixin(AsyncAttrs):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID. id (int): 要查找的记录的主键 ID.
load (Relationship | list[Relationship] | None): 可选的用于预加载的关联属性. load (Relationship | None): 可选的用于预加载的关联属性.
可以是单个关系或关系列表.
Returns: Returns:
T: 找到的模型实例. T: 找到的模型实例.
@@ -1002,7 +1223,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
@override @override
@classmethod @classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T: async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
""" """
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常. 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
@@ -1012,8 +1233,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键. id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | list[Relationship] | None): 可选的用于预加载的关联属性. load (Relationship | None): 可选的用于预加载的关联属性.
可以是单个关系或关系列表.
Returns: Returns:
T: 找到的模型实例. T: 找到的模型实例.

View File

@@ -120,3 +120,4 @@ class MCPResponseBase(MCPBase):
result: str result: str
"""方法返回结果""" """方法返回结果"""

View File

@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal
from uuid import UUID from uuid import UUID
from enum import StrEnum from enum import StrEnum
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
from .base import SQLModelBase from .base import SQLModelBase
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
from .source_link import SourceLink from .source_link import SourceLink
from .share import Share from .share import Share
from .physical_file import PhysicalFile from .physical_file import PhysicalFile
from .uri import DiskNextURI
class ObjectType(StrEnum): class ObjectType(StrEnum):
@@ -103,7 +105,7 @@ class ObjectMoveRequest(SQLModelBase):
class ObjectDeleteRequest(SQLModelBase): class ObjectDeleteRequest(SQLModelBase):
"""删除对象请求 DTO""" """删除对象请求 DTO"""
ids: UUID | list[UUID] ids: list[UUID]
"""待删除对象UUID列表""" """待删除对象UUID列表"""
@@ -116,12 +118,12 @@ class ObjectResponse(ObjectBase):
thumb: bool = False thumb: bool = False
"""是否有缩略图""" """是否有缩略图"""
date: datetime created_at: datetime
"""对象修改时间"""
create_date: datetime
"""对象创建时间""" """对象创建时间"""
updated_at: datetime
"""对象修改时间"""
source_enabled: bool = False source_enabled: bool = False
"""是否启用离线下载源""" """是否启用离线下载源"""
@@ -138,7 +140,7 @@ class PolicyResponse(SQLModelBase):
type: StorageType type: StorageType
"""存储类型""" """存储类型"""
max_size: int = Field(ge=0, default=0) max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
"""单文件最大限制单位字节0表示不限制""" """单文件最大限制单位字节0表示不限制"""
file_type: list[str] | None = None file_type: list[str] | None = None
@@ -186,18 +188,18 @@ class Object(ObjectBase, UUIDTableBaseMixin):
合并了原有的 File Folder 模型通过 type 字段区分文件和目录 合并了原有的 File Folder 模型通过 type 字段区分文件和目录
根目录规则 根目录规则
- 每个用户有一个显式根目录对象name=用户的username, parent_id=NULL - 每个用户有一个显式根目录对象name="/", parent_id=NULL
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id - 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
- 根目录的 policy_id 指定用户默认存储策略 - 根目录的 policy_id 指定用户默认存储策略
- 路径格式/username/path/to/file /admin/docs/readme.md - 路径格式/path/to/file /docs/readme.md不包含用户名前缀
""" """
__table_args__ = ( __table_args__ = (
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况) # 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"), UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
# 名称不能包含斜杠 ([TODO] 还有特殊字符) # 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/"
CheckConstraint( CheckConstraint(
"name NOT LIKE '%/%' AND name NOT LIKE '%\\%'", "parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
name="ck_object_name_no_slash", name="ck_object_name_no_slash",
), ),
# 性能索引 # 性能索引
@@ -220,7 +222,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
# ==================== 文件专属字段 ==================== # ==================== 文件专属字段 ====================
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}) size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
"""文件大小(字节),目录为 0""" """文件大小(字节),目录为 0"""
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True) upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
@@ -374,15 +376,16 @@ class Object(ObjectBase, UUIDTableBaseMixin):
session, session,
user_id: UUID, user_id: UUID,
path: str, path: str,
username: str,
) -> "Object | None": ) -> "Object | None":
""" """
根据路径获取对象 根据路径获取对象
路径从用户根目录开始不包含用户名前缀
"/" 表示根目录"/docs/images" 表示根目录下的 docs/images
:param session: 数据库会话 :param session: 数据库会话
:param user_id: 用户UUID :param user_id: 用户UUID
:param path: 路径 "/username" "/username/docs/images" :param path: 路径 "/" "/docs/images"
:param username: 用户名用于识别根目录
:return: Object None :return: Object None
""" """
path = path.strip() path = path.strip()
@@ -403,16 +406,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
if not parts: if not parts:
return root return root
# 检查第一部分是否是用户名(根目录名) # 从根目录开始遍历路径
if parts[0] == username:
# 路径以用户名开头,如 /admin/docs
if len(parts) == 1:
# 只有用户名,返回根目录
return root
# 去掉用户名部分,从第二个部分开始遍历
parts = parts[1:]
# 从根目录开始遍历剩余路径
current = root current = root
for part in parts: for part in parts:
if not current: if not current:
@@ -443,6 +437,77 @@ class Object(ObjectBase, UUIDTableBaseMixin):
fetch_mode="all" fetch_mode="all"
) )
@classmethod
async def resolve_uri(
cls,
session,
uri: "DiskNextURI",
requesting_user_id: UUID | None = None,
) -> "Object":
"""
URI 解析为 Object 实例
分派逻辑类似 Cloudreve getNavigator
- MY user_id = uri.id(str(requesting_user_id))
验证权限自己的或管理员然后 get_by_path
- SHARE 通过 uri.fs_id Share 验证密码和有效期
获取 share.object然后沿 uri.path 遍历子对象
- TRASH 延后实现
:param session: 数据库会话
:param uri: DiskNextURI 实例
:param requesting_user_id: 请求用户UUID
:return: Object 实例
:raises ValueError: URI 无法解析
:raises PermissionError: 权限不足
:raises NotImplementedError: 不支持的命名空间
"""
from .uri import FileSystemNamespace
if uri.namespace == FileSystemNamespace.MY:
# 确定目标用户
target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None)
if not target_user_id_str:
raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id")
target_user_id = UUID(target_user_id_str)
# 权限检查:只能访问自己的空间(管理员权限由路由层判断)
if requesting_user_id and target_user_id != requesting_user_id:
raise PermissionError("无权访问其他用户的文件空间")
obj = await cls.get_by_path(session, target_user_id, uri.path)
if not obj:
raise ValueError(f"路径不存在: {uri.path}")
return obj
elif uri.namespace == FileSystemNamespace.SHARE:
raise NotImplementedError("分享空间解析尚未实现")
elif uri.namespace == FileSystemNamespace.TRASH:
raise NotImplementedError("回收站解析尚未实现")
else:
raise ValueError(f"未知的命名空间: {uri.namespace}")
async def get_full_path(self, session) -> str:
"""
从当前对象沿 parent_id 向上遍历到根目录返回完整路径
:param session: 数据库会话
:return: 完整路径 "/docs/images/photo.jpg"
"""
parts: list[str] = []
current: Object | None = self
while current and current.parent_id is not None:
parts.append(current.name)
current = await Object.get(session, Object.id == current.parent_id)
# 反转顺序(从根到当前)
parts.reverse()
return "/" + "/".join(parts)
# ==================== 上传会话模型 ==================== # ==================== 上传会话模型 ====================
@@ -452,10 +517,10 @@ class UploadSessionBase(SQLModelBase):
file_name: str = Field(max_length=255) file_name: str = Field(max_length=255)
"""原始文件名""" """原始文件名"""
file_size: int = Field(ge=0) file_size: int = Field(ge=0, sa_type=BigInteger)
"""文件总大小(字节)""" """文件总大小(字节)"""
chunk_size: int = Field(ge=1) chunk_size: int = Field(ge=1, sa_type=BigInteger)
"""分片大小(字节)""" """分片大小(字节)"""
total_chunks: int = Field(ge=1) total_chunks: int = Field(ge=1)
@@ -474,7 +539,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
uploaded_chunks: int = 0 uploaded_chunks: int = 0
"""已上传分片数""" """已上传分片数"""
uploaded_size: int = 0 uploaded_size: int = Field(default=0, sa_type=BigInteger)
"""已上传大小(字节)""" """已上传大小(字节)"""
storage_path: str | None = Field(default=None, max_length=512) storage_path: str | None = Field(default=None, max_length=512)
@@ -680,8 +745,8 @@ class AdminFileResponse(ObjectResponse):
owner_id: UUID owner_id: UUID
"""所有者UUID""" """所有者UUID"""
owner_username: str owner_email: str
"""所有者用户名""" """所有者邮箱"""
policy_name: str policy_name: str
"""存储策略名称""" """存储策略名称"""
@@ -709,12 +774,12 @@ class AdminFileResponse(ObjectResponse):
# ObjectResponse 字段 # ObjectResponse 字段
id=obj.id, id=obj.id,
thumb=False, thumb=False,
date=obj.updated_at, created_at=obj.created_at,
create_date=obj.created_at, updated_at=obj.updated_at,
source_enabled=False, source_enabled=False,
# AdminFileResponse 字段 # AdminFileResponse 字段
owner_id=obj.owner_id, owner_id=obj.owner_id,
owner_username=owner.username if owner else "unknown", owner_email=owner.email if owner else "unknown",
policy_name=policy.name if policy else "unknown", policy_name=policy.name if policy else "unknown",
is_banned=obj.is_banned, is_banned=obj.is_banned,
banned_at=obj.banned_at, banned_at=obj.banned_at,
@@ -725,7 +790,7 @@ class AdminFileResponse(ObjectResponse):
class FileBanRequest(SQLModelBase): class FileBanRequest(SQLModelBase):
"""文件封禁请求 DTO""" """文件封禁请求 DTO"""
is_banned: bool = True ban: bool = True
"""是否封禁""" """是否封禁"""
reason: str | None = Field(default=None, max_length=500) reason: str | None = Field(default=None, max_length=500)

View File

@@ -12,6 +12,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlalchemy import BigInteger
from sqlmodel import Field, Relationship, Index from sqlmodel import Field, Relationship, Index
from .base import SQLModelBase from .base import SQLModelBase
@@ -28,7 +29,7 @@ class PhysicalFileBase(SQLModelBase):
storage_path: str = Field(max_length=512) storage_path: str = Field(max_length=512)
"""物理存储路径(相对于存储策略根目录)""" """物理存储路径(相对于存储策略根目录)"""
size: int = 0 size: int = Field(default=0, sa_type=BigInteger)
"""文件大小(字节)""" """文件大小(字节)"""
checksum_md5: str | None = Field(default=None, max_length=32) checksum_md5: str | None = Field(default=None, max_length=32)

View File

@@ -20,16 +20,10 @@ class SiteConfigResponse(SQLModelBase):
title: str = "DiskNext" title: str = "DiskNext"
"""网站标题""" """网站标题"""
# themes: dict[str, str] = {}
# """网站主题配置"""
# default_theme: dict[str, str] = {}
# """默认主题RGB色号"""
site_notice: str | None = None site_notice: str | None = None
"""网站公告""" """网站公告"""
user: UserResponse user: UserResponse | None = None
"""用户信息""" """用户信息"""
logo_light: str | None = None logo_light: str | None = None
@@ -38,11 +32,23 @@ class SiteConfigResponse(SQLModelBase):
logo_dark: str | None = None logo_dark: str | None = None
"""网站Logo URL深色模式""" """网站Logo URL深色模式"""
captcha_type: CaptchaType | None = None register_enabled: bool = True
"""是否允许注册"""
login_captcha: bool = False
"""登录是否需要验证码"""
reg_captcha: bool = False
"""注册是否需要验证码"""
forget_captcha: bool = False
"""找回密码是否需要验证码"""
captcha_type: CaptchaType = CaptchaType.DEFAULT
"""验证码类型""" """验证码类型"""
captcha_key: str | None = None captcha_key: str | None = None
"""验证码密钥""" """验证码 public keyDEFAULT 类型时为 None"""
# ==================== 管理员设置 DTO ==================== # ==================== 管理员设置 DTO ====================

View File

@@ -215,6 +215,6 @@ class AdminShareListItem(ShareListItemBase):
"""从 Share ORM 对象构建""" """从 Share ORM 对象构建"""
return cls( return cls(
**ShareListItemBase.model_validate(share, from_attributes=True).model_dump(), **ShareListItemBase.model_validate(share, from_attributes=True).model_dump(),
username=user.username if user else None, username=user.email if user else None,
object_name=obj.name if obj else None, object_name=obj.name if obj else None,
) )

View File

@@ -73,7 +73,7 @@ class TaskSummary(TaskSummaryBase):
"""从 Task ORM 对象构建""" """从 Task ORM 对象构建"""
return cls( return cls(
**TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(), **TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(),
username=user.username if user else None, username=user.email if user else None,
) )

258
sqlmodels/uri.py Normal file
View File

@@ -0,0 +1,258 @@
from enum import StrEnum
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
from .base import SQLModelBase
class FileSystemNamespace(StrEnum):
"""文件系统命名空间"""
MY = "my"
"""用户个人空间"""
SHARE = "share"
"""分享空间"""
TRASH = "trash"
"""回收站"""
class DiskNextURI(SQLModelBase):
"""
DiskNext 文件 URI
URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query]
fs_id 可省略:
- my/trash 命名空间省略时默认当前用户
- share 命名空间必须提供 fs_idShare.code
"""
fs_id: str | None = None
"""文件系统标识符,可省略"""
namespace: FileSystemNamespace
"""命名空间"""
path: str = "/"
"""路径"""
password: str | None = None
"""访问密码(用于有密码的分享)"""
query: dict[str, str] | None = None
"""查询参数"""
# === 属性 ===
@property
def path_parts(self) -> list[str]:
"""路径分割为列表(过滤空串)"""
return [p for p in self.path.split("/") if p]
@property
def is_root(self) -> bool:
"""是否指向根目录"""
return self.path.strip("/") == ""
# === 核心方法 ===
def id(self, default_id: str | None = None) -> str | None:
"""
获取 fs_id省略时返回 default_id
参考 Cloudreve URI.ID(defaultUid) 方法
:param default_id: 默认值(通常为当前用户 ID
:return: fs_id 或 default_id
"""
return self.fs_id if self.fs_id else default_id
# === 类方法 ===
@classmethod
def parse(cls, uri: str) -> "DiskNextURI":
"""
解析 URI 字符串
实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析
- hostname → namespace
- username → fs_id
- password → password
- path → path
- query → query dict
:param uri: URI 字符串,如 "disknext://my/docs/readme.md"
:return: DiskNextURI 实例
:raises ValueError: URI 格式无效
"""
if not uri.startswith("disknext://"):
raise ValueError(f"URI 必须以 disknext:// 开头: {uri}")
# 替换协议为 http:// 以利用 urllib.parse 解析
http_uri = "http://" + uri[len("disknext://"):]
parsed = urlparse(http_uri)
# 解析 namespace
hostname = parsed.hostname
if not hostname:
raise ValueError(f"URI 缺少命名空间: {uri}")
try:
namespace = FileSystemNamespace(hostname)
except ValueError:
raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}")
# 解析 fs_id 和 password
fs_id = unquote(parsed.username) if parsed.username else None
password = unquote(parsed.password) if parsed.password else None
# 解析 path
path = unquote(parsed.path) if parsed.path else "/"
if not path:
path = "/"
# 解析 query
query: dict[str, str] | None = None
if parsed.query:
raw_query = parse_qs(parsed.query, keep_blank_values=True)
query = {k: v[0] for k, v in raw_query.items()}
return cls(
fs_id=fs_id,
namespace=namespace,
path=path,
password=password,
query=query,
)
@classmethod
def build(
cls,
namespace: FileSystemNamespace,
path: str = "/",
fs_id: str | None = None,
password: str | None = None,
) -> "DiskNextURI":
"""
构建 URI 实例
:param namespace: 命名空间
:param path: 路径
:param fs_id: 文件系统标识符
:param password: 访问密码
:return: DiskNextURI 实例
"""
# 确保 path 以 / 开头
if not path.startswith("/"):
path = "/" + path
return cls(
fs_id=fs_id,
namespace=namespace,
path=path,
password=password,
)
# === 实例方法 ===
def to_string(self) -> str:
"""
序列化为 URI 字符串
:return: URI 字符串,如 "disknext://my/docs/readme.md"
"""
result = "disknext://"
# fs_id 和 password
if self.fs_id:
result += quote(self.fs_id, safe="")
if self.password:
result += ":" + quote(self.password, safe="")
result += "@"
# namespace
result += self.namespace.value
# path
result += self.path
# query
if self.query:
result += "?" + urlencode(self.query)
return result
def join(self, *elements: str) -> "DiskNextURI":
"""
拼接路径元素,返回新 URI
:param elements: 路径元素
:return: 新的 DiskNextURI 实例
"""
base = self.path.rstrip("/")
for element in elements:
element = element.strip("/")
if element:
base += "/" + element
if not base:
base = "/"
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path=base,
password=self.password,
query=self.query,
)
def dir_uri(self) -> "DiskNextURI":
"""
返回父目录的 URI
:return: 父目录的 DiskNextURI 实例
"""
parts = self.path_parts
if not parts:
# 已经是根目录
return self.root()
parent_path = "/" + "/".join(parts[:-1])
if not parent_path.endswith("/"):
parent_path += "/"
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path=parent_path,
password=self.password,
)
def root(self) -> "DiskNextURI":
"""
返回根目录的 URI保留 namespace 和 fs_id
:return: 根目录的 DiskNextURI 实例
"""
return DiskNextURI(
fs_id=self.fs_id,
namespace=self.namespace,
path="/",
password=self.password,
)
def name(self) -> str:
"""
返回路径的最后一段(文件名或目录名)
:return: 文件名或目录名,根目录返回空字符串
"""
parts = self.path_parts
return parts[-1] if parts else ""
def __str__(self) -> str:
return self.to_string()
def __repr__(self) -> str:
return f"DiskNextURI({self.to_string()!r})"

View File

@@ -60,8 +60,8 @@ class UserFilterParams(SQLModelBase):
group_id: UUID | None = None group_id: UUID | None = None
"""按用户组UUID筛选""" """按用户组UUID筛选"""
username_contains: str | None = Field(default=None, max_length=50) email_contains: str | None = Field(default=None, max_length=50)
"""用户名包含(不区分大小写的模糊搜索)""" """邮箱包含(不区分大小写的模糊搜索)"""
nickname_contains: str | None = Field(default=None, max_length=50) nickname_contains: str | None = Field(default=None, max_length=50)
"""昵称包含(不区分大小写的模糊搜索)""" """昵称包含(不区分大小写的模糊搜索)"""
@@ -75,8 +75,8 @@ class UserFilterParams(SQLModelBase):
class UserBase(SQLModelBase): class UserBase(SQLModelBase):
"""用户基础字段,供数据库模型和 DTO 共享""" """用户基础字段,供数据库模型和 DTO 共享"""
username: str email: str
"""用户""" """用户邮箱"""
status: UserStatus = UserStatus.ACTIVE status: UserStatus = UserStatus.ACTIVE
"""用户状态""" """用户状态"""
@@ -90,8 +90,8 @@ class UserBase(SQLModelBase):
class LoginRequest(SQLModelBase): class LoginRequest(SQLModelBase):
"""登录请求 DTO""" """登录请求 DTO"""
username: str email: str
"""用户名或邮箱""" """用户邮箱"""
password: str password: str
"""用户密码""" """用户密码"""
@@ -106,8 +106,8 @@ class LoginRequest(SQLModelBase):
class RegisterRequest(SQLModelBase): class RegisterRequest(SQLModelBase):
"""注册请求 DTO""" """注册请求 DTO"""
username: str email: str
"""用户,唯一,一经注册不可更改""" """用户邮箱,唯一"""
password: str password: str
"""用户密码""" """用户密码"""
@@ -116,6 +116,20 @@ class RegisterRequest(SQLModelBase):
"""验证码""" """验证码"""
class BatchDeleteRequest(SQLModelBase):
"""批量删除请求 DTO"""
ids: list[UUID]
"""待删除 UUID 列表"""
class RefreshTokenRequest(SQLModelBase):
"""刷新令牌请求 DTO"""
refresh_token: str
"""刷新令牌"""
class WebAuthnInfo(SQLModelBase): class WebAuthnInfo(SQLModelBase):
"""WebAuthn 信息 DTO""" """WebAuthn 信息 DTO"""
@@ -166,6 +180,9 @@ class UserResponse(ResponseBase):
id: UUID id: UUID
"""用户UUID""" """用户UUID"""
email: str
"""用户邮箱"""
nickname: str | None = None nickname: str | None = None
"""用户昵称""" """用户昵称"""
@@ -184,11 +201,23 @@ class UserResponse(ResponseBase):
tags: list[str] = [] tags: list[str] = []
"""用户标签列表""" """用户标签列表"""
class UserStorageResponse(SQLModelBase):
"""用户存储信息 DTO"""
used: int
"""已用存储空间(字节)"""
free: int
"""剩余存储空间(字节)"""
total: int
"""总存储空间(字节)"""
class UserPublic(UserBase): class UserPublic(UserBase):
"""用户公开信息 DTO用于 API 响应""" """用户公开信息 DTO用于 API 响应"""
id: UUID | None = None id: UUID
"""用户UUID""" """用户UUID"""
nickname: str | None = None nickname: str | None = None
@@ -206,6 +235,9 @@ class UserPublic(UserBase):
group_id: UUID | None = None group_id: UUID | None = None
"""所属用户组UUID""" """所属用户组UUID"""
group_name: str | None = None
"""用户组名称"""
two_factor: str | None = None two_factor: str | None = None
"""两步验证密钥32位字符串null 表示未启用)""" """两步验证密钥32位字符串null 表示未启用)"""
@@ -219,30 +251,64 @@ class UserPublic(UserBase):
class UserSettingResponse(SQLModelBase): class UserSettingResponse(SQLModelBase):
"""用户设置响应 DTO""" """用户设置响应 DTO"""
authn: "AuthnResponse | None" = None id: UUID
"""用户UUID"""
email: str
"""用户邮箱"""
nickname: str | None = None
"""昵称"""
created_at: datetime
"""用户注册时间"""
group_name: str
"""用户所属用户组名称"""
language: str
"""语言偏好"""
timezone: int
"""时区"""
authn: "list[AuthnResponse] | None" = None
"""认证信息""" """认证信息"""
group_expires: datetime | None = None group_expires: datetime | None = None
"""用户组过期时间""" """用户组过期时间"""
prefer_theme: str = "#5898d4"
"""用户首选主题"""
themes: dict[str, str] = {}
"""用户主题配置"""
two_factor: bool = False two_factor: bool = False
"""是否启用两步验证""" """是否启用两步验证"""
uid: UUID | None = None
"""用户UUID"""
# ==================== 管理员用户管理 DTO ==================== # ==================== 管理员用户管理 DTO ====================
class UserAdminCreateRequest(SQLModelBase):
"""管理员创建用户请求 DTO"""
email: str = Field(max_length=50)
"""用户邮箱"""
password: str
"""用户密码(明文,由服务端加密)"""
nickname: str | None = Field(default=None, max_length=50)
"""昵称"""
group_id: UUID
"""所属用户组UUID"""
status: UserStatus = UserStatus.ACTIVE
"""用户状态"""
class UserAdminUpdateRequest(SQLModelBase): class UserAdminUpdateRequest(SQLModelBase):
"""管理员更新用户请求 DTO""" """管理员更新用户请求 DTO"""
email: str = Field(max_length=50)
"""邮箱"""
nickname: str | None = Field(default=None, max_length=50) nickname: str | None = Field(default=None, max_length=50)
"""昵称""" """昵称"""
@@ -317,8 +383,8 @@ UserSettingResponse.model_rebuild()
class User(UserBase, UUIDTableBaseMixin): class User(UserBase, UUIDTableBaseMixin):
"""用户模型""" """用户模型"""
username: str = Field(max_length=50, unique=True, index=True) email: str = Field(max_length=50, unique=True, index=True)
"""用户,唯一,一经注册不可更改""" """用户邮箱,唯一"""
nickname: str | None = Field(default=None, max_length=50) nickname: str | None = Field(default=None, max_length=50)
"""用于公开展示的名字,可使用真实姓名或昵称""" """用于公开展示的名字,可使用真实姓名或昵称"""
@@ -426,8 +492,10 @@ class User(UserBase, UUIDTableBaseMixin):
) )
def to_public(self) -> "UserPublic": def to_public(self) -> "UserPublic":
"""转换为公开 DTO排除敏感字段""" """转换为公开 DTO排除敏感字段。需要预加载 group 关系。"""
return UserPublic.model_validate(self) data = UserPublic.model_validate(self)
data.group_name = self.group.name
return data
@classmethod @classmethod
async def get_with_count( async def get_with_count(
@@ -457,8 +525,8 @@ class User(UserBase, UUIDTableBaseMixin):
if filter_params.group_id is not None: if filter_params.group_id is not None:
filter_conditions.append(cls.group_id == filter_params.group_id) filter_conditions.append(cls.group_id == filter_params.group_id)
if filter_params.username_contains is not None: if filter_params.email_contains is not None:
filter_conditions.append(cls.username.ilike(f"%{filter_params.username_contains}%")) filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%"))
if filter_params.nickname_contains is not None: if filter_params.nickname_contains is not None:
filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%")) filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%"))
@@ -483,3 +551,4 @@ class User(UserBase, UUIDTableBaseMixin):
filter=filter, filter=filter,
table_view=table_view, table_view=table_view,
) )

View File

@@ -49,13 +49,13 @@ def main():
("itsdangerous", "签名工具"), ("itsdangerous", "签名工具"),
# 项目模块 # 项目模块
("models", "数据库模型"), ("sqlmodels", "数据库模型"),
("models.user", "用户模型"), ("sqlmodels.user", "用户模型"),
("models.group", "用户组模型"), ("sqlmodels.group", "用户组模型"),
("models.object", "对象模型"), ("sqlmodels.object", "对象模型"),
("models.setting", "设置模型"), ("sqlmodels.setting", "设置模型"),
("models.policy", "策略模型"), ("sqlmodels.policy", "策略模型"),
("models.database", "数据库连接"), ("sqlmodels.database", "数据库连接"),
("utils.password.pwd", "密码工具"), ("utils.password.pwd", "密码工具"),
("utils.JWT.JWT", "JWT 工具"), ("utils.JWT.JWT", "JWT 工具"),
("service.user.login", "登录服务"), ("service.user.login", "登录服务"),

View File

@@ -23,12 +23,12 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from main import app from main import app
from models.database import get_session from sqlmodels.database import get_session
from models.group import Group, GroupOptions from sqlmodels.group import Group, GroupOptions
from models.migration import migration from sqlmodels.migration import migration
from models.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
from models.user import User from sqlmodels.user import User
from utils.JWT.JWT import create_access_token from utils.JWT.JWT import create_access_token
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession):
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]: async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
""" """
创建测试用户并返回 {id, username, password, token} 创建测试用户并返回 {id, email, password, token}
创建一个普通用户,包含用户组、存储策略和根目录。 创建一个普通用户,包含用户组、存储策略和根目录。
""" """
@@ -190,7 +190,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建测试用户 # 创建测试用户
password = "test_password_123" password = "test_password_123"
user = User( user = User(
username="testuser", email="testuser@test.local",
nickname="测试用户", nickname="测试用户",
password=Password.hash(password), password=Password.hash(password),
status=True, status=True,
@@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建用户根目录 # 创建用户根目录
root_folder = Object( root_folder = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -216,7 +216,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
return { return {
"id": user.id, "id": user.id,
"username": user.username, "email": user.email,
"password": password, "password": password,
"token": access_token, "token": access_token,
"group_id": group.id, "group_id": group.id,
@@ -227,7 +227,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]: async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
""" """
获取管理员用户 {id, username, token} 获取管理员用户 {id, email, token}
创建具有管理员权限的用户。 创建具有管理员权限的用户。
""" """
@@ -267,7 +267,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建管理员用户 # 创建管理员用户
password = "admin_password_456" password = "admin_password_456"
admin = User( admin = User(
username="admin", email="admin@disknext.local",
nickname="管理员", nickname="管理员",
password=Password.hash(password), password=Password.hash(password),
status=True, status=True,
@@ -279,7 +279,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建管理员根目录 # 创建管理员根目录
root_folder = Object( root_folder = Object(
name=admin.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=admin.id, owner_id=admin.id,
@@ -293,7 +293,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
return { return {
"id": admin.id, "id": admin.id,
"username": admin.username, "email": admin.email,
"password": password, "password": password,
"token": access_token, "token": access_token,
"group_id": admin_group.id, "group_id": admin_group.id,

View File

@@ -8,9 +8,9 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User from sqlmodels.user import User
from models.group import Group from sqlmodels.group import Group
from models.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
from tests.fixtures import UserFactory, GroupFactory, ObjectFactory from tests.fixtures import UserFactory, GroupFactory, ObjectFactory
@@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession):
user = await UserFactory.create( user = await UserFactory.create(
db_session, db_session,
group_id=group.id, group_id=group.id,
username="testuser", email="testuser@test.local",
password="password123" password="password123"
) )
# 验证 # 验证
assert user.id is not None assert user.id is not None
assert user.username == "testuser" assert user.email == "testuser@test.local"
assert user.group_id == group.id assert user.group_id == group.id
assert user.status is True assert user.status is True
@@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession):
async def test_object_factory(db_session: AsyncSession): async def test_object_factory(db_session: AsyncSession):
"""测试对象工厂的基本功能""" """测试对象工厂的基本功能"""
# 准备依赖 # 准备依赖
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = await GroupFactory.create(db_session) group = await GroupFactory.create(db_session)
user = await UserFactory.create(db_session, group_id=group.id) user = await UserFactory.create(db_session, group_id=group.id)
@@ -102,7 +102,7 @@ async def test_conftest_fixtures(
"""测试 conftest.py 中的 fixtures""" """测试 conftest.py 中的 fixtures"""
# 验证 test_user fixture # 验证 test_user fixture
assert test_user["id"] is not None assert test_user["id"] is not None
assert test_user["username"] == "testuser" assert test_user["email"] == "testuser@test.local"
assert test_user["token"] is not None assert test_user["token"] is not None
# 验证 auth_headers fixture # 验证 auth_headers fixture
@@ -112,7 +112,7 @@ async def test_conftest_fixtures(
# 验证用户在数据库中存在 # 验证用户在数据库中存在
user = await User.get(db_session, User.id == test_user["id"]) user = await User.get(db_session, User.id == test_user["id"])
assert user is not None assert user is not None
assert user.username == test_user["username"] assert user.email == test_user["email"]
@pytest.mark.integration @pytest.mark.integration
@@ -145,7 +145,7 @@ async def test_test_directory_fixture(
@pytest.mark.integration @pytest.mark.integration
async def test_nested_structure_factory(db_session: AsyncSession): async def test_nested_structure_factory(db_session: AsyncSession):
"""测试嵌套结构工厂""" """测试嵌套结构工厂"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
# 准备依赖 # 准备依赖
group = await GroupFactory.create(db_session) group = await GroupFactory.create(db_session)

View File

@@ -5,7 +5,7 @@
""" """
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.group import Group, GroupOptions from sqlmodels.group import Group, GroupOptions
class GroupFactory: class GroupFactory:

View File

@@ -7,8 +7,8 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
from models.user import User from sqlmodels.user import User
class ObjectFactory: class ObjectFactory:
@@ -119,7 +119,7 @@ class ObjectFactory:
Object: 创建的根目录实例 Object: 创建的根目录实例
""" """
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User from sqlmodels.user import User
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -18,7 +18,7 @@ class UserFactory:
async def create( async def create(
session: AsyncSession, session: AsyncSession,
group_id: UUID, group_id: UUID,
username: str | None = None, email: str | None = None,
password: str | None = None, password: str | None = None,
**kwargs **kwargs
) -> User: ) -> User:
@@ -28,7 +28,7 @@ class UserFactory:
参数: 参数:
session: 数据库会话 session: 数据库会话
group_id: 用户组UUID group_id: 用户组UUID
username: 用户(默认: test_user_{随机} email: 用户邮箱(默认: test_user_{随机}@test.local
password: 明文密码(默认: password123 password: 明文密码(默认: password123
**kwargs: 其他用户字段 **kwargs: 其他用户字段
@@ -37,15 +37,15 @@ class UserFactory:
""" """
import uuid import uuid
if username is None: if email is None:
username = f"test_user_{uuid.uuid4().hex[:8]}" email = f"test_user_{uuid.uuid4().hex[:8]}@test.local"
if password is None: if password is None:
password = "password123" password = "password123"
user = User( user = User(
username=username, email=email,
nickname=kwargs.get("nickname", username), nickname=kwargs.get("nickname", email),
password=Password.hash(password), password=Password.hash(password),
status=kwargs.get("status", True), status=kwargs.get("status", True),
storage=kwargs.get("storage", 0), storage=kwargs.get("storage", 0),
@@ -67,7 +67,7 @@ class UserFactory:
async def create_admin( async def create_admin(
session: AsyncSession, session: AsyncSession,
admin_group_id: UUID, admin_group_id: UUID,
username: str | None = None, email: str | None = None,
password: str | None = None password: str | None = None
) -> User: ) -> User:
""" """
@@ -76,7 +76,7 @@ class UserFactory:
参数: 参数:
session: 数据库会话 session: 数据库会话
admin_group_id: 管理员组UUID admin_group_id: 管理员组UUID
username: 用户(默认: admin_{随机} email: 用户邮箱(默认: admin_{随机}@disknext.local
password: 明文密码(默认: admin_password password: 明文密码(默认: admin_password
返回: 返回:
@@ -84,15 +84,15 @@ class UserFactory:
""" """
import uuid import uuid
if username is None: if email is None:
username = f"admin_{uuid.uuid4().hex[:8]}" email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local"
if password is None: if password is None:
password = "admin_password" password = "admin_password"
admin = User( admin = User(
username=username, email=email,
nickname=f"管理员 {username}", nickname=f"管理员 {email}",
password=Password.hash(password), password=Password.hash(password),
status=True, status=True,
storage=0, storage=0,
@@ -108,7 +108,7 @@ class UserFactory:
async def create_banned( async def create_banned(
session: AsyncSession, session: AsyncSession,
group_id: UUID, group_id: UUID,
username: str | None = None email: str | None = None
) -> User: ) -> User:
""" """
创建被封禁用户 创建被封禁用户
@@ -116,19 +116,19 @@ class UserFactory:
参数: 参数:
session: 数据库会话 session: 数据库会话
group_id: 用户组UUID group_id: 用户组UUID
username: 用户(默认: banned_user_{随机} email: 用户邮箱(默认: banned_user_{随机}@test.local
返回: 返回:
User: 创建的被封禁用户实例 User: 创建的被封禁用户实例
""" """
import uuid import uuid
if username is None: if email is None:
username = f"banned_user_{uuid.uuid4().hex[:8]}" email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local"
banned_user = User( banned_user = User(
username=username, email=email,
nickname=f"封禁用户 {username}", nickname=f"封禁用户 {email}",
password=Password.hash("banned_password"), password=Password.hash("banned_password"),
status=False, # 封禁状态 status=False, # 封禁状态
storage=0, storage=0,
@@ -145,7 +145,7 @@ class UserFactory:
session: AsyncSession, session: AsyncSession,
group_id: UUID, group_id: UUID,
storage_bytes: int, storage_bytes: int,
username: str | None = None email: str | None = None
) -> User: ) -> User:
""" """
创建已使用指定存储空间的用户 创建已使用指定存储空间的用户
@@ -154,19 +154,19 @@ class UserFactory:
session: 数据库会话 session: 数据库会话
group_id: 用户组UUID group_id: 用户组UUID
storage_bytes: 已使用的存储空间(字节) storage_bytes: 已使用的存储空间(字节)
username: 用户(默认: storage_user_{随机} email: 用户邮箱(默认: storage_user_{随机}@test.local
返回: 返回:
User: 创建的用户实例 User: 创建的用户实例
""" """
import uuid import uuid
if username is None: if email is None:
username = f"storage_user_{uuid.uuid4().hex[:8]}" email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local"
user = User( user = User(
username=username, email=email,
nickname=username, nickname=email,
password=Password.hash("password123"), password=Password.hash("password123"),
status=True, status=True,
storage=storage_bytes, storage=storage_bytes,

View File

@@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data(
if len(users) > 0: if len(users) > 0:
user = users[0] user = users[0]
assert "id" in user assert "id" in user
assert "username" in user assert "email" in user
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient):
"""测试创建用户需要认证""" """测试创建用户需要认证"""
response = await async_client.post( response = await async_client.post(
"/api/admin/user/create", "/api/admin/user/create",
json={"username": "newadminuser", "password": "pass123"} json={"email": "newadminuser@test.local", "password": "pass123"}
) )
assert response.status_code == 401 assert response.status_code == 401
@@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin(
response = await async_client.post( response = await async_client.post(
"/api/admin/user/create", "/api/admin/user/create",
headers=auth_headers, headers=auth_headers,
json={"username": "newadminuser", "password": "pass123"} json={"email": "newadminuser@test.local", "password": "pass123"}
) )
assert response.status_code == 403 assert response.status_code == 403

View File

@@ -11,7 +11,7 @@ from uuid import UUID
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_directory_requires_auth(async_client: AsyncClient): async def test_directory_requires_auth(async_client: AsyncClient):
"""测试获取目录需要认证""" """测试获取目录需要认证"""
response = await async_client.get("/api/directory/testuser") response = await async_client.get("/api/directory/")
assert response.status_code == 401 assert response.status_code == 401
@@ -24,7 +24,7 @@ async def test_directory_get_root(
): ):
"""测试获取用户根目录""" """测试获取用户根目录"""
response = await async_client.get( response = await async_client.get(
"/api/directory/testuser", "/api/directory/",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -45,7 +45,7 @@ async def test_directory_get_nested(
): ):
"""测试获取嵌套目录""" """测试获取嵌套目录"""
response = await async_client.get( response = await async_client.get(
"/api/directory/testuser/docs", "/api/directory/docs",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -63,7 +63,7 @@ async def test_directory_get_contains_children(
): ):
"""测试目录包含子对象""" """测试目录包含子对象"""
response = await async_client.get( response = await async_client.get(
"/api/directory/testuser/docs", "/api/directory/docs",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -75,19 +75,6 @@ async def test_directory_get_contains_children(
assert len(objects) >= 1 assert len(objects) >= 1
@pytest.mark.asyncio
async def test_directory_forbidden_other_user(
async_client: AsyncClient,
auth_headers: dict[str, str]
):
"""测试访问他人目录返回 403"""
response = await async_client.get(
"/api/directory/admin",
headers=auth_headers
)
assert response.status_code == 403
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_directory_not_found( async def test_directory_not_found(
async_client: AsyncClient, async_client: AsyncClient,
@@ -95,23 +82,23 @@ async def test_directory_not_found(
): ):
"""测试目录不存在返回 404""" """测试目录不存在返回 404"""
response = await async_client.get( response = await async_client.get(
"/api/directory/testuser/nonexistent", "/api/directory/nonexistent",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 404 assert response.status_code == 404
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_directory_empty_path_returns_400( async def test_directory_root_returns_200(
async_client: AsyncClient, async_client: AsyncClient,
auth_headers: dict[str, str] auth_headers: dict[str, str]
): ):
"""测试空路径返回 400""" """测试根目录端点返回 200"""
response = await async_client.get( response = await async_client.get(
"/api/directory/", "/api/directory/",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 400 assert response.status_code == 200
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -121,7 +108,7 @@ async def test_directory_response_includes_policy(
): ):
"""测试目录响应包含存储策略""" """测试目录响应包含存储策略"""
response = await async_client.get( response = await async_client.get(
"/api/directory/testuser", "/api/directory/",
headers=auth_headers headers=auth_headers
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent(
"""测试在他人目录下创建目录返回 404""" """测试在他人目录下创建目录返回 404"""
# 先用管理员账号获取管理员的根目录ID # 先用管理员账号获取管理员的根目录ID
admin_response = await async_client.get( admin_response = await async_client.get(
"/api/directory/admin", "/api/directory/",
headers=admin_headers headers=admin_headers
) )
assert admin_response.status_code == 200 assert admin_response.status_code == 200

View File

@@ -16,7 +16,7 @@ async def test_user_login_success(
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ data={
"username": test_user_info["username"], "username": test_user_info["email"],
"password": test_user_info["password"], "password": test_user_info["password"],
} }
) )
@@ -38,7 +38,7 @@ async def test_user_login_wrong_password(
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ data={
"username": test_user_info["username"], "username": test_user_info["email"],
"password": "wrongpassword", "password": "wrongpassword",
} }
) )
@@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ data={
"username": "nonexistent", "username": "nonexistent@test.local",
"password": "anypassword", "password": "anypassword",
} }
) )
@@ -67,7 +67,7 @@ async def test_user_login_user_banned(
response = await async_client.post( response = await async_client.post(
"/api/user/session", "/api/user/session",
data={ data={
"username": banned_user_info["username"], "username": banned_user_info["email"],
"password": banned_user_info["password"], "password": banned_user_info["password"],
} }
) )
@@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient):
response = await async_client.post( response = await async_client.post(
"/api/user/", "/api/user/",
json={ json={
"username": "newuser", "email": "newuser@test.local",
"password": "newpass123", "password": "newpass123",
} }
) )
@@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient):
data = response.json() data = response.json()
assert "data" in data assert "data" in data
assert "user_id" in data["data"] assert "user_id" in data["data"]
assert "username" in data["data"] assert "email" in data["data"]
assert data["data"]["username"] == "newuser" assert data["data"]["email"] == "newuser@test.local"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_register_duplicate_username( async def test_user_register_duplicate_email(
async_client: AsyncClient, async_client: AsyncClient,
test_user_info: dict[str, str] test_user_info: dict[str, str]
): ):
"""测试重复用户名返回 400""" """测试重复邮箱返回 400"""
response = await async_client.post( response = await async_client.post(
"/api/user/", "/api/user/",
json={ json={
"username": test_user_info["username"], "email": test_user_info["email"],
"password": "anypassword", "password": "anypassword",
} }
) )
@@ -143,8 +143,8 @@ async def test_user_me_returns_user_info(
assert "data" in data assert "data" in data
user_data = data["data"] user_data = data["data"]
assert "id" in user_data assert "id" in user_data
assert "username" in user_data assert "email" in user_data
assert user_data["username"] == "testuser" assert user_data["email"] == "testuser@test.local"
assert "group" in user_data assert "group" in user_data
assert "tags" in user_data assert "tags" in user_data

View File

@@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from main import app from main import app
from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from utils import Password from utils import Password
from utils.JWT import create_access_token from utils.JWT import create_access_token
from utils.JWT import JWT from utils.JWT import JWT
@@ -92,6 +92,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
Setting(type=SettingsType.VIEW, name="home_view_method", value="list"), Setting(type=SettingsType.VIEW, name="home_view_method", value="list"),
Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"), Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"),
Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"), Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"),
Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"),
Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""), Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""),
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""), Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"), Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
@@ -180,7 +181,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 6. 创建测试用户 # 6. 创建测试用户
test_user = User( test_user = User(
id=uuid4(), id=uuid4(),
username="testuser", email="testuser@test.local",
password=Password.hash("testpass123"), password=Password.hash("testpass123"),
nickname="测试用户", nickname="测试用户",
status=True, status=True,
@@ -194,7 +195,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
admin_user = User( admin_user = User(
id=uuid4(), id=uuid4(),
username="admin", email="admin@disknext.local",
password=Password.hash("adminpass123"), password=Password.hash("adminpass123"),
nickname="管理员", nickname="管理员",
status=True, status=True,
@@ -208,7 +209,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
banned_user = User( banned_user = User(
id=uuid4(), id=uuid4(),
username="banneduser", email="banneduser@test.local",
password=Password.hash("banned123"), password=Password.hash("banned123"),
nickname="封禁用户", nickname="封禁用户",
status=False, # 封禁状态 status=False, # 封禁状态
@@ -230,7 +231,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 7. 创建用户根目录 # 7. 创建用户根目录
test_user_root = Object( test_user_root = Object(
id=uuid4(), id=uuid4(),
name=test_user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
owner_id=test_user.id, owner_id=test_user.id,
parent_id=None, parent_id=None,
@@ -241,7 +242,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
admin_user_root = Object( admin_user_root = Object(
id=uuid4(), id=uuid4(),
name=admin_user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
owner_id=admin_user.id, owner_id=admin_user.id,
parent_id=None, parent_id=None,
@@ -264,7 +265,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
def test_user_info() -> dict[str, str]: def test_user_info() -> dict[str, str]:
"""测试用户信息""" """测试用户信息"""
return { return {
"username": "testuser", "email": "testuser@test.local",
"password": "testpass123", "password": "testpass123",
} }
@@ -273,7 +274,7 @@ def test_user_info() -> dict[str, str]:
def admin_user_info() -> dict[str, str]: def admin_user_info() -> dict[str, str]:
"""管理员用户信息""" """管理员用户信息"""
return { return {
"username": "admin", "email": "admin@disknext.local",
"password": "adminpass123", "password": "adminpass123",
} }
@@ -282,7 +283,7 @@ def admin_user_info() -> dict[str, str]:
def banned_user_info() -> dict[str, str]: def banned_user_info() -> dict[str, str]:
"""封禁用户信息""" """封禁用户信息"""
return { return {
"username": "banneduser", "email": "banneduser@test.local",
"password": "banned123", "password": "banned123",
} }
@@ -293,7 +294,7 @@ def banned_user_info() -> dict[str, str]:
def test_user_token(test_user_info: dict[str, str]) -> str: def test_user_token(test_user_info: dict[str, str]) -> str:
"""生成测试用户的JWT token""" """生成测试用户的JWT token"""
token, _ = JWT.create_access_token( token, _ = JWT.create_access_token(
data={"sub": test_user_info["username"]}, data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1), expires_delta=timedelta(hours=1),
) )
return token return token
@@ -303,7 +304,7 @@ def test_user_token(test_user_info: dict[str, str]) -> str:
def admin_user_token(admin_user_info: dict[str, str]) -> str: def admin_user_token(admin_user_info: dict[str, str]) -> str:
"""生成管理员的JWT token""" """生成管理员的JWT token"""
token, _ = JWT.create_access_token( token, _ = JWT.create_access_token(
data={"sub": admin_user_info["username"]}, data={"sub": admin_user_info["email"]},
expires_delta=timedelta(hours=1), expires_delta=timedelta(hours=1),
) )
return token return token
@@ -313,7 +314,7 @@ def admin_user_token(admin_user_info: dict[str, str]) -> str:
def expired_token() -> str: def expired_token() -> str:
"""生成过期的JWT token""" """生成过期的JWT token"""
token, _ = JWT.create_access_token( token, _ = JWT.create_access_token(
data={"sub": "testuser"}, data={"sub": "testuser@test.local"},
expires_delta=timedelta(seconds=-1), # 已过期 expires_delta=timedelta(seconds=-1), # 已过期
) )
return token return token
@@ -362,7 +363,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU
"""创建测试目录结构""" """创建测试目录结构"""
# 获取测试用户和根目录 # 获取测试用户和根目录
test_user = await User.get(initialized_db, User.username == "testuser") test_user = await User.get(initialized_db, User.email == "testuser@test.local")
test_user_root = await Object.get_root(initialized_db, test_user.id) test_user_root = await Object.get_root(initialized_db, test_user.id)
default_policy = await Policy.get(initialized_db, Policy.name == "本地存储") default_policy = await Policy.get(initialized_db, Policy.name == "本地存储")

View File

@@ -83,7 +83,7 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient): async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
"""测试用户不存在的token返回 401""" """测试用户不存在的token返回 401"""
token, _ = JWT.create_access_token( token, _ = JWT.create_access_token(
data={"sub": "nonexistent_user"}, data={"sub": "nonexistent_user@test.local"},
expires_delta=timedelta(hours=1) expires_delta=timedelta(hours=1)
) )
@@ -178,12 +178,12 @@ async def test_auth_on_directory_endpoint(
): ):
"""测试目录端点应用认证""" """测试目录端点应用认证"""
# 无认证 # 无认证
response_no_auth = await async_client.get("/api/directory/testuser") response_no_auth = await async_client.get("/api/directory/")
assert response_no_auth.status_code == 401 assert response_no_auth.status_code == 401
# 有认证 # 有认证
response_with_auth = await async_client.get( response_with_auth = await async_client.get(
"/api/directory/testuser", "/api/directory/",
headers=auth_headers headers=auth_headers
) )
assert response_with_auth.status_code == 200 assert response_with_auth.status_code == 200
@@ -235,7 +235,7 @@ async def test_auth_on_storage_endpoint(
async def test_refresh_token_format(test_user_info: dict[str, str]): async def test_refresh_token_format(test_user_info: dict[str, str]):
"""测试刷新token格式正确""" """测试刷新token格式正确"""
refresh_token, _ = JWT.create_refresh_token( refresh_token, _ = JWT.create_refresh_token(
data={"sub": test_user_info["username"]}, data={"sub": test_user_info["email"]},
expires_delta=timedelta(days=7) expires_delta=timedelta(days=7)
) )
@@ -247,7 +247,7 @@ async def test_refresh_token_format(test_user_info: dict[str, str]):
async def test_access_token_format(test_user_info: dict[str, str]): async def test_access_token_format(test_user_info: dict[str, str]):
"""测试访问token格式正确""" """测试访问token格式正确"""
access_token, expires = JWT.create_access_token( access_token, expires = JWT.create_access_token(
data={"sub": test_user_info["username"]}, data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1) expires_delta=timedelta(hours=1)
) )

View File

@@ -3,14 +3,14 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_initialize_db(): async def test_initialize_db():
"""测试创建数据库结构""" """测试创建数据库结构"""
from models import database from sqlmodels import database
await database.init_db(url='sqlite:///:memory:') await database.init_db(url='sqlite:///:memory:')
@pytest.fixture @pytest.fixture
async def db_session(): async def db_session():
"""测试获取数据库连接Session""" """测试获取数据库连接Session"""
from models import database from sqlmodels import database
await database.init_db(url='sqlite:///:memory:') await database.init_db(url='sqlite:///:memory:')
@@ -20,8 +20,8 @@ async def db_session():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_migration(): async def test_migration():
"""测试数据库创建并初始化配置""" """测试数据库创建并初始化配置"""
from models import migration from sqlmodels import migration
from models import database from sqlmodels import database
await database.init_db(url='sqlite:///:memory:') await database.init_db(url='sqlite:///:memory:')

View File

@@ -3,8 +3,8 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_group_curd(): async def test_group_curd():
"""测试数据库的增删改查""" """测试数据库的增删改查"""
from models import database, migration from sqlmodels import database, migration
from models.group import Group from sqlmodels.group import Group
await database.init_db(url='sqlite+aiosqlite:///:memory:') await database.init_db(url='sqlite+aiosqlite:///:memory:')

View File

@@ -3,8 +3,8 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_settings_curd(): async def test_settings_curd():
"""测试数据库的增删改查""" """测试数据库的增删改查"""
from models import database from sqlmodels import database
from models.setting import Setting from sqlmodels.setting import Setting
await database.init_db(url='sqlite:///:memory:') await database.init_db(url='sqlite:///:memory:')

View File

@@ -3,9 +3,9 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_curd(): async def test_user_curd():
"""测试数据库的增删改查""" """测试数据库的增删改查"""
from models import database, migration from sqlmodels import database, migration
from models.group import Group from sqlmodels.group import Group
from models.user import User from sqlmodels.user import User
await database.init_db(url='sqlite+aiosqlite:///:memory:') await database.init_db(url='sqlite+aiosqlite:///:memory:')
@@ -17,7 +17,7 @@ async def test_user_curd():
created_group = await test_user_group.save(session) created_group = await test_user_group.save(session)
test_user = User( test_user = User(
username='test_user', email='test_user@test.local',
password='test_password', password='test_password',
group_id=created_group.id group_id=created_group.id
) )
@@ -27,7 +27,7 @@ async def test_user_curd():
# 验证用户是否存在 # 验证用户是否存在
assert created_user.id is not None assert created_user.id is not None
assert created_user.username == 'test_user' assert created_user.email == 'test_user@test.local'
assert created_user.password == 'test_password' assert created_user.password == 'test_password'
assert created_user.group_id == created_group.id assert created_user.group_id == created_group.id
@@ -35,18 +35,18 @@ async def test_user_curd():
fetched_user = await User.get(session, User.id == created_user.id) fetched_user = await User.get(session, User.id == created_user.id)
assert fetched_user is not None assert fetched_user is not None
assert fetched_user.username == 'test_user' assert fetched_user.email == 'test_user@test.local'
assert fetched_user.password == 'test_password' assert fetched_user.password == 'test_password'
assert fetched_user.group_id == created_group.id assert fetched_user.group_id == created_group.id
# 测试改 Update # 测试改 Update
updated_user = await fetched_user.update( updated_user = await fetched_user.update(
session, session,
{"username": "updated_user", "password": "updated_password"} {"email": "updated_user@test.local", "password": "updated_password"}
) )
assert updated_user is not None assert updated_user is not None
assert updated_user.username == 'updated_user' assert updated_user.email == 'updated_user@test.local'
assert updated_user.password == 'updated_password' assert updated_user.password == 'updated_password'
# 测试删除 Delete # 测试删除 Delete

View File

@@ -8,8 +8,8 @@ import pytest
from fastapi import HTTPException from fastapi import HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User from sqlmodels.user import User
from models.group import Group from sqlmodels.group import Group
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession):
group = await group.save(db_session) group = await group.save(db_session)
# 更新数据 # 更新数据
from models.group import GroupBase from sqlmodels.group import GroupBase
update_data = GroupBase(name="更新后名称") update_data = GroupBase(name="更新后名称")
updated_group = await group.update(db_session, update_data) updated_group = await group.update(db_session, update_data)
@@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession):
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
# 更新记录 # 更新记录
from models.group import GroupBase from sqlmodels.group import GroupBase
update_data = GroupBase(name="更新后的名称") update_data = GroupBase(name="更新后的名称")
group = await group.update(db_session, update_data) group = await group.update(db_session, update_data)

View File

@@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.group import Group, GroupOptions, GroupResponse from sqlmodels.group import Group, GroupOptions, GroupResponse
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -5,21 +5,21 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.object import Object, ObjectType from sqlmodels.object import Object, ObjectType
from models.user import User from sqlmodels.user import User
from models.group import Group from sqlmodels.group import Group
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_create_folder(db_session: AsyncSession): async def test_object_create_folder(db_session: AsyncSession):
"""测试创建目录""" """测试创建目录"""
# 创建必要的依赖数据 # 创建必要的依赖数据
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id) user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy( policy = Policy(
@@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_create_file(db_session: AsyncSession): async def test_object_create_file(db_session: AsyncSession):
"""测试创建文件""" """测试创建文件"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id) user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy( policy = Policy(
@@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession):
# 创建根目录 # 创建根目录
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession):
owner_id=user.id, owner_id=user.id,
policy_id=policy.id, policy_id=policy.id,
size=1024, size=1024,
source_name="test_source.txt"
) )
file = await file.save(db_session) file = await file.save(db_session)
@@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession):
assert file.name == "test.txt" assert file.name == "test.txt"
assert file.type == ObjectType.FILE assert file.type == ObjectType.FILE
assert file.size == 1024 assert file.size == 1024
assert file.source_name == "test_source.txt"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_is_file_property(db_session: AsyncSession): async def test_object_is_file_property(db_session: AsyncSession):
"""测试 is_file 属性""" """测试 is_file 属性"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id) user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_is_folder_property(db_session: AsyncSession): async def test_object_is_folder_property(db_session: AsyncSession):
"""测试 is_folder 属性""" """测试 is_folder 属性"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id) user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_get_root(db_session: AsyncSession): async def test_object_get_root(db_session: AsyncSession):
"""测试 get_root() 方法""" """测试 get_root() 方法"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="rootuser", password="password", group_id=group.id) user = User(email="rootuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession):
# 创建根目录 # 创建根目录
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_get_by_path_root(db_session: AsyncSession): async def test_object_get_by_path_root(db_session: AsyncSession):
"""测试获取根目录""" """测试获取根目录"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="pathuser", password="password", group_id=group.id) user = User(email="pathuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
# 创建根目录 # 创建根目录
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
root = await root.save(db_session) root = await root.save(db_session)
# 通过路径获取根目录 # 通过路径获取根目录
result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username) result = await Object.get_by_path(db_session, user.id, "/")
assert result is not None assert result is not None
assert result.id == root.id assert result.id == root.id
@@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_get_by_path_nested(db_session: AsyncSession): async def test_object_get_by_path_nested(db_session: AsyncSession):
"""测试获取嵌套路径""" """测试获取嵌套路径"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="nesteduser", password="password", group_id=group.id) user = User(email="nesteduser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
# 创建目录结构: root -> docs -> work -> project # 创建目录结构: root -> docs -> work -> project
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
result = await Object.get_by_path( result = await Object.get_by_path(
db_session, db_session,
user.id, user.id,
"/nesteduser/docs/work/project", "/docs/work/project",
user.username
) )
assert result is not None assert result is not None
@@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_get_by_path_not_found(db_session: AsyncSession): async def test_object_get_by_path_not_found(db_session: AsyncSession):
"""测试路径不存在""" """测试路径不存在"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="notfounduser", password="password", group_id=group.id) user = User(email="notfounduser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
# 创建根目录 # 创建根目录
root = Object( root = Object(
name=user.username, name="/",
type=ObjectType.FOLDER, type=ObjectType.FOLDER,
parent_id=None, parent_id=None,
owner_id=user.id, owner_id=user.id,
@@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
result = await Object.get_by_path( result = await Object.get_by_path(
db_session, db_session,
user.id, user.id,
"/notfounduser/nonexistent", "/nonexistent",
user.username
) )
assert result is None assert result is None
@@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_get_children(db_session: AsyncSession): async def test_object_get_children(db_session: AsyncSession):
"""测试 get_children() 方法""" """测试 get_children() 方法"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="childrenuser", password="password", group_id=group.id) user = User(email="childrenuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_parent_child_relationship(db_session: AsyncSession): async def test_object_parent_child_relationship(db_session: AsyncSession):
"""测试父子关系""" """测试父子关系"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="reluser", password="password", group_id=group.id) user = User(email="reluser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_object_unique_constraint(db_session: AsyncSession): async def test_object_unique_constraint(db_session: AsyncSession):
"""测试同目录名称唯一约束""" """测试同目录名称唯一约束"""
from models.policy import Policy, PolicyType from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组") group = Group(name="测试组")
group = await group.save(db_session) group = await group.save(db_session)
user = User(username="uniqueuser", password="password", group_id=group.id) user = User(email="uniqueuser", password="password", group_id=group.id)
user = await user.save(db_session) user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test") policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession):
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
await file2.save(db_session) await file2.save(db_session)
@pytest.mark.asyncio
async def test_object_get_full_path(db_session: AsyncSession):
"""测试 get_full_path() 方法"""
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(email="pathuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
policy = await policy.save(db_session)
# 创建目录结构: root -> docs -> images -> photo.jpg
root = Object(
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
policy_id=policy.id
)
root = await root.save(db_session)
docs = Object(
name="docs",
type=ObjectType.FOLDER,
parent_id=root.id,
owner_id=user.id,
policy_id=policy.id
)
docs = await docs.save(db_session)
images = Object(
name="images",
type=ObjectType.FOLDER,
parent_id=docs.id,
owner_id=user.id,
policy_id=policy.id
)
images = await images.save(db_session)
photo = Object(
name="photo.jpg",
type=ObjectType.FILE,
parent_id=images.id,
owner_id=user.id,
policy_id=policy.id,
size=2048
)
photo = await photo.save(db_session)
# 测试完整路径
full_path = await photo.get_full_path(db_session)
assert full_path == "/docs/images/photo.jpg"
# 测试根目录的 full_path
root_path = await root.get_full_path(db_session)
assert root_path == "/"

View File

@@ -5,7 +5,7 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.setting import Setting, SettingsType from sqlmodels.setting import Setting, SettingsType
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
setting = await setting.save(db_session) setting = await setting.save(db_session)
# 更新值 # 更新值
from models.base import SQLModelBase from sqlmodels.base import SQLModelBase
class SettingUpdate(SQLModelBase): class SettingUpdate(SQLModelBase):
value: str | None = None value: str | None = None

View File

@@ -0,0 +1,273 @@
"""
DiskNextURI 模型的单元测试
"""
import pytest
from sqlmodels.uri import DiskNextURI, FileSystemNamespace
class TestDiskNextURIParse:
"""测试 URI 解析"""
def test_parse_my_root(self):
"""测试解析个人空间根目录"""
uri = DiskNextURI.parse("disknext://my/")
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/"
assert uri.fs_id is None
assert uri.password is None
assert uri.is_root is True
def test_parse_my_with_path(self):
"""测试解析个人空间带路径"""
uri = DiskNextURI.parse("disknext://my/docs/readme.md")
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/docs/readme.md"
assert uri.fs_id is None
assert uri.path_parts == ["docs", "readme.md"]
assert uri.is_root is False
def test_parse_my_with_fs_id(self):
"""测试解析带 fs_id 的个人空间"""
uri = DiskNextURI.parse("disknext://some-uuid@my/docs")
assert uri.namespace == FileSystemNamespace.MY
assert uri.fs_id == "some-uuid"
assert uri.path == "/docs"
def test_parse_share_with_code(self):
"""测试解析分享链接"""
uri = DiskNextURI.parse("disknext://abc123@share/")
assert uri.namespace == FileSystemNamespace.SHARE
assert uri.fs_id == "abc123"
assert uri.path == "/"
assert uri.password is None
def test_parse_share_with_password(self):
"""测试解析带密码的分享链接"""
uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir")
assert uri.namespace == FileSystemNamespace.SHARE
assert uri.fs_id == "abc123"
assert uri.password == "mypass"
assert uri.path == "/sub/dir"
def test_parse_trash(self):
"""测试解析回收站"""
uri = DiskNextURI.parse("disknext://trash/")
assert uri.namespace == FileSystemNamespace.TRASH
assert uri.is_root is True
def test_parse_with_query(self):
"""测试解析带查询参数的 URI"""
uri = DiskNextURI.parse("disknext://my/?name=report&type=file")
assert uri.namespace == FileSystemNamespace.MY
assert uri.query is not None
assert uri.query["name"] == "report"
assert uri.query["type"] == "file"
def test_parse_invalid_scheme(self):
"""测试无效的协议前缀"""
with pytest.raises(ValueError, match="disknext://"):
DiskNextURI.parse("http://my/docs")
def test_parse_invalid_namespace(self):
"""测试无效的命名空间"""
with pytest.raises(ValueError, match="无效的命名空间"):
DiskNextURI.parse("disknext://invalid/docs")
def test_parse_no_namespace(self):
"""测试缺少命名空间"""
with pytest.raises(ValueError):
DiskNextURI.parse("disknext://")
class TestDiskNextURIBuild:
"""测试 URI 构建"""
def test_build_simple(self):
"""测试简单构建"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/"
assert uri.fs_id is None
def test_build_with_path(self):
"""测试带路径构建"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.path == "/docs/readme.md"
def test_build_path_auto_prefix(self):
"""测试路径自动添加 / 前缀"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md")
assert uri.path == "/docs/readme.md"
def test_build_with_fs_id(self):
"""测试带 fs_id 构建"""
uri = DiskNextURI.build(
FileSystemNamespace.SHARE,
fs_id="abc123",
password="secret",
)
assert uri.fs_id == "abc123"
assert uri.password == "secret"
class TestDiskNextURIToString:
"""测试 URI 序列化"""
def test_to_string_simple(self):
"""测试简单序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.to_string() == "disknext://my/"
def test_to_string_with_path(self):
"""测试带路径序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.to_string() == "disknext://my/docs/readme.md"
def test_to_string_with_fs_id(self):
"""测试带 fs_id 序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123")
assert uri.to_string() == "disknext://uuid-123@my/"
def test_to_string_with_password(self):
"""测试带密码序列化"""
uri = DiskNextURI.build(
FileSystemNamespace.SHARE,
fs_id="code",
password="pass",
)
assert uri.to_string() == "disknext://code:pass@share/"
def test_to_string_roundtrip(self):
"""测试序列化-反序列化往返"""
original = "disknext://abc123:pass@share/sub/dir"
uri = DiskNextURI.parse(original)
result = uri.to_string()
assert result == original
class TestDiskNextURIId:
"""测试 id() 方法"""
def test_id_with_fs_id(self):
"""测试有 fs_id 时返回 fs_id"""
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid")
assert uri.id("default") == "my-uuid"
def test_id_without_fs_id(self):
"""测试无 fs_id 时返回默认值"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.id("default-uuid") == "default-uuid"
def test_id_without_fs_id_no_default(self):
"""测试无 fs_id 且无默认值时返回 None"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.id() is None
class TestDiskNextURIJoin:
"""测试 join() 方法"""
def test_join_single(self):
"""测试拼接单个路径元素"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
joined = uri.join("readme.md")
assert joined.path == "/docs/readme.md"
def test_join_multiple(self):
"""测试拼接多个路径元素"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
joined = uri.join("docs", "work", "report.pdf")
assert joined.path == "/docs/work/report.pdf"
def test_join_preserves_metadata(self):
"""测试 join 保留 namespace 和 fs_id"""
uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123")
joined = uri.join("sub")
assert joined.namespace == FileSystemNamespace.SHARE
assert joined.fs_id == "code123"
class TestDiskNextURIDirUri:
"""测试 dir_uri() 方法"""
def test_dir_uri_file(self):
"""测试获取文件的父目录 URI"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
parent = uri.dir_uri()
assert parent.path == "/docs/"
def test_dir_uri_root(self):
"""测试根目录的 dir_uri 返回自身"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
parent = uri.dir_uri()
assert parent.path == "/"
class TestDiskNextURIRoot:
"""测试 root() 方法"""
def test_root_resets_path(self):
"""测试 root 重置路径"""
uri = DiskNextURI.build(
FileSystemNamespace.MY,
path="/docs/work/report.pdf",
fs_id="uuid-123",
)
root = uri.root()
assert root.path == "/"
assert root.fs_id == "uuid-123"
assert root.namespace == FileSystemNamespace.MY
class TestDiskNextURIName:
"""测试 name() 方法"""
def test_name_file(self):
"""测试获取文件名"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.name() == "readme.md"
def test_name_directory(self):
"""测试获取目录名"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work")
assert uri.name() == "work"
def test_name_root(self):
"""测试根目录的 name 返回空字符串"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.name() == ""
class TestDiskNextURIProperties:
"""测试属性方法"""
def test_path_parts(self):
"""测试路径分割"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf")
assert uri.path_parts == ["docs", "work", "report.pdf"]
def test_path_parts_root(self):
"""测试根路径分割"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.path_parts == []
def test_is_root_true(self):
"""测试 is_root 为真"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.is_root is True
def test_is_root_false(self):
"""测试 is_root 为假"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert uri.is_root is False
def test_str_representation(self):
"""测试字符串表示"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert str(uri) == "disknext://my/docs"
def test_repr(self):
"""测试 repr"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert "disknext://my/docs" in repr(uri)

View File

@@ -5,8 +5,8 @@ import pytest
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User, ThemeType, UserPublic from sqlmodels.user import User, ThemeType, UserPublic
from models.group import Group from sqlmodels.group import Group
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession):
# 创建用户 # 创建用户
user = User( user = User(
username="testuser", email="testuser@test.local",
nickname="测试用户", nickname="测试用户",
password="hashed_password", password="hashed_password",
group_id=group.id group_id=group.id
@@ -26,7 +26,7 @@ async def test_user_create(db_session: AsyncSession):
user = await user.save(db_session) user = await user.save(db_session)
assert user.id is not None assert user.id is not None
assert user.username == "testuser" assert user.email == "testuser@test.local"
assert user.nickname == "测试用户" assert user.nickname == "测试用户"
assert user.status is True assert user.status is True
assert user.storage == 0 assert user.storage == 0
@@ -34,15 +34,15 @@ async def test_user_create(db_session: AsyncSession):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user_unique_username(db_session: AsyncSession): async def test_user_unique_email(db_session: AsyncSession):
"""测试用户名唯一约束""" """测试邮箱唯一约束"""
# 创建用户组 # 创建用户组
group = Group(name="默认组") group = Group(name="默认组")
group = await group.save(db_session) group = await group.save(db_session)
# 创建第一个用户 # 创建第一个用户
user1 = User( user1 = User(
username="duplicate", email="duplicate@test.local",
password="password1", password="password1",
group_id=group.id group_id=group.id
) )
@@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession):
# 尝试创建同名用户 # 尝试创建同名用户
user2 = User( user2 = User(
username="duplicate", email="duplicate@test.local",
password="password2", password="password2",
group_id=group.id group_id=group.id
) )
@@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession):
# 创建用户 # 创建用户
user = User( user = User(
username="publicuser", email="publicuser@test.local",
nickname="公开用户", nickname="公开用户",
password="secret_password", password="secret_password",
storage=1024, storage=1024,
@@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession):
assert isinstance(public_user, UserPublic) assert isinstance(public_user, UserPublic)
assert public_user.id == user.id assert public_user.id == user.id
assert public_user.username == "publicuser" assert public_user.email == "publicuser@test.local"
# 注意: UserPublic.nick 字段名与 User.nickname 不同, # 注意: UserPublic.nick 字段名与 User.nickname 不同,
# model_validate 不会自动映射,所以 nick 为 None # model_validate 不会自动映射,所以 nick 为 None
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段 # 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
@@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession):
# 创建用户 # 创建用户
user = User( user = User(
username="vipuser", email="vipuser@test.local",
password="password", password="password",
group_id=group.id group_id=group.id
) )
@@ -125,7 +125,7 @@ async def test_user_status_default(db_session: AsyncSession):
group = await group.save(db_session) group = await group.save(db_session)
user = User( user = User(
username="defaultuser", email="defaultuser@test.local",
password="password", password="password",
group_id=group.id group_id=group.id
) )
@@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession):
group = await group.save(db_session) group = await group.save(db_session)
user = User( user = User(
username="storageuser", email="storageuser@test.local",
password="password", password="password",
group_id=group.id group_id=group.id
) )
@@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试默认值 # 测试默认值
user1 = User( user1 = User(
username="user1", email="user1@test.local",
password="password", password="password",
group_id=group.id group_id=group.id
) )
@@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 LIGHT # 测试设置为 LIGHT
user2 = User( user2 = User(
username="user2", email="user2@test.local",
password="password", password="password",
theme=ThemeType.LIGHT, theme=ThemeType.LIGHT,
group_id=group.id group_id=group.id
@@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 DARK # 测试设置为 DARK
user3 = User( user3 = User(
username="user3", email="user3@test.local",
password="password", password="password",
theme=ThemeType.DARK, theme=ThemeType.DARK,
group_id=group.id group_id=group.id

View File

@@ -4,8 +4,8 @@ Login 服务的单元测试
import pytest import pytest
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User, LoginRequest, TokenResponse from sqlmodels.user import User, LoginRequest, TokenResponse
from models.group import Group from sqlmodels.group import Group
from service.user.login import login from service.user.login import login
from utils.password.pwd import Password from utils.password.pwd import Password
@@ -20,7 +20,7 @@ async def setup_user(db_session: AsyncSession):
# 创建正常用户 # 创建正常用户
plain_password = "secure_password_123" plain_password = "secure_password_123"
user = User( user = User(
username="loginuser", email="loginuser@test.local",
password=Password.hash(plain_password), password=Password.hash(plain_password),
status=True, status=True,
group_id=group.id group_id=group.id
@@ -41,7 +41,7 @@ async def setup_banned_user(db_session: AsyncSession):
group = await group.save(db_session) group = await group.save(db_session)
user = User( user = User(
username="banneduser", email="banneduser@test.local",
password=Password.hash("password"), password=Password.hash("password"),
status=False, # 封禁状态 status=False, # 封禁状态
group_id=group.id group_id=group.id
@@ -61,7 +61,7 @@ async def setup_2fa_user(db_session: AsyncSession):
secret = pyotp.random_base32() secret = pyotp.random_base32()
user = User( user = User(
username="2fauser", email="2fauser@test.local",
password=Password.hash("password"), password=Password.hash("password"),
status=True, status=True,
two_factor=secret, two_factor=secret,
@@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
user_data = setup_user user_data = setup_user
login_request = LoginRequest( login_request = LoginRequest(
username="loginuser", email="loginuser@test.local",
password=user_data["password"] password=user_data["password"]
) )
@@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
async def test_login_user_not_found(db_session: AsyncSession): async def test_login_user_not_found(db_session: AsyncSession):
"""测试用户不存在""" """测试用户不存在"""
login_request = LoginRequest( login_request = LoginRequest(
username="nonexistent_user", email="nonexistent@test.local",
password="any_password" password="any_password"
) )
@@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
async def test_login_wrong_password(db_session: AsyncSession, setup_user): async def test_login_wrong_password(db_session: AsyncSession, setup_user):
"""测试密码错误""" """测试密码错误"""
login_request = LoginRequest( login_request = LoginRequest(
username="loginuser", email="loginuser@test.local",
password="wrong_password" password="wrong_password"
) )
@@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user): async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
"""测试用户被封禁""" """测试用户被封禁"""
login_request = LoginRequest( login_request = LoginRequest(
username="banneduser", email="banneduser@test.local",
password="password" password="password"
) )
@@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
user_data = setup_2fa_user user_data = setup_2fa_user
login_request = LoginRequest( login_request = LoginRequest(
username="2fauser", email="2fauser@test.local",
password=user_data["password"] password=user_data["password"]
# 未提供 two_fa_code # 未提供 two_fa_code
) )
@@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
user_data = setup_2fa_user user_data = setup_2fa_user
login_request = LoginRequest( login_request = LoginRequest(
username="2fauser", email="2fauser@test.local",
password=user_data["password"], password=user_data["password"],
two_fa_code="000000" # 错误的验证码 two_fa_code="000000" # 错误的验证码
) )
@@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
valid_code = totp.now() valid_code = totp.now()
login_request = LoginRequest( login_request = LoginRequest(
username="2fauser", email="2fauser@test.local",
password=user_data["password"], password=user_data["password"],
two_fa_code=valid_code two_fa_code=valid_code
) )
@@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
user_data = setup_user user_data = setup_user
login_request = LoginRequest( login_request = LoginRequest(
username="loginuser", email="loginuser@test.local",
password=user_data["password"] password=user_data["password"]
) )
@@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user): async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
"""测试用户名大小写敏感""" """测试邮箱大小写敏感"""
user_data = setup_user user_data = setup_user
# 使用大写用户名登录(如果数据库是 loginuser # 使用大写邮箱登录
login_request = LoginRequest( login_request = LoginRequest(
username="LOGINUSER", email="LOGINUSER@TEST.LOCAL",
password=user_data["password"] password=user_data["password"]
) )
result = await login(db_session, login_request) result = await login(db_session, login_request)
# 应该失败,因为用户名大小写不匹配 # 应该失败,因为邮箱大小写不匹配
assert result is None assert result is None

View File

@@ -72,9 +72,9 @@ def test_password_verify_expired():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_totp_generate(): async def test_totp_generate():
"""测试 TOTP 密钥生成""" """测试 TOTP 密钥生成"""
username = "testuser" email = "testuser@test.local"
response = await Password.generate_totp(username) response = await Password.generate_totp(email)
assert response.setup_token is not None assert response.setup_token is not None
assert response.uri is not None assert response.uri is not None
@@ -82,7 +82,7 @@ async def test_totp_generate():
assert isinstance(response.uri, str) assert isinstance(response.uri, str)
# TOTP URI 格式: otpauth://totp/... # TOTP URI 格式: otpauth://totp/...
assert response.uri.startswith("otpauth://totp/") assert response.uri.startswith("otpauth://totp/")
assert username in response.uri assert email in response.uri
def test_totp_verify_valid(): def test_totp_verify_valid():

View File

@@ -4,7 +4,7 @@ from uuid import UUID, uuid4
import jwt import jwt
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from models import AccessTokenBase, RefreshTokenBase from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌', scheme_name='获取 JWT Bearer 令牌',
@@ -21,8 +21,8 @@ async def load_secret_key() -> None:
从数据库读取 JWT 的密钥。 从数据库读取 JWT 的密钥。
""" """
# 延迟导入以避免循环依赖 # 延迟导入以避免循环依赖
from models.database import get_session from sqlmodels.database import get_session
from models.setting import Setting from sqlmodels.setting import Setting
global SECRET_KEY global SECRET_KEY
async for session in get_session(): async for session in get_session():
@@ -69,20 +69,30 @@ def build_token_payload(
# 访问令牌 # 访问令牌
def create_access_token( def create_access_token(
data: dict, sub: UUID,
jti: UUID,
expires_delta: timedelta | None = None, expires_delta: timedelta | None = None,
algorithm: str = "HS256" algorithm: str = "HS256",
**kwargs
) -> AccessTokenBase: ) -> AccessTokenBase:
""" """
生成访问令牌,默认有效期 3 小时。 生成访问令牌,默认有效期 3 小时。
:param data: 需要放进 JWT Payload 的字段 :param sub: 令牌的主题,通常是用户 ID
:param jti: 令牌的唯一标识符,通常是一个 UUID。
:param expires_delta: 过期时间, 缺省时为 3 小时。 :param expires_delta: 过期时间, 缺省时为 3 小时。
:param algorithm: JWT 密钥强度,缺省时为 HS256 :param algorithm: JWT 密钥强度,缺省时为 HS256
:param kwargs: 需要放进 JWT Payload 的字段。
:return: 包含密钥本身和过期时间的 `AccessTokenBase` :return: 包含密钥本身和过期时间的 `AccessTokenBase`
""" """
data = {"sub": str(sub), "jti": str(jti)}
# 将额外的字段添加到 Payload 中
for key, value in kwargs.items():
data[key] = value
access_token, expire_at = build_token_payload( access_token, expire_at = build_token_payload(
data, data,
False, False,
@@ -97,20 +107,30 @@ def create_access_token(
# 刷新令牌 # 刷新令牌
def create_refresh_token( def create_refresh_token(
data: dict, sub: UUID,
jti: UUID,
expires_delta: timedelta | None = None, expires_delta: timedelta | None = None,
algorithm: str = "HS256" algorithm: str = "HS256",
**kwargs,
) -> RefreshTokenBase: ) -> RefreshTokenBase:
""" """
生成刷新令牌,默认有效期 30 天。 生成刷新令牌,默认有效期 30 天。
:param data: 需要放进 JWT Payload 的字段 :param sub: 令牌的主题,通常是用户 ID
:param jti: 令牌的唯一标识符,通常是一个 UUID。
:param expires_delta: 过期时间, 缺省时为 30 天。 :param expires_delta: 过期时间, 缺省时为 30 天。
:param algorithm: JWT 密钥强度,缺省时为 HS256 :param algorithm: JWT 密钥强度,缺省时为 HS256
:param kwargs: 需要放进 JWT Payload 的字段。
:return: 包含密钥本身和过期时间的 `RefreshTokenBase` :return: 包含密钥本身和过期时间的 `RefreshTokenBase`
""" """
data = {"sub": str(sub), "jti": str(jti)}
# 将额外的字段添加到 Payload 中
for key, value in kwargs.items():
data[key] = value
refresh_token, expire_at = build_token_payload( refresh_token, expire_at = build_token_payload(
data, data,
True, True,

View File

@@ -28,6 +28,10 @@ def raise_forbidden(detail: str | None = None, *args, **kwargs) -> NoReturn:
"""Raises an HTTP 403 Forbidden exception.""" """Raises an HTTP 403 Forbidden exception."""
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
def raise_banned(detail: str = "此文件已被管理员封禁,仅允许删除操作", *args, **kwargs) -> NoReturn:
"""Raises an HTTP 403 Forbidden exception for banned objects."""
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
def raise_not_found(detail: str | None = None, *args, **kwargs) -> NoReturn: def raise_not_found(detail: str | None = None, *args, **kwargs) -> NoReturn:
"""Raises an HTTP 404 Not Found exception.""" """Raises an HTTP 404 Not Found exception."""
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail, *args, **kwargs) raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail, *args, **kwargs)

View File

@@ -73,6 +73,8 @@ class Password:
:param length: 密码长度 :param length: 密码长度
:type length: int :type length: int
:param url_safe: 是否生成 URL 安全的密码
:type url_safe: bool
:return: 随机密码 :return: 随机密码
:rtype: str :rtype: str
""" """