- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
471 lines
17 KiB
Python
471 lines
17 KiB
Python
"""
|
||
关系预加载 Mixin
|
||
|
||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||
|
||
设计原则:
|
||
- 按需加载:只加载被调用方法需要的关系
|
||
- 增量加载:已加载的关系不重复加载
|
||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||
- 零侵入:调用方无需任何改动
|
||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||
|
||
使用方式:
|
||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||
|
||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||
|
||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||
async def cost(self, params, context, session) -> ToolCost:
|
||
# 自动加载,可以安全访问
|
||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||
...
|
||
|
||
# 调用方 - 无需任何改动
|
||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||
|
||
支持 AsyncGenerator:
|
||
@requires_relations('twitter_api')
|
||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||
"""
|
||
import inspect as python_inspect
|
||
from functools import wraps
|
||
from typing import Callable, TypeVar, ParamSpec, Any
|
||
|
||
from loguru import logger as l
|
||
from sqlalchemy import inspect as sa_inspect
|
||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
from sqlmodel.main import RelationshipInfo
|
||
|
||
P = ParamSpec('P')
|
||
R = TypeVar('R')
|
||
|
||
|
||
def _extract_session(
|
||
func: Callable,
|
||
args: tuple[Any, ...],
|
||
kwargs: dict[str, Any],
|
||
) -> AsyncSession | None:
|
||
"""
|
||
从方法参数中提取 AsyncSession
|
||
|
||
按以下顺序查找:
|
||
1. kwargs 中名为 'session' 的参数
|
||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||
3. kwargs 中类型为 AsyncSession 的参数
|
||
"""
|
||
# 1. 优先从 kwargs 查找
|
||
if 'session' in kwargs:
|
||
return kwargs['session']
|
||
|
||
# 2. 从函数签名定位位置参数
|
||
try:
|
||
sig = python_inspect.signature(func)
|
||
param_names = list(sig.parameters.keys())
|
||
|
||
if 'session' in param_names:
|
||
# 计算位置(减去 self)
|
||
idx = param_names.index('session') - 1
|
||
if 0 <= idx < len(args):
|
||
return args[idx]
|
||
except (ValueError, TypeError):
|
||
pass
|
||
|
||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||
for value in kwargs.values():
|
||
if isinstance(value, AsyncSession):
|
||
return value
|
||
|
||
return None
|
||
|
||
|
||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||
"""
|
||
检查对象的关系是否已加载(独立函数版本)
|
||
|
||
Args:
|
||
obj: 要检查的对象
|
||
rel_name: 关系属性名
|
||
|
||
Returns:
|
||
True 如果关系已加载,False 如果未加载或已过期
|
||
"""
|
||
try:
|
||
state = sa_inspect(obj)
|
||
return rel_name not in state.unloaded
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||
"""
|
||
在类中查找指向目标类的关系属性名
|
||
|
||
Args:
|
||
from_class: 源类
|
||
to_class: 目标类
|
||
|
||
Returns:
|
||
关系属性名,如果找不到则返回 None
|
||
|
||
Example:
|
||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||
# 返回 'kling_video_generator'
|
||
"""
|
||
for attr_name in dir(from_class):
|
||
try:
|
||
attr = getattr(from_class, attr_name, None)
|
||
if attr is None:
|
||
continue
|
||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||
target_class = attr.property.mapper.class_
|
||
if target_class == to_class:
|
||
return attr_name
|
||
except AttributeError:
|
||
continue
|
||
return None
|
||
|
||
|
||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||
"""
|
||
装饰器:声明方法需要的关系,自动按需增量加载
|
||
|
||
参数格式:
|
||
- 字符串:本类属性名,如 'kling_video_generator'
|
||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||
|
||
行为:
|
||
- 方法调用时自动检查关系是否已加载
|
||
- 未加载的关系会被增量加载(单次查询)
|
||
- 已加载的关系直接跳过
|
||
|
||
支持:
|
||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||
|
||
Example:
|
||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||
async def cost(self, params, context, session) -> ToolCost:
|
||
# self.kling_video_generator.kling_o1 已自动加载
|
||
...
|
||
|
||
@requires_relations('twitter_api')
|
||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||
|
||
验证:
|
||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||
- 拼写错误会在导入时抛出 AttributeError
|
||
"""
|
||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||
# 检测是否是 async generator 函数
|
||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||
|
||
if is_async_gen:
|
||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||
@wraps(func)
|
||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||
session = _extract_session(func, args, kwargs)
|
||
if session is not None:
|
||
await self._ensure_relations_loaded(session, relations)
|
||
# 委托给原始 async generator,逐个 yield 值
|
||
async for item in func(self, *args, **kwargs):
|
||
yield item # type: ignore
|
||
else:
|
||
# 普通 async 函数:await 并返回结果
|
||
@wraps(func)
|
||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||
session = _extract_session(func, args, kwargs)
|
||
if session is not None:
|
||
await self._ensure_relations_loaded(session, relations)
|
||
return await func(self, *args, **kwargs)
|
||
|
||
# 保存关系声明供验证和内省使用
|
||
wrapper._required_relations = relations # type: ignore
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
class RelationPreloadMixin:
|
||
"""
|
||
关系预加载 Mixin
|
||
|
||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||
|
||
特性:
|
||
- 按需加载:只加载被调用方法需要的关系
|
||
- 增量加载:已加载的关系不重复加载
|
||
- 原地更新:直接修改 self,无需替换实例
|
||
- 导入时验证:字符串关系名在类创建时验证
|
||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||
"""
|
||
|
||
def __init_subclass__(cls, **kwargs) -> None:
|
||
"""类创建时验证所有 @requires_relations 声明"""
|
||
super().__init_subclass__(**kwargs)
|
||
|
||
# 收集类及其父类的所有注解(包含普通字段)
|
||
all_annotations: set[str] = set()
|
||
for klass in cls.__mro__:
|
||
if hasattr(klass, '__annotations__'):
|
||
all_annotations.update(klass.__annotations__.keys())
|
||
|
||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||
sqlmodel_relationships: set[str] = set()
|
||
for klass in cls.__mro__:
|
||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||
|
||
# 合并所有可用的属性名
|
||
all_available_names = all_annotations | sqlmodel_relationships
|
||
|
||
for method_name in dir(cls):
|
||
if method_name.startswith('__'):
|
||
continue
|
||
|
||
try:
|
||
method = getattr(cls, method_name, None)
|
||
except AttributeError:
|
||
continue
|
||
|
||
if method is None or not hasattr(method, '_required_relations'):
|
||
continue
|
||
|
||
# 验证字符串格式的关系名
|
||
for spec in method._required_relations:
|
||
if isinstance(spec, str):
|
||
# 检查注解、Relationship 或已有属性
|
||
if spec not in all_available_names and not hasattr(cls, spec):
|
||
raise AttributeError(
|
||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||
f"但 {cls.__name__} 没有此属性"
|
||
)
|
||
|
||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||
"""
|
||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||
|
||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||
自动处理 commit 导致的 expire 问题。
|
||
|
||
Args:
|
||
rel_name: 关系属性名
|
||
|
||
Returns:
|
||
True 如果关系已加载,False 如果未加载或已过期
|
||
"""
|
||
try:
|
||
state = sa_inspect(self)
|
||
# unloaded 包含未加载的关系属性名
|
||
return rel_name not in state.unloaded
|
||
except Exception:
|
||
# 对象可能未被 SQLAlchemy 管理
|
||
return False
|
||
|
||
async def _ensure_relations_loaded(
|
||
self,
|
||
session: AsyncSession,
|
||
relations: tuple[str | RelationshipInfo, ...],
|
||
) -> None:
|
||
"""
|
||
确保指定关系已加载,只加载未加载的部分
|
||
|
||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||
- 首次访问的关系
|
||
- commit 后 expire 的关系
|
||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
relations: 需要的关系规格
|
||
"""
|
||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||
to_load: list[str | RelationshipInfo] = []
|
||
# 区分直接关系和嵌套关系的 key
|
||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||
|
||
for rel in relations:
|
||
if isinstance(rel, str):
|
||
# 直接关系:检查本类的关系是否已加载
|
||
if not self._is_relation_loaded(rel):
|
||
to_load.append(rel)
|
||
direct_keys.add(rel)
|
||
else:
|
||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||
# 1. 查找指向父类的关系属性
|
||
parent_class = rel.parent.class_
|
||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||
|
||
if parent_attr is None:
|
||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||
l.warning(
|
||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||
f"无法检查 {rel.key} 是否已加载"
|
||
)
|
||
to_load.append(rel)
|
||
continue
|
||
|
||
# 2. 检查父对象是否已加载
|
||
if not self._is_relation_loaded(parent_attr):
|
||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||
to_load.append(parent_attr)
|
||
nested_parent_keys.add(parent_attr)
|
||
to_load.append(rel)
|
||
else:
|
||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||
parent_obj = getattr(self, parent_attr)
|
||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||
to_load.append(parent_attr)
|
||
nested_parent_keys.add(parent_attr)
|
||
to_load.append(rel)
|
||
|
||
if not to_load:
|
||
return # 全部已加载,跳过
|
||
|
||
# 构建 load 参数
|
||
load_options = self._specs_to_load_options(to_load)
|
||
if not load_options:
|
||
return
|
||
|
||
# 安全地获取主键值(避免触发懒加载)
|
||
state = sa_inspect(self)
|
||
pk_tuple = state.key[1] if state.key else None
|
||
if pk_tuple is None:
|
||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||
return
|
||
# 主键是元组,取第一个值(假设单列主键)
|
||
pk_value = pk_tuple[0]
|
||
|
||
# 单次查询加载缺失的关系
|
||
fresh = await self.__class__.get(
|
||
session,
|
||
self.__class__.id == pk_value,
|
||
load=load_options,
|
||
)
|
||
|
||
if fresh is None:
|
||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||
return
|
||
|
||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||
all_direct_keys = direct_keys | nested_parent_keys
|
||
for key in all_direct_keys:
|
||
value = getattr(fresh, key, None)
|
||
object.__setattr__(self, key, value)
|
||
|
||
def _specs_to_load_options(
|
||
self,
|
||
specs: list[str | RelationshipInfo],
|
||
) -> list[RelationshipInfo]:
|
||
"""
|
||
将关系规格转换为 load 参数
|
||
|
||
- 字符串 → cls.{name}
|
||
- RelationshipInfo → 直接使用
|
||
"""
|
||
result: list[RelationshipInfo] = []
|
||
|
||
for spec in specs:
|
||
if isinstance(spec, str):
|
||
rel = getattr(self.__class__, spec, None)
|
||
if rel is not None:
|
||
result.append(rel)
|
||
else:
|
||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||
else:
|
||
result.append(spec)
|
||
|
||
return result
|
||
|
||
# ==================== 可选的手动预加载 API ====================
|
||
|
||
@classmethod
|
||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||
"""
|
||
获取指定方法声明的关系(用于外部预加载场景)
|
||
|
||
Args:
|
||
method_name: 方法名
|
||
|
||
Returns:
|
||
RelationshipInfo 列表
|
||
"""
|
||
method = getattr(cls, method_name, None)
|
||
if method is None or not hasattr(method, '_required_relations'):
|
||
return []
|
||
|
||
result: list[RelationshipInfo] = []
|
||
for spec in method._required_relations:
|
||
if isinstance(spec, str):
|
||
rel = getattr(cls, spec, None)
|
||
if rel:
|
||
result.append(rel)
|
||
else:
|
||
result.append(spec)
|
||
|
||
return result
|
||
|
||
@classmethod
|
||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||
"""
|
||
获取多个方法的关系并去重(用于批量预加载场景)
|
||
|
||
Args:
|
||
method_names: 方法名列表
|
||
|
||
Returns:
|
||
去重后的 RelationshipInfo 列表
|
||
"""
|
||
seen: set[str] = set()
|
||
result: list[RelationshipInfo] = []
|
||
|
||
for method_name in method_names:
|
||
for rel in cls.get_relations_for_method(method_name):
|
||
key = rel.key
|
||
if key not in seen:
|
||
seen.add(key)
|
||
result.append(rel)
|
||
|
||
return result
|
||
|
||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||
"""
|
||
手动预加载指定方法的关系(可选优化 API)
|
||
|
||
当需要确保在调用方法前完成所有加载时使用。
|
||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||
|
||
Args:
|
||
session: 数据库会话
|
||
method_names: 方法名列表
|
||
|
||
Returns:
|
||
self(支持链式调用)
|
||
|
||
Example:
|
||
# 可选:显式预加载(通常不需要)
|
||
tool = await tool.preload_for(session, 'cost', '_call')
|
||
"""
|
||
all_relations: list[str | RelationshipInfo] = []
|
||
|
||
for method_name in method_names:
|
||
method = getattr(self.__class__, method_name, None)
|
||
if method and hasattr(method, '_required_relations'):
|
||
all_relations.extend(method._required_relations)
|
||
|
||
if all_relations:
|
||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||
|
||
return self
|