feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
470
sqlmodels/mixin/relation_preload.py
Normal file
470
sqlmodels/mixin/relation_preload.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
Reference in New Issue
Block a user