Compare commits
2 Commits
62c671e07b
...
a99091ea7a
| Author | SHA1 | Date | |
|---|---|---|---|
| a99091ea7a | |||
| 209cb24ab4 |
@@ -3,7 +3,9 @@
|
||||
"allow": [
|
||||
"Bash(git rev-parse:*)",
|
||||
"Bash(findstr:*)",
|
||||
"Bash(find:*)"
|
||||
"Bash(find:*)",
|
||||
"Bash(yarn tsc:*)",
|
||||
"Bash(dir:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
92
docs/CLA.md
Normal file
92
docs/CLA.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# DiskNext Contributor License Agreement
|
||||
|
||||
Thank you for your interest in contributing to the DiskNext project ("We", "Us", or "Our"). This Contributor License Agreement ("Agreement") is for our mutual protection. It clarifies the intellectual property rights You grant to Us for Your Contributions.
|
||||
|
||||
By signing this Agreement, You accept its terms and conditions.
|
||||
|
||||
## 1. The Purpose of This Agreement
|
||||
|
||||
The DiskNext project is developed with a dual-licensing strategy. We maintain a free, open-source community edition alongside a commercial Pro edition. This model allows Us to support a vibrant community while also funding the project's sustainable development.
|
||||
|
||||
To make this model work, We require broad rights to use the code You contribute. This Agreement ensures that We can include Your Contributions in all editions of DiskNext under their respective licenses. By signing this Agreement, You grant Us the rights needed to manage the project effectively, including the right to incorporate Your Contribution into Our commercial products and to transfer the project to another entity.
|
||||
|
||||
## 2. Definitions
|
||||
|
||||
**"You"** means the individual copyright owner who Submits a Contribution to Us.
|
||||
|
||||
**"Contribution"** means any original work of authorship, including any modifications or additions to an existing work, that you intentionally Submit to Us for inclusion in the Material.
|
||||
|
||||
**"Material"** means the software and documentation We make available to third parties. Your Contribution may be included in the Material.
|
||||
|
||||
**"Submit"** means any form of communication sent to Us (e.g., via a pull request, issue tracker, or email) that is managed by Us for the purpose of discussing and improving the Material, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution."
|
||||
|
||||
**"Copyright"** means all rights protecting works of authorship, including copyright, moral rights, and neighboring rights, for the full term of their existence.
|
||||
|
||||
## 3. Copyright License Grant
|
||||
|
||||
Subject to the terms and conditions of this Agreement, You hereby grant to Us a worldwide, royalty-free, **non-exclusive**, perpetual, and irrevocable license under the Copyright covering your Contribution. This license includes the right to sublicense and to assign Your Contribution.
|
||||
|
||||
This license allows Us to use, reproduce, prepare derivative works of, publicly display, publicly perform, distribute, and publish your Contribution and such derivative works in any form. This includes, without limitation, the right to sell and distribute the Contribution as part of a commercial product under a proprietary license.
|
||||
|
||||
You retain full ownership of the Copyright in Your Contribution. Nothing in this Agreement shall be construed to restrict or transfer Your rights to use Your own Contribution for any purpose.
|
||||
|
||||
## 4. Patent License Grant
|
||||
|
||||
You hereby grant to Us and to recipients of the Material a worldwide, royalty-free, non-exclusive, perpetual, and irrevocable patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer Your Contribution. This license applies to all patents owned or controlled by You, now or in the future, that would be infringed by Your Contribution alone or in combination with the Material.
|
||||
|
||||
## 5. Your Representations
|
||||
|
||||
You represent and warrant that:
|
||||
|
||||
1. The Contribution is Your original work.
|
||||
2. You are legally entitled to grant the licenses in this Agreement.
|
||||
3. If Your employer has rights to intellectual property that You create, You have either (i) received permission from Your employer to make the Contribution on behalf of that employer, or (ii) Your employer has waived such rights for the Contribution.
|
||||
4. To the best of Your knowledge, the Contribution does not violate any third-party rights, including copyright, patent, trademark, or trade secret.
|
||||
|
||||
You agree to notify Us of any facts or circumstances of which you become aware that would make these representations inaccurate in any respect.
|
||||
|
||||
## 6. Our Licensing Rights
|
||||
|
||||
You acknowledge that We may license the Material, including Your Contribution, under different license terms. We intend to distribute a community edition of DiskNext under a free and open-source license. We also reserve the right to distribute a Pro edition and other commercial versions of the Material, including Your Contribution, under a proprietary license at Our sole discretion.
|
||||
|
||||
## 7. Disclaimer of Warranty
|
||||
|
||||
THE CONTRIBUTION IS PROVIDED "AS IS" AND WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
|
||||
## 8. Limitation of Liability
|
||||
|
||||
TO THE MAXIMUM EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT WILL YOU OR WE BE LIABLE FOR ANY LOSS OF PROFITS, LOSS OF ANTICIPATED SAVINGS, LOSS OF DATA, OR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF THIS AGREEMENT, REGARDLESS OF THE LEGAL THEORY UPON WHICH THE CLAIM IS BASED.
|
||||
|
||||
## 9. Term
|
||||
|
||||
This Agreement is effective on the date You accept it and shall continue for the full term of the copyrights and patents licensed herein. This Agreement is irrevocable.
|
||||
|
||||
## 10. Miscellaneous
|
||||
|
||||
**10.1 Governing Law:** This Agreement shall be governed by the laws of the People's Republic of China, excluding its conflict of law provisions.
|
||||
|
||||
**10.2 Entire Agreement:** This Agreement sets out the entire agreement between You and Us for Your Contributions and supersedes all prior communications and understandings.
|
||||
|
||||
**10.3 Assignment:** We may assign Our rights and obligations under this Agreement at Our sole discretion. This Agreement will be binding upon and will inure to the benefit of the parties, their successors, and permitted assigns.
|
||||
|
||||
**10.4 Severability:** If any provision of this Agreement is found to be void or unenforceable, it will be replaced with a provision that comes closest to the meaning of the original and is enforceable.
|
||||
|
||||
---
|
||||
|
||||
## To Accept This Agreement
|
||||
|
||||
Please provide the following information to signify your acceptance.
|
||||
|
||||
### Contributor ("You"):
|
||||
|
||||
- **Date:**
|
||||
- **Full Name:**
|
||||
- **Address:**
|
||||
- **Email:**
|
||||
- **GitHub Username (if applicable):**
|
||||
|
||||
### For DiskNext ("Us"):
|
||||
|
||||
- **Date:**
|
||||
- **[NAME]**
|
||||
- **Owner of DiskNext Org**
|
||||
12
main.py
12
main.py
@@ -1,25 +1,29 @@
|
||||
from typing import NoReturn
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from utils.conf import appmeta
|
||||
from utils.http.http_exceptions import raise_internal_error
|
||||
from utils.lifespan import lifespan
|
||||
from models.database import init_db
|
||||
from models.migration import migration
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.migration import migration
|
||||
from utils import JWT
|
||||
from routers import router
|
||||
from service.redis import RedisManager
|
||||
from loguru import logger as l
|
||||
|
||||
async def _init_db() -> None:
|
||||
"""初始化数据库连接引擎"""
|
||||
await DatabaseManager.init(appmeta.database_url, debug=appmeta.debug)
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(init_db)
|
||||
lifespan.add_startup(_init_db)
|
||||
lifespan.add_startup(migration)
|
||||
lifespan.add_startup(JWT.load_secret_key)
|
||||
lifespan.add_startup(RedisManager.connect)
|
||||
|
||||
# 添加关闭项
|
||||
lifespan.add_shutdown(DatabaseManager.close)
|
||||
lifespan.add_shutdown(RedisManager.disconnect)
|
||||
|
||||
# 创建应用实例并设置元数据
|
||||
|
||||
@@ -4,50 +4,79 @@ from uuid import UUID
|
||||
from fastapi import Depends
|
||||
import jwt
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import JWTPayload, User, UserStatus
|
||||
from utils import JWT
|
||||
from .dependencies import SessionDep
|
||||
from utils import http_exceptions
|
||||
from service.redis import RedisManager
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
|
||||
async def auth_required(
|
||||
|
||||
async def jwt_required(
|
||||
session: SessionDep,
|
||||
token: Annotated[str, Depends(JWT.oauth2_scheme)],
|
||||
) -> User:
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
AuthRequired 需要登录
|
||||
验证 JWT 并返回 claims。
|
||||
|
||||
封禁检查策略:
|
||||
1. JWT 内嵌 status 检查(签发时快照)
|
||||
2. Redis 黑名单检查(即时封禁,如果 Redis 可用)
|
||||
3. Redis 不可用时查库检查 status(降级方案)
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
user_id = payload.get("sub")
|
||||
|
||||
if user_id is None:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
user_id = UUID(user_id)
|
||||
|
||||
# 从数据库获取用户信息
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("账号或密码错误")
|
||||
|
||||
return user
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
claims = JWTPayload(
|
||||
sub=payload["sub"],
|
||||
jti=payload["jti"],
|
||||
status=payload["status"],
|
||||
group=payload["group"],
|
||||
)
|
||||
except (jwt.InvalidTokenError, KeyError, ValueError):
|
||||
http_exceptions.raise_unauthorized("凭据过期或无效")
|
||||
|
||||
# 1. JWT 内嵌 status 检查
|
||||
if claims.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 2. 即时封禁检查
|
||||
user_id_str = str(claims.sub)
|
||||
if RedisManager.is_available():
|
||||
# Redis 可用:查黑名单
|
||||
if await UserBanStore.is_banned(user_id_str):
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
else:
|
||||
# Redis 不可用:查库(仅 status 字段,不加载关系)
|
||||
user = await User.get(session, User.id == claims.sub)
|
||||
if not user or user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
return claims
|
||||
|
||||
|
||||
async def admin_required(
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> User:
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> JWTPayload:
|
||||
"""
|
||||
验证是否为管理员。
|
||||
验证管理员权限(仅读取 JWT claims,不查库)。
|
||||
|
||||
使用方法:
|
||||
>>> APIRouter(dependencies=[Depends(admin_required)])
|
||||
"""
|
||||
group = await user.awaitable_attrs.group
|
||||
if group.admin:
|
||||
return user
|
||||
raise http_exceptions.raise_forbidden("Admin Required")
|
||||
if not claims.group.admin:
|
||||
http_exceptions.raise_forbidden("Admin Required")
|
||||
return claims
|
||||
|
||||
|
||||
async def auth_required(
|
||||
session: SessionDep,
|
||||
claims: Annotated[JWTPayload, Depends(jwt_required)],
|
||||
) -> User:
|
||||
"""验证 JWT + 从数据库加载完整 User(含 group 关系)"""
|
||||
user = await User.get(session, User.id == claims.sub, load=User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
return user
|
||||
|
||||
|
||||
def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
||||
@@ -66,4 +95,4 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
|
||||
@@ -6,22 +6,24 @@ FastAPI 依赖注入
|
||||
- TimeFilterRequestDep: 时间筛选查询依赖(用于 count 等统计接口)
|
||||
- TableViewRequestDep: 分页排序查询依赖(包含时间筛选 + 分页排序)
|
||||
- UserFilterParamsDep: 用户筛选参数依赖(用于管理员用户列表)
|
||||
- require_captcha: 验证码校验依赖注入工厂
|
||||
"""
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Query
|
||||
from fastapi import Depends, Form, Query
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.database import get_session
|
||||
from models.mixin import TimeFilterRequest, TableViewRequest
|
||||
from models.user import UserFilterParams, UserStatus
|
||||
from sqlmodels.database_connection import DatabaseManager
|
||||
from sqlmodels.mixin import TimeFilterRequest, TableViewRequest
|
||||
from sqlmodels.user import UserFilterParams, UserStatus
|
||||
|
||||
|
||||
# --- 数据库会话依赖 ---
|
||||
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(get_session)]
|
||||
SessionDep: TypeAlias = Annotated[AsyncSession, Depends(DatabaseManager.get_session)]
|
||||
"""数据库会话依赖,用于路由函数中获取数据库会话"""
|
||||
|
||||
|
||||
@@ -79,14 +81,14 @@ TableViewRequestDep: TypeAlias = Annotated[TableViewRequest, Depends(_get_table_
|
||||
|
||||
async def _get_user_filter_params(
|
||||
group_id: Annotated[UUID | None, Query(description="按用户组UUID筛选")] = None,
|
||||
username: Annotated[str | None, Query(max_length=50, description="按用户名模糊搜索")] = None,
|
||||
email: Annotated[str | None, Query(max_length=50, description="按邮箱模糊搜索")] = None,
|
||||
nickname: Annotated[str | None, Query(max_length=50, description="按昵称模糊搜索")] = None,
|
||||
status: Annotated[UserStatus | None, Query(description="按用户状态筛选")] = None,
|
||||
) -> UserFilterParams:
|
||||
"""解析用户过滤查询参数"""
|
||||
return UserFilterParams(
|
||||
group_id=group_id,
|
||||
username_contains=username,
|
||||
email_contains=email,
|
||||
nickname_contains=nickname,
|
||||
status=status,
|
||||
)
|
||||
@@ -94,3 +96,30 @@ async def _get_user_filter_params(
|
||||
|
||||
UserFilterParamsDep: TypeAlias = Annotated[UserFilterParams, Depends(_get_user_filter_params)]
|
||||
"""获取用户筛选参数的依赖(用于管理员用户列表)"""
|
||||
|
||||
|
||||
# --- 验证码校验依赖 ---
|
||||
|
||||
def require_captcha(scene: 'CaptchaScene') -> Callable[..., Awaitable[None]]:
|
||||
"""
|
||||
验证码校验依赖注入工厂。
|
||||
|
||||
根据场景查询数据库设置,判断是否需要验证码。
|
||||
需要则校验前端提交的 captcha_code,失败则抛出异常。
|
||||
|
||||
使用方式::
|
||||
|
||||
@router.post('/session', dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))])
|
||||
async def login(...): ...
|
||||
|
||||
:param scene: 验证码使用场景(LOGIN / REGISTER / FORGET)
|
||||
"""
|
||||
from service.captcha import CaptchaScene, verify_captcha_if_needed
|
||||
|
||||
async def _verify_captcha(
|
||||
session: SessionDep,
|
||||
captcha_code: Annotated[str | None, Form()] = None,
|
||||
) -> None:
|
||||
await verify_captcha_if_needed(session, scene, captcha_code)
|
||||
|
||||
return _verify_captcha
|
||||
|
||||
@@ -1,456 +0,0 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import String, inspect
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的default_factory丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
|
||||
通过从 MRO 中查找父类的原始字段定义来获取正确的 default_factory,
|
||||
遵循单一真相原则(不硬编码 default_factory)。
|
||||
|
||||
需要修复的字段:
|
||||
- id: 主键(从父类获取 default_factory)
|
||||
- created_at: 创建时间戳(从父类获取 default_factory)
|
||||
- updated_at: 更新时间戳(从父类获取 default_factory)
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被 InstrumentedAttribute 污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 污染的
|
||||
if not isinstance(field_info.default, InstrumentedAttribute):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
# 动态检测所有需要修复的字段
|
||||
# 遵循单一真相原则:不硬编码字段列表,而是通过以下条件判断:
|
||||
# 1. default 是 InstrumentedAttribute(被 SQLAlchemy 污染)
|
||||
# 2. 原始定义有 default_factory 或明确的 default 值
|
||||
#
|
||||
# 覆盖场景:
|
||||
# - UUID主键(UUIDTableBaseMixin):id 有 default_factory=uuid.uuid4,需要修复
|
||||
# - int主键(TableBaseMixin):id 用 default=None,不需要修复(数据库自增)
|
||||
# - created_at/updated_at:有 default_factory=now,需要修复
|
||||
# - 外键字段(created_by_id等):有 default=None,需要修复
|
||||
# - 普通字段(name, temperature等):无 default_factory,不需要修复
|
||||
#
|
||||
# MRO 查找保证:
|
||||
# - 在多重继承场景下,MRO 顺序是确定性的
|
||||
# - find_original_field_info 会找到第一个未被污染且有该字段的父类
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute)
|
||||
if not isinstance(current_field.default, InstrumentedAttribute):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, ""),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制SQLModel特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
Example:
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
@@ -5,15 +5,15 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
)
|
||||
from models.base import SQLModelBase
|
||||
from models.setting import (
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.setting import (
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from models.setting import SettingsType
|
||||
from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
from .file import admin_file_router
|
||||
|
||||
@@ -5,14 +5,60 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
Policy, PolicyType, User, ResponseBase, ListResponse,
|
||||
from sqlmodels import (
|
||||
JWTPayload, Policy, PolicyType, User, ListResponse,
|
||||
Object, ObjectType, AdminFileResponse, FileBanRequest, )
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
async def _set_ban_recursive(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
ban: bool,
|
||||
admin_id: UUID,
|
||||
reason: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
递归设置封禁状态,返回受影响对象数量。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要封禁/解禁的对象
|
||||
:param ban: True=封禁, False=解禁
|
||||
:param admin_id: 管理员UUID
|
||||
:param reason: 封禁原因
|
||||
:return: 受影响的对象数量
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# 如果是文件夹,先递归处理子对象
|
||||
if obj.is_folder:
|
||||
children = await Object.get(
|
||||
session,
|
||||
Object.parent_id == obj.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for child in children:
|
||||
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
|
||||
|
||||
# 设置当前对象
|
||||
obj.is_banned = ban
|
||||
if ban:
|
||||
obj.banned_at = datetime.now()
|
||||
obj.banned_by = admin_id
|
||||
obj.ban_reason = reason
|
||||
else:
|
||||
obj.banned_at = None
|
||||
obj.banned_by = None
|
||||
obj.ban_reason = None
|
||||
|
||||
await obj.save(session)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
admin_file_router = APIRouter(
|
||||
prefix="/file",
|
||||
tags=["admin", "admin_file"],
|
||||
@@ -118,45 +164,32 @@ async def router_admin_preview_file(
|
||||
path='/ban/{file_id}',
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
) -> ResponseBase:
|
||||
claims: Annotated[JWTPayload, Depends(admin_required)],
|
||||
) -> None:
|
||||
"""
|
||||
封禁或解禁文件。封禁后用户无法访问该文件。
|
||||
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
|
||||
封禁文件夹时会级联封禁所有子对象。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param file_id: 文件UUID
|
||||
:param request: 封禁请求
|
||||
:param admin: 当前管理员
|
||||
:param claims: 当前管理员 JWT claims
|
||||
:return: 封禁结果
|
||||
"""
|
||||
file_obj = await Object.get(session, Object.id == file_id)
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
file_obj.is_banned = request.is_banned
|
||||
if request.is_banned:
|
||||
file_obj.banned_at = datetime.now()
|
||||
file_obj.banned_by = admin.id
|
||||
file_obj.ban_reason = request.reason
|
||||
else:
|
||||
file_obj.banned_at = None
|
||||
file_obj.banned_by = None
|
||||
file_obj.ban_reason = None
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, claims.sub, request.reason)
|
||||
|
||||
file_obj = await file_obj.save(session)
|
||||
|
||||
action = "封禁" if request.is_banned else "解禁"
|
||||
l.info(f"管理员{action}了文件: {file_obj.name}")
|
||||
return ResponseBase(data={
|
||||
"id": str(file_obj.id),
|
||||
"is_banned": file_obj.is_banned,
|
||||
})
|
||||
action = "封禁" if request.ban else "解禁"
|
||||
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
|
||||
|
||||
|
||||
@admin_file_router.delete(
|
||||
@@ -164,12 +197,13 @@ async def router_admin_ban_file(
|
||||
summary='删除文件',
|
||||
description='Delete file by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
delete_physical: bool = True,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除文件。
|
||||
|
||||
@@ -211,5 +245,4 @@ async def router_admin_delete_file(
|
||||
# 使用条件删除
|
||||
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
@@ -5,12 +5,12 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, GroupOptions, )
|
||||
from models.group import (
|
||||
from sqlmodels.group import (
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
|
||||
from models.policy import GroupPolicyLink
|
||||
from sqlmodels.policy import GroupPolicyLink
|
||||
|
||||
admin_group_router = APIRouter(
|
||||
prefix="/group",
|
||||
@@ -113,11 +113,12 @@ async def router_admin_get_group_members(
|
||||
summary='创建用户组',
|
||||
description='Create a new user group',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_create_group(
|
||||
session: SessionDep,
|
||||
request: GroupCreateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建新的用户组。
|
||||
|
||||
@@ -164,7 +165,6 @@ async def router_admin_create_group(
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员创建了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id), "name": group.name})
|
||||
|
||||
|
||||
@admin_group_router.patch(
|
||||
@@ -172,12 +172,13 @@ async def router_admin_create_group(
|
||||
summary='更新用户组信息',
|
||||
description='Update user group information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_update_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
request: GroupUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID更新用户组信息。
|
||||
|
||||
@@ -233,8 +234,7 @@ async def router_admin_update_group(
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员更新了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id)})
|
||||
l.info(f"管理员更新了用户组: {group_id}")
|
||||
|
||||
|
||||
@admin_group_router.delete(
|
||||
@@ -242,11 +242,12 @@ async def router_admin_update_group(
|
||||
summary='删除用户组',
|
||||
description='Delete user group by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID删除用户组。
|
||||
|
||||
@@ -271,5 +272,4 @@ async def router_admin_delete_group(
|
||||
group_name = group.name
|
||||
await Group.delete(session, group)
|
||||
|
||||
l.info(f"管理员删除了用户组: {group_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了用户组: {group_id}")
|
||||
@@ -6,10 +6,10 @@ from sqlmodel import Field
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object, )
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Share, AdminShareListItem, )
|
||||
|
||||
@@ -80,7 +80,7 @@ async def router_admin_get_share(
|
||||
"score": share.score,
|
||||
"has_password": bool(share.password),
|
||||
"user_id": str(share.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"object": {
|
||||
"id": str(obj.id),
|
||||
"name": obj.name,
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Task, TaskSummary,
|
||||
)
|
||||
@@ -89,7 +89,7 @@ async def router_admin_get_task(
|
||||
"progress": task.progress,
|
||||
"error": task.error,
|
||||
"user_id": str(task.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"props": props.model_dump() if props else None,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
"updated_at": task.updated_at.isoformat(),
|
||||
|
||||
@@ -6,11 +6,14 @@ from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
|
||||
from models import (
|
||||
from service.redis.user_ban_store import UserBanStore
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, Object, ObjectType, )
|
||||
from models.user import (
|
||||
UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse, UserStatus,
|
||||
)
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
@@ -26,19 +29,19 @@ admin_user_router = APIRouter(
|
||||
description='Get user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
|
||||
async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
|
||||
"""
|
||||
根据用户ID获取用户信息,包括用户名、邮箱、注册时间等。
|
||||
|
||||
Args:
|
||||
session(SessionDep): 数据库会话依赖项。
|
||||
user_id (int): 用户ID。
|
||||
user_id (UUID): 用户ID。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含用户信息的响应模型。
|
||||
"""
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.get(
|
||||
@@ -60,7 +63,7 @@ async def router_admin_get_users(
|
||||
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
|
||||
:return: 分页用户列表
|
||||
"""
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view)
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
|
||||
return ListResponse(
|
||||
items=[user.to_public() for user in result.items],
|
||||
count=result.count,
|
||||
@@ -75,22 +78,33 @@ async def router_admin_get_users(
|
||||
)
|
||||
async def router_admin_create_user(
|
||||
session: SessionDep,
|
||||
user: User,
|
||||
) -> ResponseBase:
|
||||
request: UserAdminCreateRequest,
|
||||
) -> UserPublic:
|
||||
"""
|
||||
创建一个新的用户,设置用户名、密码等信息。
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含创建结果的响应模型。
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.username == user.username)
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg="User with this username already exists."
|
||||
)
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
password=Password.hash(request.password),
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.patch(
|
||||
@@ -98,12 +112,13 @@ async def router_admin_create_user(
|
||||
summary='更新用户信息',
|
||||
description='Update user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204
|
||||
)
|
||||
async def router_admin_update_user(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
request: UserAdminUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID更新用户信息。
|
||||
|
||||
@@ -116,8 +131,15 @@ async def router_admin_update_user(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 默认管理员(用户名为 admin)不允许更改用户组
|
||||
if request.group_id and user.username == "admin" and request.group_id != user.group_id:
|
||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
if (request.group_id
|
||||
and default_admin_setting
|
||||
and default_admin_setting.value == str(user_id)
|
||||
and request.group_id != user.group_id):
|
||||
http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
|
||||
|
||||
# 如果更新用户组,验证新组存在
|
||||
@@ -138,43 +160,50 @@ async def router_admin_update_user(
|
||||
if len(update_data['two_factor']) != 32:
|
||||
raise HTTPException(status_code=400, detail="两步验证密钥必须为32位字符串")
|
||||
|
||||
# 记录旧 status 以便检测变更
|
||||
old_status = user.status
|
||||
|
||||
# 更新字段
|
||||
for key, value in update_data.items():
|
||||
setattr(user, key, value)
|
||||
user = await user.save(session)
|
||||
|
||||
l.info(f"管理员更新了用户: {user.username}")
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
# 封禁状态变更 → 更新 BanStore
|
||||
new_status = user.status
|
||||
if old_status == UserStatus.ACTIVE and new_status != UserStatus.ACTIVE:
|
||||
await UserBanStore.ban(str(user_id))
|
||||
elif old_status != UserStatus.ACTIVE and new_status == UserStatus.ACTIVE:
|
||||
await UserBanStore.unban(str(user_id))
|
||||
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
path='/{user_id}',
|
||||
summary='删除用户',
|
||||
description='Delete user by ID',
|
||||
path='/',
|
||||
summary='删除用户(支持批量)',
|
||||
description='Delete users by ID list',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_user(
|
||||
async def router_admin_delete_users(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
request: BatchDeleteRequest,
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID删除用户及其所有数据。
|
||||
批量删除用户及其所有数据。
|
||||
|
||||
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除结果
|
||||
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
||||
:return: 删除结果(已删除数 / 总请求数)
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
username = user.username
|
||||
await User.delete(session, user)
|
||||
|
||||
l.info(f"管理员删除了用户: {username}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
deleted = 0
|
||||
for uid in request.ids:
|
||||
user = await User.get(session, User.id == uid)
|
||||
if user:
|
||||
await User.delete(session, user)
|
||||
l.info(f"管理员删除了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.post(
|
||||
@@ -186,7 +215,7 @@ async def router_admin_delete_user(
|
||||
async def router_admin_calibrate_storage(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> UserCalibrateResponse:
|
||||
"""
|
||||
重新计算用户的已用存储空间。
|
||||
|
||||
@@ -228,5 +257,5 @@ async def router_admin_calibrate_storage(
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}")
|
||||
return ResponseBase(data=response.model_dump())
|
||||
l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
|
||||
return response
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
import service.oauth
|
||||
from utils import http_exceptions
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
@@ -14,50 +16,28 @@ from models import (
|
||||
User,
|
||||
ResponseBase,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
directory_router = APIRouter(
|
||||
prefix="/directory",
|
||||
tags=["directory"]
|
||||
)
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
|
||||
async def _get_directory_response(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
folder: Object,
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
|
||||
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
|
||||
构建目录响应 DTO
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(必须以用户名开头)
|
||||
:return: 目录内容
|
||||
:param user_id: 用户UUID
|
||||
:param folder: 目录对象
|
||||
:return: DirectoryResponse
|
||||
"""
|
||||
# 路径必须以用户名开头
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
|
||||
|
||||
path_parts = path.split("/")
|
||||
if path_parts[0] != user.username:
|
||||
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
children = await Object.get_children(session, user.id, folder.id)
|
||||
children = await Object.get_children(session, user_id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
|
||||
objects = [
|
||||
@@ -67,8 +47,8 @@ async def router_directory_get(
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
date=child.updated_at,
|
||||
create_date=child.created_at,
|
||||
created_at=child.created_at,
|
||||
updated_at=child.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for child in children
|
||||
@@ -89,7 +69,74 @@ async def router_directory_get(
|
||||
)
|
||||
|
||||
|
||||
@directory_router.put(
|
||||
@directory_router.get(
|
||||
path="/",
|
||||
summary="获取根目录内容",
|
||||
)
|
||||
async def router_directory_root(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取当前用户的根目录内容
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:return: 根目录内容
|
||||
"""
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
|
||||
if root.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(从根目录开始的相对路径)
|
||||
:return: 目录内容
|
||||
"""
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
# 空路径交给根目录端点处理(理论上不会到达这里)
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
if folder.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, folder)
|
||||
|
||||
|
||||
@directory_router.post(
|
||||
path="/",
|
||||
summary="创建目录",
|
||||
)
|
||||
@@ -123,6 +170,9 @@ async def router_directory_create(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父路径不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名对象
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
download_router = APIRouter(
|
||||
|
||||
@@ -18,7 +18,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required, verify_download_token
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
Object,
|
||||
@@ -91,6 +91,9 @@ async def create_upload_session(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
@@ -100,7 +103,7 @@ async def create_upload_session(
|
||||
# 验证文件大小限制
|
||||
if policy.max_size > 0 and request.file_size > policy.max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
status_code=413,
|
||||
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
|
||||
)
|
||||
|
||||
@@ -221,30 +224,40 @@ async def upload_chunk(
|
||||
upload_session.uploaded_size += len(content)
|
||||
upload_session = await upload_session.save(session)
|
||||
|
||||
# 检查是否完成
|
||||
# 在后续可能的 commit 前保存需要的属性
|
||||
is_complete = upload_session.is_complete
|
||||
uploaded_chunks = upload_session.uploaded_chunks
|
||||
total_chunks = upload_session.total_chunks
|
||||
file_object_id: UUID | None = None
|
||||
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=upload_session.storage_path,
|
||||
size=upload_session.uploaded_size,
|
||||
policy_id=upload_session.policy_id,
|
||||
storage_path=storage_path,
|
||||
size=uploaded_size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=upload_session.file_name,
|
||||
name=file_name,
|
||||
type=ObjectType.FILE,
|
||||
size=upload_session.uploaded_size,
|
||||
size=uploaded_size,
|
||||
physical_file_id=physical_file.id,
|
||||
upload_session_id=str(upload_session.id),
|
||||
parent_id=upload_session.parent_id,
|
||||
upload_session_id=str(upload_session_id),
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=upload_session.policy_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
file_object = await file_object.save(session, commit=False)
|
||||
file_object_id = file_object.id
|
||||
@@ -252,18 +265,18 @@ async def upload_chunk(
|
||||
# 删除上传会话(使用条件删除)
|
||||
await UploadSession.delete(
|
||||
session,
|
||||
condition=UploadSession.id == upload_session.id,
|
||||
condition=UploadSession.id == upload_session_id,
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
|
||||
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
||||
l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}")
|
||||
|
||||
return UploadChunkResponse(
|
||||
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks,
|
||||
total_chunks=upload_session.total_chunks,
|
||||
uploaded_chunks=uploaded_chunks if not is_complete else total_chunks,
|
||||
total_chunks=total_chunks,
|
||||
is_complete=is_complete,
|
||||
object_id=file_object_id,
|
||||
)
|
||||
@@ -368,6 +381,9 @@ async def create_download_token_endpoint(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
token = create_download_token(file_id, user.id)
|
||||
|
||||
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
|
||||
@@ -410,6 +426,9 @@ async def download_file(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 预加载 physical_file 关系以获取存储路径
|
||||
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
@@ -470,6 +489,9 @@ async def create_empty_file(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import MCPRequestBase, MCPResponseBase, MCPMethod
|
||||
|
||||
# MCP 路由
|
||||
MCP_router = APIRouter(
|
||||
prefix='/mcp',
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
@MCP_router.get(
|
||||
"/",
|
||||
)
|
||||
async def mcp_root(
|
||||
param: MCPRequestBase
|
||||
):
|
||||
match param.method:
|
||||
case MCPMethod.PING:
|
||||
return MCPResponseBase(result="pong", **param.model_dump())
|
||||
@@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -26,10 +27,11 @@ from models import (
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
ResponseBase,
|
||||
User,
|
||||
)
|
||||
from models import ResponseBase
|
||||
from service.storage import LocalStorageService
|
||||
from utils import http_exceptions
|
||||
|
||||
object_router = APIRouter(
|
||||
prefix="/object",
|
||||
@@ -59,15 +61,22 @@ async def _delete_object_recursive(
|
||||
"""
|
||||
deleted_count = 0
|
||||
|
||||
if obj.is_folder:
|
||||
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
|
||||
obj_id = obj.id
|
||||
obj_name = obj.name
|
||||
obj_is_folder = obj.is_folder
|
||||
obj_is_file = obj.is_file
|
||||
obj_physical_file_id = obj.physical_file_id
|
||||
|
||||
if obj_is_folder:
|
||||
# 递归删除子对象
|
||||
children = await Object.get_children(session, user_id, obj.id)
|
||||
children = await Object.get_children(session, user_id, obj_id)
|
||||
for child in children:
|
||||
deleted_count += await _delete_object_recursive(session, child, user_id)
|
||||
|
||||
# 如果是文件,处理物理文件引用
|
||||
if obj.is_file and obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if obj_is_file and obj_physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
|
||||
if physical_file:
|
||||
# 减少引用计数
|
||||
new_count = physical_file.decrement_reference()
|
||||
@@ -81,11 +90,11 @@ async def _delete_object_recursive(
|
||||
await storage_service.move_to_trash(
|
||||
source_path=physical_file.storage_path,
|
||||
user_id=user_id,
|
||||
object_id=obj.id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
l.debug(f"物理文件已移动到回收站: {obj.name}")
|
||||
l.debug(f"物理文件已移动到回收站: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
|
||||
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
|
||||
|
||||
# 删除 PhysicalFile 记录
|
||||
await PhysicalFile.delete(session, physical_file)
|
||||
@@ -95,8 +104,8 @@ async def _delete_object_recursive(
|
||||
await physical_file.save(session)
|
||||
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
|
||||
|
||||
# 删除数据库记录
|
||||
await Object.delete(session, obj)
|
||||
# 使用条件删除,避免访问过期的 obj 实例
|
||||
await Object.delete(session, condition=Object.id == obj_id)
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
@@ -168,6 +177,97 @@ async def _copy_object_recursive(
|
||||
return copied_count, new_ids
|
||||
|
||||
|
||||
@object_router.post(
|
||||
path='/',
|
||||
summary='创建空白文件',
|
||||
description='在指定目录下创建空白文件。',
|
||||
)
|
||||
async def router_object_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateFileRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
创建空白文件端点
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 创建文件请求(parent_id, name)
|
||||
:return: 创建结果
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 验证文件名
|
||||
if not request.name or '/' in request.name or '\\' in request.name:
|
||||
raise HTTPException(status_code=400, detail="无效的文件名")
|
||||
|
||||
# 验证父目录
|
||||
parent = await Object.get(session, Object.id == request.parent_id)
|
||||
if not parent or parent.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == request.name)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件已存在")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
parent_id = parent.id
|
||||
|
||||
# 生成存储路径并创建空文件
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=0,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=request.name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
await file_object.save(session)
|
||||
|
||||
l.info(f"创建空白文件: {request.name}")
|
||||
|
||||
return ResponseBase()
|
||||
|
||||
|
||||
@object_router.delete(
|
||||
path='/',
|
||||
summary='删除对象',
|
||||
@@ -197,10 +297,7 @@ async def router_object_delete(
|
||||
user_id = user.id
|
||||
deleted_count = 0
|
||||
|
||||
# 处理单个 UUID 或 UUID 列表
|
||||
ids = request.ids if isinstance(request.ids, list) else [request.ids]
|
||||
|
||||
for obj_id in ids:
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
if not obj or obj.owner_id != user_id:
|
||||
continue
|
||||
@@ -219,7 +316,7 @@ async def router_object_delete(
|
||||
return ResponseBase(
|
||||
data={
|
||||
"deleted": deleted_count,
|
||||
"total": len(ids),
|
||||
"total": len(request.ids),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -253,6 +350,9 @@ async def router_object_move(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
|
||||
dst_id = dst.id
|
||||
dst_parent_id = dst.parent_id
|
||||
@@ -264,6 +364,9 @@ async def router_object_move(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能移动根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -348,6 +451,9 @@ async def router_object_copy(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
|
||||
@@ -356,6 +462,9 @@ async def router_object_copy(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能复制根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -438,6 +547,9 @@ async def router_object_rename(
|
||||
if obj.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 不能重命名根目录
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="无法重命名根目录")
|
||||
@@ -543,7 +655,7 @@ async def router_object_property_detail(
|
||||
policy_name = policy.name if policy else None
|
||||
|
||||
# 获取分享统计
|
||||
from models import Share
|
||||
from sqlmodels import Share
|
||||
shares = await Share.get(
|
||||
session,
|
||||
Share.object_id == obj.id,
|
||||
|
||||
@@ -7,11 +7,11 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase
|
||||
from models.user import User
|
||||
from models.share import Share, ShareCreateRequest, ShareResponse
|
||||
from models.object import Object
|
||||
from models.mixin import ListResponse, TableViewRequest
|
||||
from sqlmodels import ResponseBase
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
|
||||
from sqlmodels.object import Object
|
||||
from sqlmodels.mixin import ListResponse, TableViewRequest
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase:
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/doc/{id}',
|
||||
summary='取得Office文档预览地址',
|
||||
description='Get Office document preview URL by ID.',
|
||||
)
|
||||
def router_share_doc(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get Office document preview URL by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the Office document.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the document preview URL.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/content/{id}',
|
||||
summary='获取文本文件内容',
|
||||
@@ -261,6 +244,9 @@ async def router_share_create(
|
||||
if not obj or obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 生成分享码
|
||||
code = str(uuid4())
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
site_router = APIRouter(
|
||||
@@ -43,16 +44,43 @@ def router_site_captcha():
|
||||
@site_router.get(
|
||||
path='/config',
|
||||
summary='站点全局配置',
|
||||
description='Get the configuration file.',
|
||||
response_model=ResponseBase,
|
||||
description='获取站点全局配置,包括验证码设置、注册开关等。',
|
||||
)
|
||||
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
"""
|
||||
Get the configuration file.
|
||||
获取站点全局配置
|
||||
|
||||
Returns:
|
||||
dict: The site configuration.
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码等配置。
|
||||
"""
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
# 构建 name→value 映射
|
||||
s: dict[str, str | None] = {item.name: item.value for item in settings}
|
||||
|
||||
# 根据 captcha_type 选择对应的 public key
|
||||
captcha_type_str = s.get("captcha_type", "default")
|
||||
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
|
||||
captcha_key: str | None = None
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
captcha_key = s.get("captcha_ReCaptchaKey") or None
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
login_captcha=s.get("login_captcha") == "1",
|
||||
reg_captcha=s.get("reg_captcha") == "1",
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
slave_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import auth_required
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
tag_router = APIRouter(
|
||||
|
||||
@@ -1,73 +1,82 @@
|
||||
from typing import Annotated, Literal
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger
|
||||
|
||||
import models
|
||||
import service
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils.JWT import SECRET_KEY
|
||||
from utils import Password, http_exceptions
|
||||
from middleware.dependencies import SessionDep, require_captcha
|
||||
from service.captcha import CaptchaScene
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
|
||||
user_router = APIRouter(
|
||||
prefix="/user",
|
||||
tags=["user"],
|
||||
)
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/user/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
class OAuth2PasswordWithExtrasForm:
|
||||
"""
|
||||
扩展 OAuth2 密码表单。
|
||||
|
||||
在标准 username/password 基础上添加 otp_code 字段。
|
||||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
username: Annotated[str, Form()],
|
||||
password: Annotated[str, Form()],
|
||||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||||
):
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.otp_code = otp_code
|
||||
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
summary='用户登录',
|
||||
description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。',
|
||||
description='用户登录端点,支持验证码校验和两步验证。',
|
||||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||||
)
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> models.TokenResponse:
|
||||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点。
|
||||
用户登录端点
|
||||
|
||||
根据 OAuth2.1 规范,使用 password grant type 进行登录。
|
||||
当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。
|
||||
表单字段:
|
||||
- username: 用户邮箱
|
||||
- password: 用户密码
|
||||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
错误处理:
|
||||
- 400: 需要验证码但未提供
|
||||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||||
- 403: 账户已禁用 / 验证码验证失败
|
||||
- 428: 需要两步验证但未提供 otp_code
|
||||
"""
|
||||
username = form_data.username
|
||||
password = form_data.password
|
||||
|
||||
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||
# scopes 格式可以是 ["otp:123456"] 或 ["123456"]
|
||||
otp_code: str | None = None
|
||||
for scope in form_data.scopes:
|
||||
if scope.startswith("otp:"):
|
||||
otp_code = scope[4:]
|
||||
break
|
||||
elif scope.isdigit() and len(scope) == 6:
|
||||
otp_code = scope
|
||||
break
|
||||
|
||||
result = await service.user.login(
|
||||
return await service.user.login(
|
||||
session,
|
||||
models.LoginRequest(
|
||||
username=username,
|
||||
password=password,
|
||||
two_fa_code=otp_code,
|
||||
sqlmodels.LoginRequest(
|
||||
email=form_data.username,
|
||||
password=form_data.password,
|
||||
two_fa_code=form_data.otp_code,
|
||||
),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@user_router.post(
|
||||
path='/session/refresh',
|
||||
summary="用刷新令牌刷新会话",
|
||||
@@ -75,19 +84,80 @@ async def router_user_session(
|
||||
)
|
||||
async def router_user_session_refresh(
|
||||
session: SessionDep,
|
||||
request, # RefreshTokenRequest
|
||||
) -> models.TokenResponse:
|
||||
http_exceptions.raise_not_implemented()
|
||||
request: sqlmodels.RefreshTokenRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
使用 refresh_token 签发新的 access_token 和 refresh_token。
|
||||
|
||||
流程:
|
||||
1. 解码 refresh_token JWT
|
||||
2. 验证 token_type 为 refresh
|
||||
3. 验证用户存在且状态正常
|
||||
4. 签发新的 access_token + refresh_token
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 刷新令牌请求
|
||||
:return: 新的 TokenResponse
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
|
||||
|
||||
# 验证是 refresh token
|
||||
if payload.get("token_type") != "refresh":
|
||||
http_exceptions.raise_unauthorized("非刷新令牌")
|
||||
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
http_exceptions.raise_unauthorized("令牌缺少用户标识")
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id, load=sqlmodels.User.group)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
if user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 加载 GroupOptions(获取最新权限)
|
||||
group_options = await sqlmodels.GroupOptions.get(
|
||||
session,
|
||||
sqlmodels.GroupOptions.group_id == user.group_id,
|
||||
)
|
||||
user.group.options = group_options
|
||||
group_claims = sqlmodels.GroupClaims.from_group(user.group)
|
||||
|
||||
# 签发新令牌
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return sqlmodels.TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
access_expires=access_token.access_expires,
|
||||
refresh_token=refresh_token.refresh_token,
|
||||
refresh_expires=refresh_token.refresh_expires,
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: models.RegisterRequest,
|
||||
) -> models.ResponseBase:
|
||||
request: sqlmodels.RegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
|
||||
@@ -95,7 +165,7 @@ async def router_user_register(
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建以用户名命名的根目录
|
||||
4. 创建用户根目录(name="/")
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
@@ -103,62 +173,53 @@ async def router_user_register(
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
"""
|
||||
# 1. 验证用户名唯一性
|
||||
existing_user = await models.User.get(
|
||||
# 1. 验证邮箱唯一性
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
models.User.username == request.username
|
||||
sqlmodels.User.email == request.email
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
default_group = await models.Group.get(session, models.Group.id == default_group_id)
|
||||
default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
|
||||
if not default_group:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
new_user = models.User(
|
||||
username=request.username,
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id # 在 save 前保存 UUID
|
||||
new_user_username = new_user.username
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建以用户名命名的根目录
|
||||
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
|
||||
# 4. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
await models.Object(
|
||||
name=new_user_username,
|
||||
type=models.ObjectType.FOLDER,
|
||||
await sqlmodels.Object(
|
||||
name="/",
|
||||
type=sqlmodels.ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"user_id": new_user_id,
|
||||
"username": new_user_username,
|
||||
},
|
||||
msg="注册成功",
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -166,7 +227,7 @@ async def router_user_register(
|
||||
)
|
||||
def router_user_email_code(
|
||||
reason: Literal['register', 'reset'] = 'register',
|
||||
) -> models.ResponseBase:
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
@@ -180,7 +241,7 @@ def router_user_email_code(
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> models.ResponseBase:
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
@@ -194,7 +255,7 @@ def router_user_qq() -> models.ResponseBase:
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@@ -203,7 +264,7 @@ async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
@@ -220,7 +281,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
summary='获取用户主页展示用分享',
|
||||
description='Get user profile for display.',
|
||||
)
|
||||
def router_user_profile(id: str) -> models.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
@@ -237,7 +298,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
|
||||
summary='获取用户头像',
|
||||
description='Get user avatar by ID and size.',
|
||||
)
|
||||
def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
@@ -259,12 +320,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
summary='获取用户信息',
|
||||
description='Get user information.',
|
||||
dependencies=[Depends(dependency=auth_required)],
|
||||
response_model=models.UserResponse,
|
||||
response_model=sqlmodels.UserResponse,
|
||||
)
|
||||
async def router_user_me(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserResponse:
|
||||
"""
|
||||
获取用户信息.
|
||||
|
||||
@@ -272,10 +333,10 @@ async def router_user_me(
|
||||
:rtype: ResponseBase
|
||||
"""
|
||||
# 加载 group 及其 options 关系
|
||||
group = await models.Group.get(
|
||||
group = await sqlmodels.Group.get(
|
||||
session,
|
||||
models.Group.id == user.group_id,
|
||||
load=models.Group.options
|
||||
sqlmodels.Group.id == user.group_id,
|
||||
load=sqlmodels.Group.options
|
||||
)
|
||||
|
||||
# 构建 GroupResponse
|
||||
@@ -284,9 +345,9 @@ async def router_user_me(
|
||||
# 异步加载 tags 关系
|
||||
user_tags = await user.awaitable_attrs.tags
|
||||
|
||||
return models.UserResponse(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
@@ -304,30 +365,26 @@ async def router_user_me(
|
||||
)
|
||||
async def router_user_storage(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserStorageResponse:
|
||||
"""
|
||||
获取用户存储空间信息。
|
||||
|
||||
返回值:
|
||||
- used: 已使用空间(字节)
|
||||
- free: 剩余空间(字节)
|
||||
- total: 总容量(字节)= 用户组容量
|
||||
"""
|
||||
# 获取用户组的基础存储容量
|
||||
group = await models.Group.get(session, models.Group.id == user.group_id)
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=500, detail="用户组不存在")
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"used": used,
|
||||
"free": free,
|
||||
"total": total,
|
||||
}
|
||||
return sqlmodels.UserStorageResponse(
|
||||
used=used,
|
||||
free=free,
|
||||
total=total,
|
||||
)
|
||||
|
||||
@user_router.put(
|
||||
@@ -338,8 +395,8 @@ async def router_user_storage(
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
|
||||
@@ -347,30 +404,30 @@ async def router_user_authn_start(
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await models.Setting.get(
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
|
||||
site_url_setting = await models.Setting.get(
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
|
||||
)
|
||||
site_title_setting = await models.Setting.get(
|
||||
site_title_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
|
||||
)
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.username,
|
||||
user_display_name=user.nick or user.username,
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
)
|
||||
|
||||
return models.ResponseBase(data=options_to_json_dict(options))
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
@@ -378,179 +435,11 @@ async def router_user_authn_start(
|
||||
description='Finish WebAuthn login for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> models.ResponseBase:
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> models.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings() -> models.ResponseBase:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return models.ResponseBase(data=models.UserSettingResponse().model_dump())
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> models.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> models.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return models.ResponseBase(
|
||||
data=await Password.generate_totp(user.username)
|
||||
)
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
http_exceptions.raise_not_implemented()
|
||||
203
routers/api/v1/user/settings/__init__.py
Normal file
203
routers/api/v1/user/settings/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils import JWT, Password, http_exceptions
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
)
|
||||
def router_user_settings(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserSettingResponse:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=user.two_factor is not None,
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data=await Password.generate_totp(user.email)
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
# WebDAV 管理路由
|
||||
|
||||
@@ -1,18 +1,20 @@
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
|
||||
import aiohttp
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .gcaptcha import GCaptcha
|
||||
from .turnstile import TurnstileCaptcha
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class CaptchaRequestBase(BaseModel):
|
||||
"""验证码验证请求"""
|
||||
token: str
|
||||
"""验证 token"""
|
||||
|
||||
response: str
|
||||
"""用户的验证码 response token"""
|
||||
|
||||
secret: str
|
||||
"""验证密钥"""
|
||||
"""服务端密钥"""
|
||||
|
||||
|
||||
class CaptchaBase(abc.ABC):
|
||||
@@ -30,10 +32,89 @@ class CaptchaBase(abc.ABC):
|
||||
"""
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
async with aiohttp.ClientSession() as client_session:
|
||||
async with client_session.post(self.verify_url, data=payload) as resp:
|
||||
if resp.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
return result.get('success', False)
|
||||
result = await resp.json()
|
||||
return result.get('success', False)
|
||||
|
||||
|
||||
# 子类导入必须在 CaptchaBase 定义之后(gcaptcha.py / turnstile.py 依赖 CaptchaBase)
|
||||
from .gcaptcha import GCaptcha # noqa: E402
|
||||
from .turnstile import TurnstileCaptcha # noqa: E402
|
||||
|
||||
|
||||
class CaptchaScene(StrEnum):
|
||||
"""验证码使用场景,value 对应 Setting 表中的 name"""
|
||||
|
||||
LOGIN = "login_captcha"
|
||||
REGISTER = "reg_captcha"
|
||||
FORGET = "forget_captcha"
|
||||
|
||||
|
||||
async def verify_captcha_if_needed(
|
||||
session: AsyncSession,
|
||||
scene: CaptchaScene,
|
||||
captcha_code: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
通用验证码校验:查询设置判断是否需要,需要则校验。
|
||||
|
||||
:param session: 数据库异步会话
|
||||
:param scene: 验证码使用场景
|
||||
:param captcha_code: 用户提交的验证码 response token
|
||||
:raises HTTPException 400: 需要验证码但未提供
|
||||
:raises HTTPException 403: 验证码验证失败
|
||||
:raises HTTPException 500: 验证码密钥未配置
|
||||
"""
|
||||
from sqlmodels import Setting, SettingsType
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
# 1. 查询该场景是否需要验证码
|
||||
scene_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.LOGIN) & (Setting.name == scene.value),
|
||||
)
|
||||
if not scene_setting or scene_setting.value != "1":
|
||||
return
|
||||
|
||||
# 2. 需要但未提供
|
||||
if not captcha_code:
|
||||
http_exceptions.raise_bad_request(detail="请完成验证码验证")
|
||||
|
||||
# 3. 查询验证码类型和密钥
|
||||
captcha_settings: list[Setting] = await Setting.get(
|
||||
session, Setting.type == SettingsType.CAPTCHA, fetch_mode="all",
|
||||
)
|
||||
s: dict[str, str | None] = {item.name: item.value for item in captcha_settings}
|
||||
captcha_type = CaptchaType(s.get("captcha_type") or "default")
|
||||
|
||||
# 4. DEFAULT 图片验证码尚未实现,跳过
|
||||
if captcha_type == CaptchaType.DEFAULT:
|
||||
l.warning("DEFAULT 图片验证码尚未实现,跳过验证")
|
||||
return
|
||||
|
||||
# 5. 选择验证器和密钥
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
secret = s.get("captcha_ReCaptchaSecret")
|
||||
verifier: CaptchaBase = GCaptcha()
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
secret = s.get("captcha_CloudflareSecret")
|
||||
verifier = TurnstileCaptcha()
|
||||
else:
|
||||
l.error(f"未知的验证码类型: {captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
if not secret:
|
||||
l.error(f"验证码密钥未配置: captcha_type={captcha_type}")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 6. 调用第三方 API 校验
|
||||
is_valid = await verifier.verify_captcha(
|
||||
CaptchaRequestBase(response=captcha_code, secret=secret)
|
||||
)
|
||||
if not is_valid:
|
||||
http_exceptions.raise_forbidden(detail="验证码验证失败")
|
||||
|
||||
72
service/redis/user_ban_store.py
Normal file
72
service/redis/user_ban_store.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
用于 JWT 模式下的即时封禁生效。
|
||||
支持 Redis(首选)和内存缓存(降级)两种存储后端。
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from cachetools import TTLCache
|
||||
from loguru import logger as l
|
||||
|
||||
from . import RedisManager
|
||||
|
||||
# access_token 有效期(秒)
|
||||
_BAN_TTL: int = 3600
|
||||
|
||||
|
||||
class UserBanStore:
|
||||
"""
|
||||
用户封禁状态存储
|
||||
|
||||
管理员封禁用户时调用 ban(),jwt_required 每次请求调用 is_banned() 检查。
|
||||
TTL 与 access_token 有效期一致(1h),过期后旧 token 自然失效,无需继续记录。
|
||||
"""
|
||||
|
||||
_memory_cache: ClassVar[TTLCache[str, bool]] = TTLCache(maxsize=10000, ttl=_BAN_TTL)
|
||||
"""内存缓存降级方案"""
|
||||
|
||||
@classmethod
|
||||
async def ban(cls, user_id: str) -> None:
|
||||
"""
|
||||
标记用户为已封禁。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.set(key, "1", ex=_BAN_TTL)
|
||||
else:
|
||||
cls._memory_cache[user_id] = True
|
||||
l.info(f"用户 {user_id} 已加入封禁黑名单")
|
||||
|
||||
@classmethod
|
||||
async def unban(cls, user_id: str) -> None:
|
||||
"""
|
||||
移除用户封禁标记(解封时调用)。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
await client.delete(key)
|
||||
else:
|
||||
cls._memory_cache.pop(user_id, None)
|
||||
l.info(f"用户 {user_id} 已从封禁黑名单移除")
|
||||
|
||||
@classmethod
|
||||
async def is_banned(cls, user_id: str) -> bool:
|
||||
"""
|
||||
检查用户是否在封禁黑名单中。
|
||||
|
||||
:param user_id: 用户 UUID 字符串
|
||||
:return: True 表示已封禁
|
||||
"""
|
||||
client = RedisManager.get_client()
|
||||
if client is not None:
|
||||
key = f"user_ban:{user_id}"
|
||||
return await client.exists(key) > 0
|
||||
else:
|
||||
return user_id in cls._memory_cache
|
||||
@@ -15,7 +15,7 @@ import aiofiles
|
||||
import aiofiles.os
|
||||
from loguru import logger as l
|
||||
|
||||
from models.policy import Policy
|
||||
from sqlmodels.policy import Policy
|
||||
from .exceptions import (
|
||||
DirectoryCreationError,
|
||||
FileReadError,
|
||||
|
||||
@@ -23,7 +23,7 @@ import string
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class NamingContext(SQLModelBase):
|
||||
|
||||
@@ -3,7 +3,9 @@ from uuid import uuid4
|
||||
from loguru import logger
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import LoginRequest, TokenResponse, User
|
||||
from sqlmodels import LoginRequest, TokenResponse, User
|
||||
from sqlmodels.group import GroupClaims, GroupOptions
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import http_exceptions
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
from utils.password.pwd import Password, PasswordStatus
|
||||
@@ -22,55 +24,65 @@ async def login(
|
||||
|
||||
:return: TokenResponse 对象或状态码或 None
|
||||
"""
|
||||
# TODO: 验证码校验
|
||||
# captcha_setting = await Setting.get(
|
||||
# session,
|
||||
# (Setting.type == "auth") & (Setting.name == "login_captcha")
|
||||
# )
|
||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||
|
||||
# 获取用户信息
|
||||
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
|
||||
# 获取用户信息(预加载 group 关系)
|
||||
current_user: User = await User.get(
|
||||
session,
|
||||
User.email == login_request.email,
|
||||
fetch_mode="first",
|
||||
load=User.group,
|
||||
) #type: ignore
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
logger.debug(f"Cannot find user with username: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
logger.debug(f"Cannot find user with email: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证密码是否正确
|
||||
if Password.verify(current_user.password, login_request.password) != PasswordStatus.VALID:
|
||||
logger.debug(f"Password verification failed for user: {login_request.username}")
|
||||
http_exceptions.raise_unauthorized("Invalid username or password")
|
||||
logger.debug(f"Password verification failed for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid email or password")
|
||||
|
||||
# 验证用户是否可登录
|
||||
if not current_user.status:
|
||||
# 验证用户是否可登录(修复:显式枚举比较,StrEnum 永远 truthy)
|
||||
if current_user.status != UserStatus.ACTIVE:
|
||||
http_exceptions.raise_forbidden("Your account is disabled")
|
||||
|
||||
# 检查两步验证
|
||||
if current_user.two_factor:
|
||||
# 用户已启用两步验证
|
||||
if not login_request.two_fa_code:
|
||||
logger.debug(f"2FA required for user: {login_request.username}")
|
||||
logger.debug(f"2FA required for user: {login_request.email}")
|
||||
http_exceptions.raise_precondition_required("2FA required")
|
||||
|
||||
# 验证 OTP 码
|
||||
if Password.verify_totp(current_user.two_factor, login_request.two_fa_code) != PasswordStatus.VALID:
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.username}")
|
||||
logger.debug(f"Invalid 2FA code for user: {login_request.email}")
|
||||
http_exceptions.raise_unauthorized("Invalid 2FA code")
|
||||
|
||||
# 加载 GroupOptions
|
||||
group_options: GroupOptions | None = await GroupOptions.get(
|
||||
session,
|
||||
GroupOptions.group_id == current_user.group_id,
|
||||
)
|
||||
|
||||
# 构建权限快照
|
||||
current_user.group.options = group_options
|
||||
group_claims = GroupClaims.from_group(current_user.group)
|
||||
|
||||
# 创建令牌
|
||||
access_token = create_access_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
refresh_token = create_refresh_token(data={
|
||||
'sub': str(current_user.id),
|
||||
'jti': str(uuid4())
|
||||
})
|
||||
access_token = create_access_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4(),
|
||||
status=current_user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
sub=current_user.id,
|
||||
jti=uuid4()
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
access_expires=access_token.access_expires,
|
||||
refresh_token=refresh_token.refresh_token,
|
||||
refresh_expires=refresh_token.refresh_expires,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
from .user import (
|
||||
BatchDeleteRequest,
|
||||
JWTPayload,
|
||||
LoginRequest,
|
||||
RefreshTokenRequest,
|
||||
RegisterRequest,
|
||||
AccessTokenBase,
|
||||
RefreshTokenBase,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
UserStorageResponse,
|
||||
UserPublic,
|
||||
UserResponse,
|
||||
UserSettingResponse,
|
||||
@@ -34,7 +38,7 @@ from .node import (
|
||||
NodeType,
|
||||
)
|
||||
from .group import (
|
||||
Group, GroupBase, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
Group, GroupBase, GroupClaims, GroupOptions, GroupOptionsBase, GroupAllOptionsBase, GroupResponse,
|
||||
# 管理员DTO
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, GroupListResponse,
|
||||
)
|
||||
@@ -66,6 +70,7 @@ from .object import (
|
||||
FileBanRequest,
|
||||
)
|
||||
from .physical_file import PhysicalFile, PhysicalFileBase
|
||||
from .uri import DiskNextURI, FileSystemNamespace
|
||||
from .order import Order, OrderStatus, OrderType
|
||||
from .policy import Policy, PolicyBase, PolicyOptions, PolicyOptionsBase, PolicyType, PolicySummary
|
||||
from .redeem import Redeem, RedeemType
|
||||
@@ -82,7 +87,7 @@ from .tag import Tag, TagType
|
||||
from .task import Task, TaskProps, TaskPropsBase, TaskStatus, TaskType, TaskSummary
|
||||
from .webdav import WebDAV
|
||||
|
||||
from .database import engine, get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
from .model_base import (
|
||||
MCPBase,
|
||||
@@ -630,7 +630,7 @@ For developers modifying this module:
|
||||
- Handles Python 3.14 annotations via `get_type_hints()`
|
||||
|
||||
**Metaclass processing order**:
|
||||
1. Check if class should be a table (`_is_table_mixin`)
|
||||
1. Check if class should be a table (`_has_table_mixin`)
|
||||
2. Collect `__mapper_args__` from kwargs and explicit dict
|
||||
3. Process `table_args`, `table_name`, `abstract` parameters
|
||||
4. Resolve annotations using `get_type_hints()`
|
||||
@@ -5,8 +5,8 @@ SQLModel 基础模块
|
||||
- SQLModelBase: 所有 SQLModel 类的基类(真正的基类)
|
||||
|
||||
注意:
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 models.mixin
|
||||
TableBase, UUIDTableBase, PolymorphicBaseMixin 已迁移到 sqlmodels.mixin
|
||||
为了避免循环导入,此处不再重新导出它们
|
||||
请直接从 models.mixin 导入这些类
|
||||
请直接从 sqlmodels.mixin 导入这些类
|
||||
"""
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
@@ -414,7 +414,7 @@ class __DeclarativeMeta(SQLModelMetaclass):
|
||||
|
||||
def __new__(cls, name, bases, attrs, **kwargs):
|
||||
# 1. 约定优于配置:自动设置 table=True
|
||||
is_intended_as_table = any(getattr(b, '_is_table_mixin', False) for b in bases)
|
||||
is_intended_as_table = any(getattr(b, '_has_table_mixin', False) for b in bases)
|
||||
if is_intended_as_table and 'table' not in kwargs:
|
||||
kwargs['table'] = True
|
||||
|
||||
78
sqlmodels/database_connection.py
Normal file
78
sqlmodels/database_connection.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import AsyncGenerator, ClassVar
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import NullPool, AsyncAdaptedQueuePool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
engine: ClassVar[AsyncEngine | None] = None
|
||||
_async_session_factory: ClassVar[sessionmaker | None] = None
|
||||
|
||||
@classmethod
|
||||
async def get_session(cls) -> AsyncGenerator[AsyncSession]:
|
||||
assert cls._async_session_factory is not None, "数据库引擎未初始化,请先调用 DatabaseManager.init()"
|
||||
async with cls._async_session_factory() as session:
|
||||
yield session
|
||||
|
||||
@classmethod
|
||||
async def init(
|
||||
cls,
|
||||
database_url: str,
|
||||
debug: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化数据库连接引擎。
|
||||
|
||||
:param database_url: 数据库连接URL
|
||||
:param debug: 是否开启调试模式
|
||||
"""
|
||||
# 构建引擎参数
|
||||
engine_kwargs: dict = {
|
||||
'echo': debug,
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if debug:
|
||||
# Debug 模式使用 NullPool(无连接池,每次创建新连接)
|
||||
engine_kwargs['poolclass'] = NullPool
|
||||
else:
|
||||
# 生产模式使用 AsyncAdaptedQueuePool 连接池
|
||||
engine_kwargs.update({
|
||||
'poolclass': AsyncAdaptedQueuePool,
|
||||
'pool_size': 40,
|
||||
'max_overflow': 80,
|
||||
'pool_timeout': 30,
|
||||
'pool_recycle': 1800,
|
||||
'pool_pre_ping': True,
|
||||
})
|
||||
|
||||
# 只在需要时添加 connect_args
|
||||
if database_url.startswith("sqlite"):
|
||||
engine_kwargs['connect_args'] = {'check_same_thread': False}
|
||||
|
||||
cls.engine = create_async_engine(database_url, **engine_kwargs)
|
||||
|
||||
cls._async_session_factory = sessionmaker(cls.engine, class_=AsyncSession)
|
||||
|
||||
# 开发阶段直接 create_all 创建表结构
|
||||
async with cls.engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("数据库引擎初始化完成")
|
||||
|
||||
@classmethod
|
||||
async def close(cls):
|
||||
"""
|
||||
优雅地关闭数据库连接引擎。
|
||||
仅应在应用结束时调用。
|
||||
"""
|
||||
if cls.engine:
|
||||
logger.info("正在关闭数据库连接引擎...")
|
||||
await cls.engine.dispose()
|
||||
logger.info("数据库连接引擎已成功关闭。")
|
||||
else:
|
||||
logger.info("数据库连接引擎未初始化,无需关闭。")
|
||||
@@ -188,6 +188,28 @@ class GroupListResponse(SQLModelBase):
|
||||
"""总数"""
|
||||
|
||||
|
||||
class GroupClaims(GroupCoreBase, GroupAllOptionsBase):
|
||||
"""
|
||||
JWT 中的用户组权限快照。
|
||||
|
||||
复用 GroupCoreBase(id, name, max_storage, share_enabled, web_dav_enabled, admin, speed_limit)
|
||||
和 GroupAllOptionsBase(share_download, share_free, ... 共 11 个功能开关)。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_group(cls, group: "Group") -> "GroupClaims":
|
||||
"""
|
||||
从 Group ORM 对象(需预加载 options 关系)构建权限快照。
|
||||
|
||||
:param group: 已加载 options 的 Group 对象
|
||||
"""
|
||||
opts = group.options
|
||||
return cls(
|
||||
**GroupCoreBase.model_validate(group, from_attributes=True).model_dump(),
|
||||
**(GroupAllOptionsBase.model_validate(opts, from_attributes=True).model_dump() if opts else {}),
|
||||
)
|
||||
|
||||
|
||||
class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
"""用户组响应 DTO"""
|
||||
|
||||
@@ -29,6 +29,10 @@ default_settings: list[Setting] = [
|
||||
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
|
||||
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
|
||||
Setting(name="site_notice", value="", type=SettingsType.BASIC),
|
||||
Setting(name="footer_code", value="", type=SettingsType.BASIC),
|
||||
Setting(name="tos_url", value="", type=SettingsType.BASIC),
|
||||
Setting(name="privacy_url", value="", type=SettingsType.BASIC),
|
||||
Setting(name="fromName", value="DiskNext", type=SettingsType.MAIL),
|
||||
Setting(name="mail_keepalive", value="30", type=SettingsType.MAIL),
|
||||
Setting(name="fromAdress", value="no-reply@yxqi.cn", type=SettingsType.MAIL),
|
||||
@@ -104,9 +108,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
Setting(name="captcha_IsShowSlimeLine", value="1", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsShowSineLine", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CaptchaLen", value="6", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_IsUseReCaptcha", value="0", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="defaultKey", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="defaultSecret", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_type", value="default", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_ReCaptchaSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareKey", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="captcha_CloudflareSecret", value="", type=SettingsType.CAPTCHA),
|
||||
Setting(name="thumb_width", value="400", type=SettingsType.THUMB),
|
||||
Setting(name="thumb_height", value="300", type=SettingsType.THUMB),
|
||||
Setting(name="pwa_small_icon", value="/static/img/favicon.ico", type=SettingsType.PWA),
|
||||
@@ -119,11 +125,11 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
@@ -139,11 +145,11 @@ async def init_default_group() -> None:
|
||||
from .group import Group, GroupOptions
|
||||
from .policy import Policy, GroupPolicyLink
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 获取默认存储策略
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
default_policy_id = default_policy.id if default_policy else None
|
||||
@@ -231,13 +237,20 @@ async def init_default_user() -> None:
|
||||
from .group import Group
|
||||
from .object import Object, ObjectType
|
||||
from .policy import Policy
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.username == "admin")
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查管理员用户是否存在(通过 Setting 中的 default_admin_id 判断)
|
||||
admin_id_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
admin_user = None
|
||||
if admin_id_setting and admin_id_setting.value:
|
||||
from uuid import UUID
|
||||
admin_user = await User.get(session, User.id == UUID(admin_id_setting.value))
|
||||
|
||||
if not admin_user:
|
||||
# 获取管理员组
|
||||
@@ -256,18 +269,24 @@ async def init_default_user() -> None:
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
nickname="admin",
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
admin_username = admin_user.username
|
||||
await admin_user.save(session)
|
||||
|
||||
# 为管理员创建根目录(使用用户名作为目录名)
|
||||
# 记录默认管理员 ID 到 Setting
|
||||
await Setting(
|
||||
name="default_admin_id",
|
||||
value=str(admin_user_id),
|
||||
type=SettingsType.AUTH,
|
||||
).save(session)
|
||||
|
||||
# 为管理员创建根目录
|
||||
await Object(
|
||||
name=admin_username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user_id,
|
||||
parent_id=None,
|
||||
@@ -275,18 +294,18 @@ async def init_default_user() -> None:
|
||||
).save(session)
|
||||
|
||||
log.warning('请注意,账号密码仅显示一次,请妥善保管')
|
||||
log.info(f'初始管理员账号: admin')
|
||||
log.info(f'初始管理员邮箱: admin@disknext.local')
|
||||
log.info(f'初始管理员密码: {admin_password}')
|
||||
|
||||
|
||||
async def init_default_policy() -> None:
|
||||
from .policy import Policy, PolicyType
|
||||
from .database import get_session
|
||||
from .database_connection import DatabaseManager
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
log.info('初始化默认存储策略...')
|
||||
|
||||
async for session in get_session():
|
||||
async for session in DatabaseManager.get_session():
|
||||
# 检查默认存储策略是否存在
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
|
||||
@@ -5,42 +5,58 @@ SQLModel Mixin模块
|
||||
|
||||
包含:
|
||||
- polymorphic: 联表继承工具(create_subclass_id_mixin, AutoPolymorphicIdentityMixin, PolymorphicBaseMixin)
|
||||
- optimistic_lock: 乐观锁(OptimisticLockMixin, OptimisticLockError)
|
||||
- table: 表基类(TableBaseMixin, UUIDTableBaseMixin)
|
||||
- table: 查询参数类(TimeFilterRequest, PaginationRequest, TableViewRequest)
|
||||
- relation_preload: 关系预加载(RelationPreloadMixin, requires_relations)
|
||||
- jwt/: JWT认证相关(JWTAuthMixin, JWTManager, JWTKey等)- 需要时直接从 .jwt 导入
|
||||
- info_response: InfoResponse DTO的id/时间戳Mixin
|
||||
|
||||
导入顺序很重要,避免循环导入:
|
||||
1. polymorphic(只依赖 SQLModelBase)
|
||||
2. table(依赖 polymorphic)
|
||||
2. optimistic_lock(只依赖 SQLAlchemy)
|
||||
3. table(依赖 polymorphic 和 optimistic_lock)
|
||||
4. relation_preload(只依赖 SQLModelBase)
|
||||
|
||||
注意:jwt 模块不在此处导入,因为 jwt/manager.py 导入 ServerConfig,
|
||||
而 ServerConfig 导入本模块,会形成循环。需要 jwt 功能时请直接从 .jwt 导入。
|
||||
"""
|
||||
# polymorphic 必须先导入
|
||||
from .polymorphic import (
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
register_sti_column_properties_for_all_subclasses,
|
||||
register_sti_columns_for_all_subclasses,
|
||||
)
|
||||
# table 依赖 polymorphic
|
||||
# optimistic_lock 只依赖 SQLAlchemy,必须在 table 之前
|
||||
from .optimistic_lock import (
|
||||
OptimisticLockError,
|
||||
OptimisticLockMixin,
|
||||
)
|
||||
# table 依赖 polymorphic 和 optimistic_lock
|
||||
from .table import (
|
||||
TableBaseMixin,
|
||||
UUIDTableBaseMixin,
|
||||
TimeFilterRequest,
|
||||
PaginationRequest,
|
||||
TableViewRequest,
|
||||
ListResponse,
|
||||
PaginationRequest,
|
||||
T,
|
||||
TableBaseMixin,
|
||||
TableViewRequest,
|
||||
TimeFilterRequest,
|
||||
UUIDTableBaseMixin,
|
||||
now,
|
||||
now_date,
|
||||
)
|
||||
# relation_preload 只依赖 SQLModelBase
|
||||
from .relation_preload import (
|
||||
RelationPreloadMixin,
|
||||
requires_relations,
|
||||
)
|
||||
# jwt 不在此处导入(避免循环:jwt/manager.py → ServerConfig → mixin → jwt)
|
||||
# 需要时直接从 sqlmodels.mixin.jwt 导入
|
||||
from .info_response import (
|
||||
IntIdInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
DatetimeInfoMixin,
|
||||
IntIdDatetimeInfoMixin,
|
||||
IntIdInfoMixin,
|
||||
UUIDIdDatetimeInfoMixin,
|
||||
UUIDIdInfoMixin,
|
||||
)
|
||||
@@ -12,7 +12,7 @@ InfoResponse DTO Mixin模块
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
|
||||
class IntIdInfoMixin(SQLModelBase):
|
||||
90
sqlmodels/mixin/optimistic_lock.py
Normal file
90
sqlmodels/mixin/optimistic_lock.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
提供基于 SQLAlchemy version_id_col 机制的乐观锁支持。
|
||||
|
||||
乐观锁适用场景:
|
||||
- 涉及"状态转换"的表(如:待支付 -> 已支付)
|
||||
- 涉及"数值变动"的表(如:余额、库存)
|
||||
|
||||
不适用场景:
|
||||
- 日志表、纯插入表、低价值统计表
|
||||
- 能用 UPDATE table SET col = col + 1 解决的简单计数问题
|
||||
|
||||
使用示例:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
status: OrderStatusEnum
|
||||
amount: Decimal
|
||||
|
||||
# save/update 时自动检查版本号
|
||||
# 如果版本号不匹配(其他事务已修改),会抛出 OptimisticLockError
|
||||
try:
|
||||
order = await order.save(session)
|
||||
except OptimisticLockError as e:
|
||||
# 处理冲突:重新查询并重试,或报错给用户
|
||||
l.warning(f"乐观锁冲突: {e}")
|
||||
"""
|
||||
from typing import ClassVar
|
||||
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
|
||||
|
||||
class OptimisticLockError(Exception):
|
||||
"""
|
||||
乐观锁冲突异常
|
||||
|
||||
当 save/update 操作检测到版本号不匹配时抛出。
|
||||
这意味着在读取和写入之间,其他事务已经修改了该记录。
|
||||
|
||||
Attributes:
|
||||
model_class: 发生冲突的模型类名
|
||||
record_id: 记录 ID(如果可用)
|
||||
expected_version: 期望的版本号(如果可用)
|
||||
original_error: 原始的 StaleDataError
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
model_class: str | None = None,
|
||||
record_id: str | None = None,
|
||||
expected_version: int | None = None,
|
||||
original_error: StaleDataError | None = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.model_class = model_class
|
||||
self.record_id = record_id
|
||||
self.expected_version = expected_version
|
||||
self.original_error = original_error
|
||||
|
||||
|
||||
class OptimisticLockMixin:
|
||||
"""
|
||||
乐观锁 Mixin
|
||||
|
||||
使用 SQLAlchemy 的 version_id_col 机制实现乐观锁。
|
||||
每次 UPDATE 时自动检查并增加版本号,如果版本号不匹配(即其他事务已修改),
|
||||
session.commit() 会抛出 StaleDataError,被 save/update 方法捕获并转换为 OptimisticLockError。
|
||||
|
||||
原理:
|
||||
1. 每条记录有一个 version 字段,初始值为 0
|
||||
2. 每次 UPDATE 时,SQLAlchemy 生成的 SQL 类似:
|
||||
UPDATE table SET ..., version = version + 1 WHERE id = ? AND version = ?
|
||||
3. 如果 WHERE 条件不匹配(version 已被其他事务修改),
|
||||
UPDATE 影响 0 行,SQLAlchemy 抛出 StaleDataError
|
||||
|
||||
继承顺序:
|
||||
OptimisticLockMixin 必须放在 TableBaseMixin/UUIDTableBaseMixin 之前:
|
||||
class Order(OptimisticLockMixin, UUIDTableBaseMixin, table=True):
|
||||
...
|
||||
|
||||
配套重试:
|
||||
如果加了乐观锁,业务层需要处理 OptimisticLockError:
|
||||
- 报错给用户:"数据已被修改,请刷新后重试"
|
||||
- 自动重试:重新查询最新数据并再次尝试
|
||||
"""
|
||||
_has_optimistic_lock: ClassVar[bool] = True
|
||||
"""标记此类启用了乐观锁"""
|
||||
|
||||
version: int = 0
|
||||
"""乐观锁版本号,每次更新自动递增"""
|
||||
710
sqlmodels/mixin/polymorphic.py
Normal file
710
sqlmodels/mixin/polymorphic.py
Normal file
@@ -0,0 +1,710 @@
|
||||
"""
|
||||
联表继承(Joined Table Inheritance)的通用工具
|
||||
|
||||
提供用于简化SQLModel多态表设计的辅助函数和Mixin。
|
||||
|
||||
Usage Example:
|
||||
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import (
|
||||
PolymorphicBaseMixin,
|
||||
create_subclass_id_mixin,
|
||||
AutoPolymorphicIdentityMixin
|
||||
)
|
||||
|
||||
# 1. 定义Base类(只有字段,无表)
|
||||
class ASRBase(SQLModelBase):
|
||||
name: str
|
||||
\"\"\"配置名称\"\"\"
|
||||
|
||||
base_url: str
|
||||
\"\"\"服务地址\"\"\"
|
||||
|
||||
# 2. 定义抽象父类(有表),使用 PolymorphicBaseMixin
|
||||
class ASR(
|
||||
ASRBase,
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC
|
||||
):
|
||||
\"\"\"ASR配置的抽象基类\"\"\"
|
||||
# PolymorphicBaseMixin 自动提供:
|
||||
# - _polymorphic_name 字段
|
||||
# - polymorphic_on='_polymorphic_name'
|
||||
# - polymorphic_abstract=True(当有抽象方法时)
|
||||
|
||||
# 3. 为第二层子类创建ID Mixin
|
||||
ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
|
||||
# 4. 创建第二层抽象类(如果需要)
|
||||
class FunASR(
|
||||
ASRSubclassIdMixin,
|
||||
ASR,
|
||||
AutoPolymorphicIdentityMixin,
|
||||
polymorphic_abstract=True
|
||||
):
|
||||
\"\"\"FunASR的抽象基类,可能有多个实现\"\"\"
|
||||
pass
|
||||
|
||||
# 5. 创建具体实现类
|
||||
class FunASRLocal(FunASR, table=True):
|
||||
\"\"\"FunASR本地部署版本\"\"\"
|
||||
# polymorphic_identity 会自动设置为 'asr.funasrlocal'
|
||||
pass
|
||||
|
||||
# 6. 获取所有具体子类(用于 selectin_polymorphic)
|
||||
concrete_asrs = ASR.get_concrete_subclasses()
|
||||
# 返回 [FunASRLocal, ...]
|
||||
"""
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from uuid import UUID
|
||||
|
||||
from loguru import logger as l
|
||||
from pydantic.fields import FieldInfo
|
||||
from pydantic_core import PydanticUndefined
|
||||
from sqlalchemy import Column, String, inspect
|
||||
from sqlalchemy.orm import ColumnProperty, Mapped, mapped_column
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlmodel import Field
|
||||
from sqlmodel.main import get_column_from_field
|
||||
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# 用于延迟注册 STI 子类列的队列
|
||||
# 在所有模型加载完成后,调用 register_sti_columns_for_all_subclasses() 处理
|
||||
_sti_subclasses_to_register: list[type] = []
|
||||
|
||||
|
||||
def register_sti_columns_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类执行列注册(第一阶段:添加列到表)
|
||||
|
||||
此函数应在 configure_mappers() 之前调用。
|
||||
将 STI 子类的字段添加到父表的 metadata 中。
|
||||
同时修复被 Column 对象污染的 model_fields。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_columns()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列时出错: {e}")
|
||||
|
||||
# 修复被 Column 对象污染的 model_fields
|
||||
# 必须在列注册后立即修复,因为 Column 污染在类定义时就已发生
|
||||
try:
|
||||
_fix_polluted_model_fields(cls)
|
||||
except Exception as e:
|
||||
l.warning(f"修复 STI 子类 {cls.__name__} 的 model_fields 时出错: {e}")
|
||||
|
||||
|
||||
def register_sti_column_properties_for_all_subclasses() -> None:
|
||||
"""
|
||||
为所有已注册的 STI 子类添加列属性到 mapper(第二阶段)
|
||||
|
||||
此函数应在 configure_mappers() 之后调用。
|
||||
将 STI 子类的字段作为 ColumnProperty 添加到 mapper 中。
|
||||
"""
|
||||
for cls in _sti_subclasses_to_register:
|
||||
try:
|
||||
cls._register_sti_column_properties()
|
||||
except Exception as e:
|
||||
l.warning(f"注册 STI 子类 {cls.__name__} 的列属性时出错: {e}")
|
||||
|
||||
# 清空队列
|
||||
_sti_subclasses_to_register.clear()
|
||||
|
||||
|
||||
def _fix_polluted_model_fields(cls: type) -> None:
|
||||
"""
|
||||
修复被 SQLAlchemy InstrumentedAttribute 或 Column 污染的 model_fields
|
||||
|
||||
当 SQLModel 类继承有表的父类时,SQLAlchemy 会在类上创建 InstrumentedAttribute
|
||||
或 Column 对象替换原始的字段默认值。这会导致 Pydantic 在构建子类 model_fields
|
||||
时错误地使用这些 SQLAlchemy 对象作为默认值。
|
||||
|
||||
此函数从 MRO 中查找原始的字段定义,并修复被污染的 model_fields。
|
||||
|
||||
:param cls: 要修复的类
|
||||
"""
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
def find_original_field_info(field_name: str) -> FieldInfo | None:
|
||||
"""从 MRO 中查找字段的原始定义(未被污染的)"""
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, 'model_fields') and field_name in base.model_fields:
|
||||
field_info = base.model_fields[field_name]
|
||||
# 跳过被 InstrumentedAttribute 或 Column 污染的
|
||||
if not isinstance(field_info.default, (InstrumentedAttribute, Column)):
|
||||
return field_info
|
||||
return None
|
||||
|
||||
for field_name, current_field in cls.model_fields.items():
|
||||
# 检查是否被污染(default 是 InstrumentedAttribute 或 Column)
|
||||
# Column 污染发生在 STI 继承链中:当 FunctionBase.show_arguments = True
|
||||
# 被继承到有表的子类时,SQLModel 会创建一个 Column 对象替换原始默认值
|
||||
if not isinstance(current_field.default, (InstrumentedAttribute, Column)):
|
||||
continue # 未被污染,跳过
|
||||
|
||||
# 从父类查找原始定义
|
||||
original = find_original_field_info(field_name)
|
||||
if original is None:
|
||||
continue # 找不到原始定义,跳过
|
||||
|
||||
# 根据原始定义的 default/default_factory 来修复
|
||||
if original.default_factory:
|
||||
# 有 default_factory(如 uuid.uuid4, now)
|
||||
new_field = FieldInfo(
|
||||
default_factory=original.default_factory,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
elif original.default is not PydanticUndefined:
|
||||
# 有明确的 default 值(如 None, 0, True),且不是 PydanticUndefined
|
||||
# PydanticUndefined 表示字段没有默认值(必填)
|
||||
new_field = FieldInfo(
|
||||
default=original.default,
|
||||
annotation=current_field.annotation,
|
||||
json_schema_extra=current_field.json_schema_extra,
|
||||
)
|
||||
else:
|
||||
continue # 既没有 default_factory 也没有有效的 default,跳过
|
||||
|
||||
# 复制 SQLModel 特有的属性
|
||||
if hasattr(current_field, 'foreign_key'):
|
||||
new_field.foreign_key = current_field.foreign_key
|
||||
if hasattr(current_field, 'primary_key'):
|
||||
new_field.primary_key = current_field.primary_key
|
||||
|
||||
cls.model_fields[field_name] = new_field
|
||||
|
||||
|
||||
def create_subclass_id_mixin(parent_table_name: str) -> type['SQLModelBase']:
|
||||
"""
|
||||
动态创建SubclassIdMixin类
|
||||
|
||||
在联表继承中,子类需要一个外键指向父表的主键。
|
||||
此函数生成一个Mixin类,提供这个外键字段,并自动生成UUID。
|
||||
|
||||
Args:
|
||||
parent_table_name: 父表名称(如'asr', 'tts', 'tool', 'function')
|
||||
|
||||
Returns:
|
||||
一个Mixin类,包含id字段(外键 + 主键 + default_factory=uuid.uuid4)
|
||||
|
||||
Example:
|
||||
>>> ASRSubclassIdMixin = create_subclass_id_mixin('asr')
|
||||
>>> class FunASR(ASRSubclassIdMixin, ASR, table=True):
|
||||
... pass
|
||||
|
||||
Note:
|
||||
- 生成的Mixin应该放在继承列表的第一位,确保通过MRO覆盖UUIDTableBaseMixin的id
|
||||
- 生成的类名为 {ParentTableName}SubclassIdMixin(PascalCase)
|
||||
- 本项目所有联表继承均使用UUID主键(UUIDTableBaseMixin)
|
||||
"""
|
||||
if not parent_table_name:
|
||||
raise ValueError("parent_table_name 不能为空")
|
||||
|
||||
# 转换为PascalCase作为类名
|
||||
class_name_parts = parent_table_name.split('_')
|
||||
class_name = ''.join(part.capitalize() for part in class_name_parts) + 'SubclassIdMixin'
|
||||
|
||||
# 使用闭包捕获parent_table_name
|
||||
_parent_table_name = parent_table_name
|
||||
|
||||
# 创建带有__init_subclass__的mixin类,用于在子类定义后修复model_fields
|
||||
class SubclassIdMixin(SQLModelBase):
|
||||
# 定义id字段
|
||||
id: UUID = Field(
|
||||
default_factory=uuid.uuid4,
|
||||
foreign_key=f'{_parent_table_name}.id',
|
||||
primary_key=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复联表继承中子类字段的 default_factory 丢失问题。
|
||||
SQLAlchemy 的 InstrumentedAttribute 或 Column 会污染从父类继承的字段,
|
||||
导致 INSERT 语句中出现 `table.column` 引用而非实际值。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
# 设置类名和文档
|
||||
SubclassIdMixin.__name__ = class_name
|
||||
SubclassIdMixin.__qualname__ = class_name
|
||||
SubclassIdMixin.__doc__ = f"""
|
||||
{parent_table_name}子类的ID Mixin
|
||||
|
||||
用于{parent_table_name}的子类,提供外键指向父表。
|
||||
通过MRO确保此id字段覆盖继承的id字段。
|
||||
"""
|
||||
|
||||
return SubclassIdMixin
|
||||
|
||||
|
||||
class AutoPolymorphicIdentityMixin:
|
||||
"""
|
||||
自动生成polymorphic_identity的Mixin,并支持STI子类列注册
|
||||
|
||||
使用此Mixin的类会自动根据类名生成polymorphic_identity。
|
||||
格式:{parent_polymorphic_identity}.{classname_lowercase}
|
||||
|
||||
如果没有父类的polymorphic_identity,则直接使用类名小写。
|
||||
|
||||
**重要:数据库迁移注意事项**
|
||||
|
||||
编写数据迁移脚本时,必须使用完整的 polymorphic_identity 格式,包括父类前缀!
|
||||
|
||||
例如,对于以下继承链::
|
||||
|
||||
LLM (polymorphic_on='_polymorphic_name')
|
||||
└── AnthropicCompatibleLLM (polymorphic_identity='anthropiccompatiblellm')
|
||||
└── TuziAnthropicLLM (polymorphic_identity='anthropiccompatiblellm.tuzianthropicllm')
|
||||
|
||||
迁移脚本中设置 _polymorphic_name 时::
|
||||
|
||||
# ❌ 错误:缺少父类前缀
|
||||
UPDATE llm SET _polymorphic_name = 'tuzianthropicllm' WHERE id = :id
|
||||
|
||||
# ✅ 正确:包含完整的继承链前缀
|
||||
UPDATE llm SET _polymorphic_name = 'anthropiccompatiblellm.tuzianthropicllm' WHERE id = :id
|
||||
|
||||
**STI(单表继承)支持**:
|
||||
当子类与父类共用同一张表(STI模式)时,此Mixin会自动将子类的新字段
|
||||
添加到父表的列定义中。这解决了SQLModel在STI模式下子类字段不被
|
||||
注册到父表的问题。
|
||||
|
||||
Example (JTI):
|
||||
>>> class Tool(UUIDTableBaseMixin, polymorphic_on='__polymorphic_name', polymorphic_abstract=True):
|
||||
... __polymorphic_name: str
|
||||
...
|
||||
>>> class Function(Tool, AutoPolymorphicIdentityMixin, polymorphic_abstract=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function'
|
||||
...
|
||||
>>> class CodeInterpreterFunction(Function, table=True):
|
||||
... pass
|
||||
... # polymorphic_identity 会自动设置为 'function.codeinterpreterfunction'
|
||||
|
||||
Example (STI):
|
||||
>>> class UserFile(UUIDTableBaseMixin, PolymorphicBaseMixin, table=True, polymorphic_abstract=True):
|
||||
... user_id: UUID
|
||||
...
|
||||
>>> class PendingFile(UserFile, AutoPolymorphicIdentityMixin, table=True):
|
||||
... upload_deadline: datetime | None = None # 自动添加到 userfile 表
|
||||
... # polymorphic_identity 会自动设置为 'pendingfile'
|
||||
|
||||
Note:
|
||||
- 如果手动在__mapper_args__中指定了polymorphic_identity,会被保留
|
||||
- 此Mixin应该在继承列表中靠后的位置(在表基类之前)
|
||||
- STI模式下,新字段会在类定义时自动添加到父表的metadata中
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, polymorphic_identity: str | None = None, **kwargs):
|
||||
"""
|
||||
子类化钩子,自动生成polymorphic_identity并处理STI列注册
|
||||
|
||||
Args:
|
||||
polymorphic_identity: 如果手动指定,则使用指定的值
|
||||
**kwargs: 其他SQLModel参数(如table=True, polymorphic_abstract=True)
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 如果手动指定了polymorphic_identity,使用指定的值
|
||||
if polymorphic_identity is not None:
|
||||
identity = polymorphic_identity
|
||||
else:
|
||||
# 自动生成polymorphic_identity
|
||||
class_name = cls.__name__.lower()
|
||||
|
||||
# 尝试从父类获取polymorphic_identity作为前缀
|
||||
parent_identity = None
|
||||
for base in cls.__mro__[1:]: # 跳过自己
|
||||
if hasattr(base, '__mapper_args__') and isinstance(base.__mapper_args__, dict):
|
||||
parent_identity = base.__mapper_args__.get('polymorphic_identity')
|
||||
if parent_identity:
|
||||
break
|
||||
|
||||
# 构建identity
|
||||
if parent_identity:
|
||||
identity = f'{parent_identity}.{class_name}'
|
||||
else:
|
||||
identity = class_name
|
||||
|
||||
# 设置到__mapper_args__
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 只在尚未设置polymorphic_identity时设置
|
||||
if 'polymorphic_identity' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_identity'] = identity
|
||||
|
||||
# 注册 STI 子类列的延迟执行
|
||||
# 由于 __init_subclass__ 在类定义过程中被调用,此时 model_fields 还不完整
|
||||
# 需要在模块加载完成后调用 register_sti_columns_for_all_subclasses()
|
||||
_sti_subclasses_to_register.append(cls)
|
||||
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
Pydantic v2 的子类初始化钩子,在模型完全构建后调用
|
||||
|
||||
修复 STI 继承中子类字段被 Column 对象污染的问题。
|
||||
当 FunctionBase.show_arguments = True 等字段被继承到有表的子类时,
|
||||
SQLModel 会创建一个 Column 对象替换原始默认值,导致实例化时字段值不正确。
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
_fix_polluted_model_fields(cls)
|
||||
|
||||
@classmethod
|
||||
def _register_sti_columns(cls) -> None:
|
||||
"""
|
||||
将STI子类的新字段注册到父表的列定义中
|
||||
|
||||
检测当前类是否是STI子类(与父类共用同一张表),
|
||||
如果是,则将子类定义的新字段添加到父表的metadata中。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_fields: set[str] = set()
|
||||
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
# 收集父类的所有字段名
|
||||
if hasattr(base, 'model_fields'):
|
||||
parent_fields.update(base.model_fields.keys())
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列注册到父表
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列注册
|
||||
|
||||
# 获取当前类的新字段(不在父类中的字段)
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
existing_columns = {col.name for col in parent_table.columns}
|
||||
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段和ClassVar
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 跳过已存在的列
|
||||
if field_name in existing_columns:
|
||||
continue
|
||||
|
||||
# 使用 SQLModel 的内置 API 创建列
|
||||
try:
|
||||
column = get_column_from_field(field_info)
|
||||
column.name = field_name
|
||||
column.key = field_name
|
||||
# STI子类字段在数据库层面必须可空,因为其他子类的行不会有这些字段的值
|
||||
# Pydantic层面的约束仍然有效(创建特定子类时会验证必填字段)
|
||||
column.nullable = True
|
||||
|
||||
# 将列添加到父表
|
||||
parent_table.append_column(column)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 创建列 {field_name} 失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def _register_sti_column_properties(cls) -> None:
|
||||
"""
|
||||
将 STI 子类的列作为 ColumnProperty 添加到 mapper
|
||||
|
||||
此方法在 configure_mappers() 之后调用,将已添加到表中的列
|
||||
注册为 mapper 的属性,使 ORM 查询能正确识别这些列。
|
||||
|
||||
**重要**:子类的列属性会同时注册到子类和父类的 mapper 上。
|
||||
这确保了查询父类时,SELECT 语句包含所有 STI 子类的列,
|
||||
避免在响应序列化时触发懒加载(MissingGreenlet 错误)。
|
||||
|
||||
JTI(联表继承)类会被自动跳过,因为它们有自己独立的表。
|
||||
"""
|
||||
# 查找父表和父类(在 MRO 中找到第一个有 __table__ 的父类)
|
||||
parent_table = None
|
||||
parent_class = None
|
||||
for base in cls.__mro__[1:]:
|
||||
if hasattr(base, '__table__') and base.__table__ is not None:
|
||||
parent_table = base.__table__
|
||||
parent_class = base
|
||||
break
|
||||
|
||||
if parent_table is None:
|
||||
return # 没有找到父表,可能是根类
|
||||
|
||||
# JTI 检测:如果当前类有自己的表且与父表不同,则是 JTI
|
||||
# JTI 类有自己独立的表,不需要将列属性注册到 mapper
|
||||
if hasattr(cls, '__table__') and cls.__table__ is not None:
|
||||
if cls.__table__.name != parent_table.name:
|
||||
return # JTI,跳过 STI 列属性注册
|
||||
|
||||
# 获取子类和父类的 mapper
|
||||
child_mapper = inspect(cls).mapper
|
||||
parent_mapper = inspect(parent_class).mapper
|
||||
local_table = child_mapper.local_table
|
||||
|
||||
# 查找父类的所有字段名
|
||||
parent_fields: set[str] = set()
|
||||
if hasattr(parent_class, 'model_fields'):
|
||||
parent_fields.update(parent_class.model_fields.keys())
|
||||
|
||||
if not hasattr(cls, 'model_fields'):
|
||||
return
|
||||
|
||||
# 获取两个 mapper 已有的列属性
|
||||
child_existing_props = {p.key for p in child_mapper.column_attrs}
|
||||
parent_existing_props = {p.key for p in parent_mapper.column_attrs}
|
||||
|
||||
for field_name in cls.model_fields:
|
||||
# 跳过从父类继承的字段
|
||||
if field_name in parent_fields:
|
||||
continue
|
||||
|
||||
# 跳过私有字段
|
||||
if field_name.startswith('_'):
|
||||
continue
|
||||
|
||||
# 检查表中是否有这个列
|
||||
if field_name not in local_table.columns:
|
||||
continue
|
||||
|
||||
column = local_table.columns[field_name]
|
||||
|
||||
# 添加到子类的 mapper(如果尚不存在)
|
||||
if field_name not in child_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
child_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为 {cls.__name__} 添加列属性 {field_name} 失败: {e}")
|
||||
|
||||
# 同时添加到父类的 mapper(确保查询父类时 SELECT 包含所有 STI 子类的列)
|
||||
if field_name not in parent_existing_props:
|
||||
try:
|
||||
prop = ColumnProperty(column)
|
||||
parent_mapper.add_property(field_name, prop)
|
||||
except Exception as e:
|
||||
l.warning(f"为父类 {parent_class.__name__} 添加子类 {cls.__name__} 的列属性 {field_name} 失败: {e}")
|
||||
|
||||
|
||||
class PolymorphicBaseMixin:
|
||||
"""
|
||||
为联表继承链中的基类自动配置 polymorphic 设置的 Mixin
|
||||
|
||||
此 Mixin 自动设置以下内容:
|
||||
- `polymorphic_on='_polymorphic_name'`: 使用 _polymorphic_name 字段作为多态鉴别器
|
||||
- `_polymorphic_name: str`: 定义多态鉴别器字段(带索引)
|
||||
- `polymorphic_abstract=True`: 当类继承自 ABC 且有抽象方法时,自动标记为抽象类
|
||||
|
||||
使用场景:
|
||||
适用于需要 joined table inheritance 的基类,例如 Tool、ASR、TTS 等。
|
||||
|
||||
用法示例:
|
||||
```python
|
||||
from abc import ABC
|
||||
from sqlmodels.mixin import UUIDTableBaseMixin
|
||||
from sqlmodels.mixin.polymorphic import PolymorphicBaseMixin
|
||||
|
||||
# 定义基类
|
||||
class MyTool(UUIDTableBaseMixin, PolymorphicBaseMixin, ABC):
|
||||
__tablename__ = 'mytool'
|
||||
|
||||
# 不需要手动定义 _polymorphic_name
|
||||
# 不需要手动设置 polymorphic_on
|
||||
# 不需要手动设置 polymorphic_abstract
|
||||
|
||||
# 定义子类
|
||||
class SpecificTool(MyTool):
|
||||
__tablename__ = 'specifictool'
|
||||
|
||||
# 会自动继承 polymorphic 配置
|
||||
```
|
||||
|
||||
自动行为:
|
||||
1. 定义 `_polymorphic_name: str` 字段(带索引)
|
||||
2. 设置 `__mapper_args__['polymorphic_on'] = '_polymorphic_name'`
|
||||
3. 自动检测抽象类:
|
||||
- 如果类继承了 ABC 且有未实现的抽象方法,设置 polymorphic_abstract=True
|
||||
- 否则设置为 False
|
||||
|
||||
手动覆盖:
|
||||
可以在类定义时手动指定参数来覆盖自动行为:
|
||||
```python
|
||||
class MyTool(
|
||||
UUIDTableBaseMixin,
|
||||
PolymorphicBaseMixin,
|
||||
ABC,
|
||||
polymorphic_on='custom_field', # 覆盖默认的 _polymorphic_name
|
||||
polymorphic_abstract=False # 强制不设为抽象类
|
||||
):
|
||||
pass
|
||||
```
|
||||
|
||||
注意事项:
|
||||
- 此 Mixin 应该与 UUIDTableBaseMixin 或 TableBaseMixin 配合使用
|
||||
- 适用于联表继承(joined table inheritance)场景
|
||||
- 子类会自动继承 _polymorphic_name 字段定义
|
||||
- 使用单下划线前缀是因为:
|
||||
* SQLAlchemy 会映射单下划线字段为数据库列
|
||||
* Pydantic 将其视为私有属性,不参与序列化
|
||||
* 双下划线字段会被 SQLAlchemy 排除,不映射为数据库列
|
||||
"""
|
||||
|
||||
# 定义 _polymorphic_name 字段,所有使用此 mixin 的类都会有这个字段
|
||||
#
|
||||
# 设计选择:使用单下划线前缀 + Mapped[str] + mapped_column
|
||||
#
|
||||
# 为什么这样做:
|
||||
# 1. 单下划线前缀表示"内部实现细节",防止外部通过 API 直接修改
|
||||
# 2. Mapped + mapped_column 绕过 Pydantic v2 的字段名限制(不允许下划线前缀)
|
||||
# 3. 字段仍然被 SQLAlchemy 映射到数据库,供多态查询使用
|
||||
# 4. 字段不出现在 Pydantic 序列化中(model_dump() 和 JSON schema)
|
||||
# 5. 内部代码仍然可以正常访问和修改此字段
|
||||
#
|
||||
# 详细说明请参考:sqlmodels/base/POLYMORPHIC_NAME_DESIGN.md
|
||||
_polymorphic_name: Mapped[str] = mapped_column(String, index=True)
|
||||
"""
|
||||
多态鉴别器字段,用于标识具体的子类类型
|
||||
|
||||
注意:此字段使用单下划线前缀,表示内部使用。
|
||||
- ✅ 存储到数据库
|
||||
- ✅ 不出现在 API 序列化中
|
||||
- ✅ 防止外部直接修改
|
||||
"""
|
||||
|
||||
def __init_subclass__(
|
||||
cls,
|
||||
polymorphic_on: str | None = None,
|
||||
polymorphic_abstract: bool | None = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
在子类定义时自动配置 polymorphic 设置
|
||||
|
||||
Args:
|
||||
polymorphic_on: polymorphic_on 字段名,默认为 '_polymorphic_name'。
|
||||
设置为其他值可以使用不同的字段作为多态鉴别器。
|
||||
polymorphic_abstract: 是否为抽象类。
|
||||
- None: 自动检测(默认)
|
||||
- True: 强制设为抽象类
|
||||
- False: 强制设为非抽象类
|
||||
**kwargs: 传递给父类的其他参数
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 初始化 __mapper_args__(如果还没有)
|
||||
if '__mapper_args__' not in cls.__dict__:
|
||||
cls.__mapper_args__ = {}
|
||||
|
||||
# 设置 polymorphic_on(默认为 _polymorphic_name)
|
||||
if 'polymorphic_on' not in cls.__mapper_args__:
|
||||
cls.__mapper_args__['polymorphic_on'] = polymorphic_on or '_polymorphic_name'
|
||||
|
||||
# 自动检测或设置 polymorphic_abstract
|
||||
if 'polymorphic_abstract' not in cls.__mapper_args__:
|
||||
if polymorphic_abstract is None:
|
||||
# 自动检测:如果继承了 ABC 且有抽象方法,则为抽象类
|
||||
has_abc = ABC in cls.__mro__
|
||||
has_abstract_methods = bool(getattr(cls, '__abstractmethods__', set()))
|
||||
polymorphic_abstract = has_abc and has_abstract_methods
|
||||
|
||||
cls.__mapper_args__['polymorphic_abstract'] = polymorphic_abstract
|
||||
|
||||
@classmethod
|
||||
def _is_joined_table_inheritance(cls) -> bool:
|
||||
"""
|
||||
检测当前类是否使用联表继承(Joined Table Inheritance)
|
||||
|
||||
通过检查子类是否有独立的表来判断:
|
||||
- JTI: 子类有独立的 local_table(与父类不同)
|
||||
- STI: 子类与父类共用同一个 local_table
|
||||
|
||||
:return: True 表示 JTI,False 表示 STI 或无子类
|
||||
"""
|
||||
mapper = inspect(cls)
|
||||
base_table_name = mapper.local_table.name
|
||||
|
||||
# 检查所有直接子类
|
||||
for subclass in cls.__subclasses__():
|
||||
sub_mapper = inspect(subclass)
|
||||
# 如果任何子类有不同的表名,说明是 JTI
|
||||
if sub_mapper.local_table.name != base_table_name:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_concrete_subclasses(cls) -> list[type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
递归获取当前类的所有具体(非抽象)子类
|
||||
|
||||
用于 selectin_polymorphic 加载策略,自动检测联表继承的所有具体子类。
|
||||
可在任意多态基类上调用,返回该类的所有非抽象子类。
|
||||
|
||||
:return: 所有具体子类的列表(不包含 polymorphic_abstract=True 的抽象类)
|
||||
"""
|
||||
result: list[type[PolymorphicBaseMixin]] = []
|
||||
for subclass in cls.__subclasses__():
|
||||
# 使用 inspect() 获取 mapper 的公开属性
|
||||
# 源码确认: mapper.polymorphic_abstract 是公开属性 (mapper.py:811)
|
||||
mapper = inspect(subclass)
|
||||
if not mapper.polymorphic_abstract:
|
||||
result.append(subclass)
|
||||
# 无论是否抽象,都需要递归(抽象类可能有具体子类)
|
||||
if hasattr(subclass, 'get_concrete_subclasses'):
|
||||
result.extend(subclass.get_concrete_subclasses())
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_polymorphic_discriminator(cls) -> str:
|
||||
"""
|
||||
获取多态鉴别字段名
|
||||
|
||||
使用 SQLAlchemy inspect 从 mapper 获取,支持从子类调用。
|
||||
|
||||
:return: 多态鉴别字段名(如 '_polymorphic_name')
|
||||
:raises ValueError: 如果类未配置 polymorphic_on
|
||||
"""
|
||||
polymorphic_on = inspect(cls).polymorphic_on
|
||||
if polymorphic_on is None:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} 未配置 polymorphic_on,"
|
||||
f"请确保正确继承 PolymorphicBaseMixin"
|
||||
)
|
||||
return polymorphic_on.key
|
||||
|
||||
@classmethod
|
||||
def get_identity_to_class_map(cls) -> dict[str, type['PolymorphicBaseMixin']]:
|
||||
"""
|
||||
获取 polymorphic_identity 到具体子类的映射
|
||||
|
||||
包含所有层级的具体子类(如 Function 和 ModelSwitchFunction 都会被包含)。
|
||||
|
||||
:return: identity 到子类的映射字典
|
||||
"""
|
||||
result: dict[str, type[PolymorphicBaseMixin]] = {}
|
||||
for subclass in cls.get_concrete_subclasses():
|
||||
identity = inspect(subclass).polymorphic_identity
|
||||
if identity:
|
||||
result[identity] = subclass
|
||||
return result
|
||||
470
sqlmodels/mixin/relation_preload.py
Normal file
470
sqlmodels/mixin/relation_preload.py
Normal file
@@ -0,0 +1,470 @@
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供方法级别的关系声明和按需增量加载,避免 MissingGreenlet 错误,同时保证 SQL 查询数理论最优。
|
||||
|
||||
设计原则:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 查询最优:相同关系只查询一次,不同关系增量查询
|
||||
- 零侵入:调用方无需任何改动
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实加载状态,自动处理 expire
|
||||
|
||||
使用方式:
|
||||
from sqlmodels.mixin import RelationPreloadMixin, requires_relations
|
||||
|
||||
class KlingO1VideoFunction(RelationPreloadMixin, Function, table=True):
|
||||
kling_video_generator: KlingO1Generator = Relationship(...)
|
||||
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# 自动加载,可以安全访问
|
||||
price = self.kling_video_generator.kling_o1.pro_price_per_second
|
||||
...
|
||||
|
||||
# 调用方 - 无需任何改动
|
||||
await tool.cost(params, context, session) # 自动加载 cost 需要的关系
|
||||
await tool._call(...) # 关系相同则跳过,否则增量加载
|
||||
|
||||
支持 AsyncGenerator:
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # 装饰器正确处理 async generator
|
||||
"""
|
||||
import inspect as python_inspect
|
||||
from functools import wraps
|
||||
from typing import Callable, TypeVar, ParamSpec, Any
|
||||
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
P = ParamSpec('P')
|
||||
R = TypeVar('R')
|
||||
|
||||
|
||||
def _extract_session(
|
||||
func: Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> AsyncSession | None:
|
||||
"""
|
||||
从方法参数中提取 AsyncSession
|
||||
|
||||
按以下顺序查找:
|
||||
1. kwargs 中名为 'session' 的参数
|
||||
2. 根据函数签名定位 'session' 参数的位置,从 args 提取
|
||||
3. kwargs 中类型为 AsyncSession 的参数
|
||||
"""
|
||||
# 1. 优先从 kwargs 查找
|
||||
if 'session' in kwargs:
|
||||
return kwargs['session']
|
||||
|
||||
# 2. 从函数签名定位位置参数
|
||||
try:
|
||||
sig = python_inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
if 'session' in param_names:
|
||||
# 计算位置(减去 self)
|
||||
idx = param_names.index('session') - 1
|
||||
if 0 <= idx < len(args):
|
||||
return args[idx]
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# 3. 遍历 kwargs 找 AsyncSession 类型
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, AsyncSession):
|
||||
return value
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_obj_relation_loaded(obj: Any, rel_name: str) -> bool:
|
||||
"""
|
||||
检查对象的关系是否已加载(独立函数版本)
|
||||
|
||||
Args:
|
||||
obj: 要检查的对象
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(obj)
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _find_relation_to_class(from_class: type, to_class: type) -> str | None:
|
||||
"""
|
||||
在类中查找指向目标类的关系属性名
|
||||
|
||||
Args:
|
||||
from_class: 源类
|
||||
to_class: 目标类
|
||||
|
||||
Returns:
|
||||
关系属性名,如果找不到则返回 None
|
||||
|
||||
Example:
|
||||
_find_relation_to_class(KlingO1VideoFunction, KlingO1Generator)
|
||||
# 返回 'kling_video_generator'
|
||||
"""
|
||||
for attr_name in dir(from_class):
|
||||
try:
|
||||
attr = getattr(from_class, attr_name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
# 检查是否是 SQLAlchemy InstrumentedAttribute(关系属性)
|
||||
# parent.class_ 是关系所在的类,property.mapper.class_ 是关系指向的目标类
|
||||
if hasattr(attr, 'property') and hasattr(attr.property, 'mapper'):
|
||||
target_class = attr.property.mapper.class_
|
||||
if target_class == to_class:
|
||||
return attr_name
|
||||
except AttributeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def requires_relations(*relations: str | RelationshipInfo) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
"""
|
||||
装饰器:声明方法需要的关系,自动按需增量加载
|
||||
|
||||
参数格式:
|
||||
- 字符串:本类属性名,如 'kling_video_generator'
|
||||
- RelationshipInfo:外部类属性,如 KlingO1Generator.kling_o1
|
||||
|
||||
行为:
|
||||
- 方法调用时自动检查关系是否已加载
|
||||
- 未加载的关系会被增量加载(单次查询)
|
||||
- 已加载的关系直接跳过
|
||||
|
||||
支持:
|
||||
- 普通 async 方法:`async def cost(...) -> ToolCost`
|
||||
- AsyncGenerator 方法:`async def _call(...) -> AsyncGenerator[ToolResponse, None]`
|
||||
|
||||
Example:
|
||||
@requires_relations('kling_video_generator', KlingO1Generator.kling_o1)
|
||||
async def cost(self, params, context, session) -> ToolCost:
|
||||
# self.kling_video_generator.kling_o1 已自动加载
|
||||
...
|
||||
|
||||
@requires_relations('twitter_api')
|
||||
async def _call(self, ...) -> AsyncGenerator[ToolResponse, None]:
|
||||
yield ToolResponse(...) # AsyncGenerator 正确处理
|
||||
|
||||
验证:
|
||||
- 字符串格式的关系名在类创建时(__init_subclass__)验证
|
||||
- 拼写错误会在导入时抛出 AttributeError
|
||||
"""
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
# 检测是否是 async generator 函数
|
||||
is_async_gen = python_inspect.isasyncgenfunction(func)
|
||||
|
||||
if is_async_gen:
|
||||
# AsyncGenerator 需要特殊处理:wrapper 也必须是 async generator
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
# 委托给原始 async generator,逐个 yield 值
|
||||
async for item in func(self, *args, **kwargs):
|
||||
yield item # type: ignore
|
||||
else:
|
||||
# 普通 async 函数:await 并返回结果
|
||||
@wraps(func)
|
||||
async def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
session = _extract_session(func, args, kwargs)
|
||||
if session is not None:
|
||||
await self._ensure_relations_loaded(session, relations)
|
||||
return await func(self, *args, **kwargs)
|
||||
|
||||
# 保存关系声明供验证和内省使用
|
||||
wrapper._required_relations = relations # type: ignore
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RelationPreloadMixin:
|
||||
"""
|
||||
关系预加载 Mixin
|
||||
|
||||
提供按需增量加载能力,确保 SQL 查询数理论最优。
|
||||
|
||||
特性:
|
||||
- 按需加载:只加载被调用方法需要的关系
|
||||
- 增量加载:已加载的关系不重复加载
|
||||
- 原地更新:直接修改 self,无需替换实例
|
||||
- 导入时验证:字符串关系名在类创建时验证
|
||||
- Commit 安全:基于 SQLAlchemy inspect 检测真实状态,自动处理 expire
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs) -> None:
|
||||
"""类创建时验证所有 @requires_relations 声明"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
# 收集类及其父类的所有注解(包含普通字段)
|
||||
all_annotations: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__annotations__'):
|
||||
all_annotations.update(klass.__annotations__.keys())
|
||||
|
||||
# 收集 SQLModel 的 Relationship 字段(存储在 __sqlmodel_relationships__)
|
||||
sqlmodel_relationships: set[str] = set()
|
||||
for klass in cls.__mro__:
|
||||
if hasattr(klass, '__sqlmodel_relationships__'):
|
||||
sqlmodel_relationships.update(klass.__sqlmodel_relationships__.keys())
|
||||
|
||||
# 合并所有可用的属性名
|
||||
all_available_names = all_annotations | sqlmodel_relationships
|
||||
|
||||
for method_name in dir(cls):
|
||||
if method_name.startswith('__'):
|
||||
continue
|
||||
|
||||
try:
|
||||
method = getattr(cls, method_name, None)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
continue
|
||||
|
||||
# 验证字符串格式的关系名
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
# 检查注解、Relationship 或已有属性
|
||||
if spec not in all_available_names and not hasattr(cls, spec):
|
||||
raise AttributeError(
|
||||
f"{cls.__name__}.{method_name} 声明了关系 '{spec}',"
|
||||
f"但 {cls.__name__} 没有此属性"
|
||||
)
|
||||
|
||||
def _is_relation_loaded(self, rel_name: str) -> bool:
|
||||
"""
|
||||
检查关系是否真正已加载(基于 SQLAlchemy inspect)
|
||||
|
||||
使用 SQLAlchemy 的 inspect 检测真实加载状态,
|
||||
自动处理 commit 导致的 expire 问题。
|
||||
|
||||
Args:
|
||||
rel_name: 关系属性名
|
||||
|
||||
Returns:
|
||||
True 如果关系已加载,False 如果未加载或已过期
|
||||
"""
|
||||
try:
|
||||
state = sa_inspect(self)
|
||||
# unloaded 包含未加载的关系属性名
|
||||
return rel_name not in state.unloaded
|
||||
except Exception:
|
||||
# 对象可能未被 SQLAlchemy 管理
|
||||
return False
|
||||
|
||||
async def _ensure_relations_loaded(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
relations: tuple[str | RelationshipInfo, ...],
|
||||
) -> None:
|
||||
"""
|
||||
确保指定关系已加载,只加载未加载的部分
|
||||
|
||||
基于 SQLAlchemy inspect 检测真实状态,自动处理:
|
||||
- 首次访问的关系
|
||||
- commit 后 expire 的关系
|
||||
- 嵌套关系(如 KlingO1Generator.kling_o1)
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
relations: 需要的关系规格
|
||||
"""
|
||||
# 找出真正未加载的关系(基于 SQLAlchemy inspect)
|
||||
to_load: list[str | RelationshipInfo] = []
|
||||
# 区分直接关系和嵌套关系的 key
|
||||
direct_keys: set[str] = set() # 本类的直接关系属性名
|
||||
nested_parent_keys: set[str] = set() # 嵌套关系所需的父关系属性名
|
||||
|
||||
for rel in relations:
|
||||
if isinstance(rel, str):
|
||||
# 直接关系:检查本类的关系是否已加载
|
||||
if not self._is_relation_loaded(rel):
|
||||
to_load.append(rel)
|
||||
direct_keys.add(rel)
|
||||
else:
|
||||
# 嵌套关系(InstrumentedAttribute):如 KlingO1Generator.kling_o1
|
||||
# 1. 查找指向父类的关系属性
|
||||
parent_class = rel.parent.class_
|
||||
parent_attr = _find_relation_to_class(self.__class__, parent_class)
|
||||
|
||||
if parent_attr is None:
|
||||
# 找不到路径,可能是配置错误,但仍尝试加载
|
||||
l.warning(
|
||||
f"无法找到从 {self.__class__.__name__} 到 {parent_class.__name__} 的关系路径,"
|
||||
f"无法检查 {rel.key} 是否已加载"
|
||||
)
|
||||
to_load.append(rel)
|
||||
continue
|
||||
|
||||
# 2. 检查父对象是否已加载
|
||||
if not self._is_relation_loaded(parent_attr):
|
||||
# 父对象未加载,需要同时加载父对象和嵌套关系
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
else:
|
||||
# 3. 父对象已加载,检查嵌套关系是否已加载
|
||||
parent_obj = getattr(self, parent_attr)
|
||||
if not _is_obj_relation_loaded(parent_obj, rel.key):
|
||||
# 嵌套关系未加载:需要同时传递父关系和嵌套关系
|
||||
# 因为 _build_load_chains 需要完整的链来构建 selectinload
|
||||
if parent_attr not in direct_keys and parent_attr not in nested_parent_keys:
|
||||
to_load.append(parent_attr)
|
||||
nested_parent_keys.add(parent_attr)
|
||||
to_load.append(rel)
|
||||
|
||||
if not to_load:
|
||||
return # 全部已加载,跳过
|
||||
|
||||
# 构建 load 参数
|
||||
load_options = self._specs_to_load_options(to_load)
|
||||
if not load_options:
|
||||
return
|
||||
|
||||
# 安全地获取主键值(避免触发懒加载)
|
||||
state = sa_inspect(self)
|
||||
pk_tuple = state.key[1] if state.key else None
|
||||
if pk_tuple is None:
|
||||
l.warning(f"无法获取 {self.__class__.__name__} 的主键值")
|
||||
return
|
||||
# 主键是元组,取第一个值(假设单列主键)
|
||||
pk_value = pk_tuple[0]
|
||||
|
||||
# 单次查询加载缺失的关系
|
||||
fresh = await self.__class__.get(
|
||||
session,
|
||||
self.__class__.id == pk_value,
|
||||
load=load_options,
|
||||
)
|
||||
|
||||
if fresh is None:
|
||||
l.warning(f"无法加载关系:{self.__class__.__name__} id={self.id} 不存在")
|
||||
return
|
||||
|
||||
# 原地复制到 self(只复制直接关系,嵌套关系通过父关系自动可访问)
|
||||
all_direct_keys = direct_keys | nested_parent_keys
|
||||
for key in all_direct_keys:
|
||||
value = getattr(fresh, key, None)
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def _specs_to_load_options(
|
||||
self,
|
||||
specs: list[str | RelationshipInfo],
|
||||
) -> list[RelationshipInfo]:
|
||||
"""
|
||||
将关系规格转换为 load 参数
|
||||
|
||||
- 字符串 → cls.{name}
|
||||
- RelationshipInfo → 直接使用
|
||||
"""
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for spec in specs:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(self.__class__, spec, None)
|
||||
if rel is not None:
|
||||
result.append(rel)
|
||||
else:
|
||||
l.warning(f"关系 '{spec}' 在类 {self.__class__.__name__} 中不存在")
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
# ==================== 可选的手动预加载 API ====================
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_method(cls, method_name: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取指定方法声明的关系(用于外部预加载场景)
|
||||
|
||||
Args:
|
||||
method_name: 方法名
|
||||
|
||||
Returns:
|
||||
RelationshipInfo 列表
|
||||
"""
|
||||
method = getattr(cls, method_name, None)
|
||||
if method is None or not hasattr(method, '_required_relations'):
|
||||
return []
|
||||
|
||||
result: list[RelationshipInfo] = []
|
||||
for spec in method._required_relations:
|
||||
if isinstance(spec, str):
|
||||
rel = getattr(cls, spec, None)
|
||||
if rel:
|
||||
result.append(rel)
|
||||
else:
|
||||
result.append(spec)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_relations_for_methods(cls, *method_names: str) -> list[RelationshipInfo]:
|
||||
"""
|
||||
获取多个方法的关系并去重(用于批量预加载场景)
|
||||
|
||||
Args:
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
去重后的 RelationshipInfo 列表
|
||||
"""
|
||||
seen: set[str] = set()
|
||||
result: list[RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
for rel in cls.get_relations_for_method(method_name):
|
||||
key = rel.key
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
result.append(rel)
|
||||
|
||||
return result
|
||||
|
||||
async def preload_for(self, session: AsyncSession, *method_names: str) -> 'RelationPreloadMixin':
|
||||
"""
|
||||
手动预加载指定方法的关系(可选优化 API)
|
||||
|
||||
当需要确保在调用方法前完成所有加载时使用。
|
||||
通常情况下不需要调用此方法,装饰器会自动处理。
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
method_names: 方法名列表
|
||||
|
||||
Returns:
|
||||
self(支持链式调用)
|
||||
|
||||
Example:
|
||||
# 可选:显式预加载(通常不需要)
|
||||
tool = await tool.preload_for(session, 'cost', '_call')
|
||||
"""
|
||||
all_relations: list[str | RelationshipInfo] = []
|
||||
|
||||
for method_name in method_names:
|
||||
method = getattr(self.__class__, method_name, None)
|
||||
if method and hasattr(method, '_required_relations'):
|
||||
all_relations.extend(method._required_relations)
|
||||
|
||||
if all_relations:
|
||||
await self._ensure_relations_loaded(session, tuple(all_relations))
|
||||
|
||||
return self
|
||||
@@ -12,7 +12,14 @@
|
||||
mixin/table.py ← 当前文件,导入 PolymorphicBaseMixin
|
||||
↓
|
||||
base/__init__.py ← 从 mixin 重新导出(保持向后兼容)
|
||||
|
||||
维护须知:
|
||||
增删功能时必须更新 __version__ 字段(遵循语义化版本)
|
||||
|
||||
版本历史:
|
||||
0.1.0 - delete() 方法支持条件删除(condition 参数)
|
||||
"""
|
||||
__version__ = "0.1.0"
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
@@ -26,16 +33,19 @@ from typing import TypeVar, Literal, override, Any, ClassVar, Generic
|
||||
# 未来: PR #1275合并后可改回继承SQLModelBase
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement, desc, asc, func, distinct, delete as sql_delete, inspect
|
||||
from sqlalchemy.orm import selectinload, Relationship, with_polymorphic
|
||||
from sqlalchemy.orm.exc import StaleDataError
|
||||
from sqlmodel import Field, select
|
||||
|
||||
from .optimistic_lock import OptimisticLockError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel.main import RelationshipInfo
|
||||
|
||||
from .polymorphic import PolymorphicBaseMixin
|
||||
from models.base.sqlmodel_base import SQLModelBase
|
||||
from sqlmodels.base.sqlmodel_base import SQLModelBase
|
||||
|
||||
# Type variables for generic type hints, improving code completion and analysis.
|
||||
T = TypeVar("T", bound="TableBaseMixin")
|
||||
@@ -196,8 +206,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
created_at (datetime): 记录创建时的时间戳, 自动设置.
|
||||
updated_at (datetime): 记录每次更新时的时间戳, 自动更新.
|
||||
"""
|
||||
_is_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类为表混入类的内部属性"""
|
||||
_has_table_mixin: ClassVar[bool] = True
|
||||
"""标记此类继承了表混入类的内部属性"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
"""
|
||||
@@ -218,7 +228,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
向数据库中添加一个新的或多个新的记录.
|
||||
|
||||
@@ -230,8 +240,6 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||
@@ -246,11 +254,6 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 添加单个实例
|
||||
item3 = Item(name="Cherry")
|
||||
added_item = await Item.add(session, item3)
|
||||
|
||||
# 批量操作,减少提交次数
|
||||
await Item.add(session, [item1, item2], commit=False)
|
||||
await Item.add(session, [item3, item4], commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
@@ -259,10 +262,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
@@ -278,14 +278,16 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session: AsyncSession,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
commit: bool = True,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
optimistic_retry_count: int = 0,
|
||||
) -> T:
|
||||
"""
|
||||
保存(插入或更新)当前模型实例到数据库.
|
||||
|
||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||
预加载(eager load)一个或多个关联关系.
|
||||
预加载(eager load)一个关联关系.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
@@ -298,13 +300,17 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.save(session, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await item1.save(session, commit=False)
|
||||
await item2.save(session, commit=False)
|
||||
# ✅ 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.save(session, load=[User.group, User.tags])
|
||||
# ✅ 正确:保存后需要访问多态关系时
|
||||
tool_set = await tool_set.save(session, load=ToolSet.tools, jti_subclasses='all')
|
||||
return tool_set # tools 关系已正确加载子类数据
|
||||
|
||||
# ✅ 正确:启用乐观锁自动重试
|
||||
order = await order.save(session, optimistic_retry_count=3)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.save(session)
|
||||
@@ -313,34 +319,77 @@ class TableBaseMixin(AsyncAttrs):
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.posts` 或 `[User.group, User.tags]`.
|
||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
例如 `User.posts`.
|
||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
commit (bool): 是否在保存后提交事务。如果为 False,只会 flush 获取 ID
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类
|
||||
- None(默认): 不使用多态加载
|
||||
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
|
||||
重试时会重新查询最新数据,将当前修改合并后再次保存。
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
|
||||
Raises:
|
||||
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
|
||||
"""
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
cls = type(self)
|
||||
instance = self
|
||||
retries_remaining = optimistic_retry_count
|
||||
current_data: dict[str, Any] | None = None # 延迟计算,仅在需要重试时
|
||||
|
||||
while True:
|
||||
session.add(instance)
|
||||
try:
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
break # 成功,退出循环
|
||||
except StaleDataError as e:
|
||||
await session.rollback()
|
||||
if retries_remaining <= 0:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(instance, 'id', None)),
|
||||
expected_version=getattr(instance, 'version', None),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
# 失败后重试:重新查询最新数据并合并修改
|
||||
retries_remaining -= 1
|
||||
if current_data is None:
|
||||
current_data = self.model_dump(exclude={'id', 'version', 'created_at', 'updated_at'})
|
||||
|
||||
fresh = await cls.get(session, cls.id == self.id)
|
||||
if fresh is None:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 重试失败:记录已被删除",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(self, 'id', None)),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
for key, value in current_data.items():
|
||||
if hasattr(fresh, key):
|
||||
setattr(fresh, key, value)
|
||||
instance = fresh
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
return instance
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
# 如果指定了 load, 重新获取实例并加载关联关系
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
await session.refresh(instance)
|
||||
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
@@ -351,7 +400,9 @@ class TableBaseMixin(AsyncAttrs):
|
||||
exclude: set[str] | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
commit: bool = True,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
optimistic_retry_count: int = 0,
|
||||
) -> T:
|
||||
"""
|
||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||
@@ -371,16 +422,20 @@ class TableBaseMixin(AsyncAttrs):
|
||||
user = await user.update(session, update_data, load=User.permission)
|
||||
return user
|
||||
|
||||
# ✅ 正确:更新后需要访问多态关系时
|
||||
tool_set = await tool_set.update(session, data, load=ToolSet.tools, jti_subclasses='all')
|
||||
return tool_set # tools 关系已正确加载子类数据
|
||||
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.update(session, update_data, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await user1.update(session, data1, commit=False)
|
||||
await user2.update(session, data2, commit=False)
|
||||
# ✅ 正确:批量操作时延迟提交
|
||||
for item in items:
|
||||
item = await item.update(session, data, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.update(session, data, load=[User.group, User.tags])
|
||||
# ✅ 正确:启用乐观锁自动重试
|
||||
order = await order.update(session, update_data, optimistic_retry_count=3)
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.update(session, update_data)
|
||||
@@ -394,111 +449,134 @@ class TableBaseMixin(AsyncAttrs):
|
||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||
的字段将被忽略. 默认为 True.
|
||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.permission` 或 `[User.group, User.tags]`.
|
||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
例如 `User.permission`.
|
||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
commit (bool): 是否在更新后提交事务。如果为 False,只会 flush
|
||||
但不提交,适用于批量操作场景。默认为 True.
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类
|
||||
- None(默认): 不使用多态加载
|
||||
optimistic_retry_count (int): 乐观锁冲突时的自动重试次数。默认为 0(不重试)。
|
||||
重试时会重新查询最新数据,将 other 的更新重新应用后再次保存。
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset, exclude=exclude),
|
||||
update=extra_data
|
||||
)
|
||||
|
||||
session.add(self)
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
Raises:
|
||||
OptimisticLockError: 如果启用了乐观锁且版本号不匹配,且重试次数已耗尽
|
||||
"""
|
||||
cls = type(self)
|
||||
update_data = other.model_dump(exclude_unset=exclude_unset, exclude=exclude)
|
||||
instance = self
|
||||
retries_remaining = optimistic_retry_count
|
||||
|
||||
while True:
|
||||
instance.sqlmodel_update(update_data, update=extra_data)
|
||||
session.add(instance)
|
||||
|
||||
try:
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
break # 成功,退出循环
|
||||
except StaleDataError as e:
|
||||
await session.rollback()
|
||||
if retries_remaining <= 0:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 乐观锁冲突:记录已被其他事务修改",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(instance, 'id', None)),
|
||||
expected_version=getattr(instance, 'version', None),
|
||||
original_error=e,
|
||||
) from e
|
||||
|
||||
# 失败后重试:重新查询最新数据并重新应用更新
|
||||
retries_remaining -= 1
|
||||
fresh = await cls.get(session, cls.id == self.id)
|
||||
if fresh is None:
|
||||
raise OptimisticLockError(
|
||||
message=f"{cls.__name__} 重试失败:记录已被删除",
|
||||
model_class=cls.__name__,
|
||||
record_id=str(getattr(self, 'id', None)),
|
||||
original_error=e,
|
||||
) from e
|
||||
instance = fresh
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
return instance
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
await session.refresh(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
await session.refresh(instance)
|
||||
return await cls.get(session, cls.id == instance.id, load=load, jti_subclasses=jti_subclasses)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
await session.refresh(instance)
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True,
|
||||
) -> int:
|
||||
"""
|
||||
从数据库中删除记录.
|
||||
|
||||
支持两种删除方式:
|
||||
1. 实例删除:传入 instances 参数,先加载再删除
|
||||
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
|
||||
从数据库中删除记录,支持实例删除和条件删除两种模式。
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
|
||||
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
session: 用于数据库操作的异步会话对象
|
||||
instances: 要删除的单个模型实例或模型实例列表(实例删除模式)
|
||||
condition: WHERE 条件表达式(条件删除模式,直接执行 SQL DELETE)
|
||||
commit: 是否在删除后提交事务。默认为 True
|
||||
|
||||
Returns:
|
||||
int: 删除的记录数量
|
||||
删除的记录数(条件删除模式返回实际删除数,实例删除模式返回实例数)
|
||||
|
||||
Raises:
|
||||
ValueError: 同时提供 instances 和 condition,或两者都未提供
|
||||
|
||||
Usage:
|
||||
# 实例删除
|
||||
item_to_delete = await Item.get(session, Item.id == 1)
|
||||
if item_to_delete:
|
||||
deleted_count = await Item.delete(session, item_to_delete)
|
||||
# 实例删除模式
|
||||
item = await Item.get(session, Item.id == 1)
|
||||
if item:
|
||||
await Item.delete(session, item)
|
||||
|
||||
# 条件删除(更高效,无需加载实例)
|
||||
items = await Item.get(session, Item.name.in_(["A", "B"]), fetch_mode="all")
|
||||
if items:
|
||||
await Item.delete(session, items)
|
||||
|
||||
# 条件删除模式(高效批量删除,不加载实例到内存)
|
||||
deleted_count = await Item.delete(
|
||||
session,
|
||||
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
|
||||
condition=(Item.user_id == user_id) & (Item.status == "expired"),
|
||||
)
|
||||
|
||||
# 批量删除后手动提交
|
||||
await Item.delete(session, item1, commit=False)
|
||||
await Item.delete(session, item2, commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
# 条件删除模式
|
||||
if condition is not None:
|
||||
from sqlmodel import delete as sql_delete
|
||||
|
||||
if instances is not None:
|
||||
raise ValueError("不能同时指定 instances 和 condition")
|
||||
|
||||
# 执行条件删除
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.exec(stmt)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
# 实例删除模式(原有逻辑)
|
||||
if instances is None:
|
||||
raise ValueError("必须指定 instances 或 condition")
|
||||
if instances is not None and condition is not None:
|
||||
raise ValueError("不能同时提供 instances 和 condition 参数")
|
||||
if instances is None and condition is None:
|
||||
raise ValueError("必须提供 instances 或 condition 参数之一")
|
||||
|
||||
deleted_count = 0
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count += 1
|
||||
|
||||
if condition is not None:
|
||||
# 条件删除模式:直接执行 SQL DELETE
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.execute(stmt)
|
||||
deleted_count = result.rowcount
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
# 实例删除模式
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count = len(instances)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
@@ -552,7 +630,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
with_for_update: bool = False,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
populate_existing: bool = False,
|
||||
created_before_datetime: datetime | None = None,
|
||||
created_after_datetime: datetime | None = None,
|
||||
updated_before_datetime: datetime | None = None,
|
||||
@@ -581,8 +660,10 @@ class TableBaseMixin(AsyncAttrs):
|
||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||
例如 `[selectinload(User.posts)]`.
|
||||
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.profile` 或 `[User.group, User.tags]`.
|
||||
可以是单个关系或关系列表。支持嵌套关系预加载:
|
||||
当传入多个关系时,会自动检测依赖关系并构建链式 selectinload。
|
||||
例如 `[NodeGroupNode.element_links, NodeGroupElementLink.node]`
|
||||
会自动构建 `selectinload(element_links).selectinload(node)`。
|
||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||
@@ -593,11 +674,16 @@ class TableBaseMixin(AsyncAttrs):
|
||||
会覆盖offset、limit、order_by及时间筛选参数。
|
||||
这是推荐的分页排序方式,统一了所有LIST端点的参数格式。
|
||||
|
||||
load_polymorphic: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
jti_subclasses: 多态子类加载选项,需要与 load 参数配合使用。
|
||||
- list[type[PolymorphicBaseMixin]]: 指定要加载的子类列表
|
||||
- 'all': 两阶段查询,只加载实际关联的子类(对于 > 10 个子类的场景有明显性能收益)
|
||||
- None(默认): 不使用多态加载
|
||||
|
||||
populate_existing (bool): 如果为 True,强制用数据库数据覆盖 session 中已存在的对象(identity map)。
|
||||
用于批量刷新对象,避免循环调用 session.refresh() 导致的 N 次查询。
|
||||
注意:只刷新标量字段,不影响运行时属性(_开头的属性)。
|
||||
对于 STI(单表继承)对象,推荐按子类分组查询以包含子类字段。默认为 False。
|
||||
|
||||
created_before_datetime (datetime | None): 筛选 created_at < datetime 的记录
|
||||
created_after_datetime (datetime | None): 筛选 created_at >= datetime 的记录
|
||||
updated_before_datetime (datetime | None): 筛选 updated_at < datetime 的记录
|
||||
@@ -607,7 +693,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
T | list[T] | None: 根据 `fetch_mode` 的设置,返回单个实例、实例列表或 `None`.
|
||||
|
||||
Raises:
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值,或 load_polymorphic 未与 load 配合使用.
|
||||
ValueError: 如果提供了无效的 `fetch_mode` 值,或 jti_subclasses 未与 load 配合使用.
|
||||
|
||||
Examples:
|
||||
# 使用table_view参数(推荐)
|
||||
@@ -621,13 +707,13 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session,
|
||||
ToolSet.id == tool_set_id,
|
||||
load=ToolSet.tools,
|
||||
load_polymorphic='all' # 只加载实际关联的子类
|
||||
jti_subclasses='all' # 只加载实际关联的子类
|
||||
)
|
||||
"""
|
||||
# 参数验证:load_polymorphic 需要与 load 配合使用
|
||||
if load_polymorphic is not None and load is None:
|
||||
# 参数验证:jti_subclasses 需要与 load 配合使用
|
||||
if jti_subclasses is not None and load is None:
|
||||
raise ValueError(
|
||||
"load_polymorphic 参数需要与 load 参数配合使用,"
|
||||
"jti_subclasses 参数需要与 load 参数配合使用,"
|
||||
"请同时指定要加载的关系"
|
||||
)
|
||||
|
||||
@@ -656,13 +742,34 @@ class TableBaseMixin(AsyncAttrs):
|
||||
|
||||
# 对于多态基类,使用 with_polymorphic 预加载所有子类的列
|
||||
# 这避免了在响应序列化时的延迟加载问题(MissingGreenlet 错误)
|
||||
if issubclass(cls, PolymorphicBaseMixin):
|
||||
polymorphic_cls = None # 保存多态实体,用于子类关系预加载
|
||||
is_polymorphic = issubclass(cls, PolymorphicBaseMixin)
|
||||
is_jti = is_polymorphic and cls._is_joined_table_inheritance()
|
||||
is_sti = is_polymorphic and not cls._is_joined_table_inheritance()
|
||||
|
||||
# JTI 模式:总是使用 with_polymorphic(避免 N+1 查询)
|
||||
# STI 模式:不使用 with_polymorphic(批量刷新时请按子类分组查询)
|
||||
if is_jti:
|
||||
# '*' 表示加载所有子类
|
||||
polymorphic_cls = with_polymorphic(cls, '*')
|
||||
statement = select(polymorphic_cls)
|
||||
else:
|
||||
statement = select(cls)
|
||||
|
||||
# 对于 STI(单表继承)子类,自动添加多态过滤条件
|
||||
# SQLAlchemy/SQLModel 在 STI 模式下不会自动添加 WHERE discriminator = 'identity' 过滤
|
||||
# 这是已知行为,参考:
|
||||
# - https://github.com/sqlalchemy/sqlalchemy/issues/5018 (bulk operations 不自动添加多态过滤)
|
||||
# - https://github.com/fastapi/sqlmodel/issues/488 (SQLModel STI 支持不完整)
|
||||
# 社区最佳实践是显式添加多态过滤条件
|
||||
if issubclass(cls, PolymorphicBaseMixin) and not cls._is_joined_table_inheritance():
|
||||
mapper = inspect(cls)
|
||||
# 检查是否有 polymorphic_identity 且不是抽象类
|
||||
if mapper.polymorphic_identity is not None and not mapper.polymorphic_abstract:
|
||||
poly_on = mapper.polymorphic_on
|
||||
if poly_on is not None:
|
||||
statement = statement.where(poly_on == mapper.polymorphic_identity)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
@@ -688,12 +795,19 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 标准化为列表
|
||||
load_list = load if isinstance(load, list) else [load]
|
||||
|
||||
# 处理多态加载
|
||||
if load_polymorphic is not None:
|
||||
# 多态加载只支持单个关系
|
||||
if len(load_list) > 1:
|
||||
raise ValueError("load_polymorphic 仅支持单个关系")
|
||||
target_class = load_list[0].property.mapper.class_
|
||||
# 构建链式 selectinload(支持嵌套关系预加载)
|
||||
# 例如:load=[NodeGroupNode.element_links, NodeGroupElementLink.node]
|
||||
# 会构建:selectinload(element_links).selectinload(node)
|
||||
load_chains = cls._build_load_chains(load_list)
|
||||
|
||||
# 处理多态加载(仅支持单链且只有一个关系)
|
||||
if jti_subclasses is not None:
|
||||
if len(load_chains) > 1 or len(load_chains[0]) > 1:
|
||||
raise ValueError(
|
||||
"jti_subclasses 仅支持单个关系(无嵌套链),请不要传入多个关系"
|
||||
)
|
||||
single_load = load_chains[0][0]
|
||||
target_class = single_load.property.mapper.class_
|
||||
|
||||
# 检查目标类是否继承自 PolymorphicBaseMixin
|
||||
if not issubclass(target_class, PolymorphicBaseMixin):
|
||||
@@ -702,26 +816,48 @@ class TableBaseMixin(AsyncAttrs):
|
||||
f"请确保其继承自 PolymorphicBaseMixin"
|
||||
)
|
||||
|
||||
if load_polymorphic == 'all':
|
||||
if jti_subclasses == 'all':
|
||||
# 两阶段查询:获取实际关联的多态类型
|
||||
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
||||
session, condition, load_list[0], target_class
|
||||
session, condition, single_load, target_class
|
||||
)
|
||||
else:
|
||||
subclasses_to_load = load_polymorphic
|
||||
subclasses_to_load = jti_subclasses
|
||||
|
||||
if subclasses_to_load:
|
||||
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
||||
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
||||
statement = statement.options(
|
||||
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
|
||||
selectinload(single_load).selectin_polymorphic(subclasses_to_load)
|
||||
)
|
||||
else:
|
||||
statement = statement.options(selectinload(load_list[0]))
|
||||
statement = statement.options(selectinload(single_load))
|
||||
else:
|
||||
# 为每个关系添加 selectinload
|
||||
for rel in load_list:
|
||||
statement = statement.options(selectinload(rel))
|
||||
# 为每条链构建链式 selectinload
|
||||
for chain in load_chains:
|
||||
# 获取第一个关系并检查是否需要通过多态实体访问
|
||||
first_rel = chain[0]
|
||||
first_rel_parent = first_rel.property.parent.class_
|
||||
|
||||
# 如果关系的 parent_class 是当前类的子类(不是 cls 本身),
|
||||
# 且当前是多态查询,则需要通过 polymorphic_cls.SubclassName 访问
|
||||
if (
|
||||
polymorphic_cls is not None
|
||||
and first_rel_parent is not cls
|
||||
and issubclass(first_rel_parent, cls)
|
||||
):
|
||||
# 通过多态实体访问子类的关系属性
|
||||
# 例如:polymorphic_cls.NodeGroupNode.element_links
|
||||
subclass_alias = getattr(polymorphic_cls, first_rel_parent.__name__)
|
||||
rel_name = first_rel.key
|
||||
first_rel_via_poly = getattr(subclass_alias, rel_name)
|
||||
loader = selectinload(first_rel_via_poly)
|
||||
else:
|
||||
loader = selectinload(first_rel)
|
||||
|
||||
for rel in chain[1:]:
|
||||
loader = loader.selectinload(rel)
|
||||
statement = statement.options(loader)
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
@@ -736,7 +872,17 @@ class TableBaseMixin(AsyncAttrs):
|
||||
statement = statement.filter(filter)
|
||||
|
||||
if with_for_update:
|
||||
statement = statement.with_for_update()
|
||||
# 对于联表继承的多态模型,使用 FOR UPDATE OF <主表> 来避免 PostgreSQL 的限制
|
||||
# PostgreSQL 不支持在 LEFT OUTER JOIN 的可空侧使用 FOR UPDATE
|
||||
if issubclass(cls, PolymorphicBaseMixin):
|
||||
statement = statement.with_for_update(of=cls)
|
||||
else:
|
||||
statement = statement.with_for_update()
|
||||
|
||||
if populate_existing:
|
||||
# 强制用数据库数据覆盖 identity map 中的对象
|
||||
# 用于批量刷新,避免循环 refresh() 的 N 次查询
|
||||
statement = statement.execution_options(populate_existing=True)
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
@@ -749,6 +895,79 @@ class TableBaseMixin(AsyncAttrs):
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@staticmethod
|
||||
def _build_load_chains(load_list: list[RelationshipInfo]) -> list[list[RelationshipInfo]]:
|
||||
"""
|
||||
将关系列表构建为链式加载结构
|
||||
|
||||
自动检测关系之间的依赖关系,构建嵌套预加载链。
|
||||
例如:[NodeGroupNode.element_links, NodeGroupElementLink.node]
|
||||
会构建:[[element_links, node]](一条链)
|
||||
|
||||
算法:
|
||||
1. 获取每个关系的 parent class 和 target class
|
||||
2. 如果关系 B 的 parent class 等于关系 A 的 target class,则 B 链在 A 后面
|
||||
3. 独立的关系各自成为一条链
|
||||
|
||||
Args:
|
||||
load_list: 关系属性列表
|
||||
|
||||
Returns:
|
||||
链式关系列表,每条链是一个关系列表
|
||||
"""
|
||||
if not load_list:
|
||||
return []
|
||||
|
||||
# 构建关系信息:{关系: (parent_class, target_class)}
|
||||
rel_info: dict[RelationshipInfo, tuple[type, type]] = {}
|
||||
for rel in load_list:
|
||||
parent_class = rel.property.parent.class_
|
||||
target_class = rel.property.mapper.class_
|
||||
rel_info[rel] = (parent_class, target_class)
|
||||
|
||||
# 构建依赖图:{关系: 其前置关系}
|
||||
predecessors: dict[RelationshipInfo, RelationshipInfo | None] = {rel: None for rel in load_list}
|
||||
for rel_b in load_list:
|
||||
parent_b, _ = rel_info[rel_b]
|
||||
for rel_a in load_list:
|
||||
if rel_a is rel_b:
|
||||
continue
|
||||
_, target_a = rel_info[rel_a]
|
||||
# 如果 B 的 parent 精确等于 A 的 target,则 B 链在 A 后面
|
||||
# 使用精确匹配避免继承关系导致的误判(如 NodeGroupNode 是 CanvasNode 子类)
|
||||
if parent_b is target_a:
|
||||
predecessors[rel_b] = rel_a
|
||||
break
|
||||
|
||||
# 找出所有链的起点(没有前置关系的)
|
||||
roots = [rel for rel, pred in predecessors.items() if pred is None]
|
||||
|
||||
# 构建链
|
||||
chains: list[list[RelationshipInfo]] = []
|
||||
used: set[RelationshipInfo] = set()
|
||||
|
||||
for root in roots:
|
||||
chain = [root]
|
||||
used.add(root)
|
||||
# 找后续节点
|
||||
current = root
|
||||
while True:
|
||||
# 找以 current 的 target 为 parent 的关系
|
||||
_, current_target = rel_info[current]
|
||||
next_rel = None
|
||||
for rel, (parent, _) in rel_info.items():
|
||||
if rel not in used and parent is current_target:
|
||||
next_rel = rel
|
||||
break
|
||||
if next_rel is None:
|
||||
break
|
||||
chain.append(next_rel)
|
||||
used.add(next_rel)
|
||||
current = next_rel
|
||||
chains.append(chain)
|
||||
|
||||
return chains
|
||||
|
||||
@classmethod
|
||||
async def _resolve_polymorphic_subclasses(
|
||||
cls: type[T],
|
||||
@@ -791,12 +1010,15 @@ class TableBaseMixin(AsyncAttrs):
|
||||
))
|
||||
)
|
||||
else:
|
||||
# 一对多关系:通过外键查询
|
||||
foreign_key_col = relationship_property.local_remote_pairs[0][1]
|
||||
# 多对一/一对多关系:通过外键查询
|
||||
# local_remote_pairs[0] = (local_fk_col, remote_pk_col)
|
||||
# 对于多对一:local 是当前类的外键,remote 是目标类的主键
|
||||
local_fk_col = relationship_property.local_remote_pairs[0][0]
|
||||
remote_pk_col = relationship_property.local_remote_pairs[0][1]
|
||||
type_query = (
|
||||
select(distinct(poly_name_col))
|
||||
.where(foreign_key_col.in_(
|
||||
select(cls.id).where(condition) if condition is not None else select(cls.id)
|
||||
.where(remote_pk_col.in_(
|
||||
select(local_fk_col).where(condition) if condition is not None else select(local_fk_col)
|
||||
))
|
||||
)
|
||||
|
||||
@@ -898,7 +1120,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
load_polymorphic: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
jti_subclasses: list[type[PolymorphicBaseMixin]] | Literal['all'] | None = None,
|
||||
) -> 'ListResponse[T]':
|
||||
"""
|
||||
获取分页列表及总数,直接返回 ListResponse
|
||||
@@ -918,7 +1140,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by: 排序子句
|
||||
filter: 附加过滤条件
|
||||
table_view: 分页排序参数(推荐使用)
|
||||
load_polymorphic: 多态子类加载选项
|
||||
jti_subclasses: 多态子类加载选项
|
||||
|
||||
Returns:
|
||||
ListResponse[T]: 包含 count 和 items 的分页响应
|
||||
@@ -957,7 +1179,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
load_polymorphic=load_polymorphic,
|
||||
jti_subclasses=jti_subclasses,
|
||||
)
|
||||
|
||||
return ListResponse(count=total_count, items=items)
|
||||
@@ -973,8 +1195,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (int): 要查找的记录的主键 ID.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
@@ -1002,7 +1223,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
|
||||
"""
|
||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
@@ -1012,8 +1233,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
@@ -119,4 +119,5 @@ class MCPResponseBase(MCPBase):
|
||||
"""MCP 响应模型基础类"""
|
||||
|
||||
result: str
|
||||
"""方法返回结果"""
|
||||
"""方法返回结果"""
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from enum import StrEnum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index, text
|
||||
|
||||
from .base import SQLModelBase
|
||||
@@ -15,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from .source_link import SourceLink
|
||||
from .share import Share
|
||||
from .physical_file import PhysicalFile
|
||||
from .uri import DiskNextURI
|
||||
|
||||
|
||||
class ObjectType(StrEnum):
|
||||
@@ -103,7 +105,7 @@ class ObjectMoveRequest(SQLModelBase):
|
||||
class ObjectDeleteRequest(SQLModelBase):
|
||||
"""删除对象请求 DTO"""
|
||||
|
||||
ids: UUID | list[UUID]
|
||||
ids: list[UUID]
|
||||
"""待删除对象UUID列表"""
|
||||
|
||||
|
||||
@@ -116,12 +118,12 @@ class ObjectResponse(ObjectBase):
|
||||
thumb: bool = False
|
||||
"""是否有缩略图"""
|
||||
|
||||
date: datetime
|
||||
"""对象修改时间"""
|
||||
|
||||
create_date: datetime
|
||||
created_at: datetime
|
||||
"""对象创建时间"""
|
||||
|
||||
updated_at: datetime
|
||||
"""对象修改时间"""
|
||||
|
||||
source_enabled: bool = False
|
||||
"""是否启用离线下载源"""
|
||||
|
||||
@@ -138,7 +140,7 @@ class PolicyResponse(SQLModelBase):
|
||||
type: StorageType
|
||||
"""存储类型"""
|
||||
|
||||
max_size: int = Field(ge=0, default=0)
|
||||
max_size: int = Field(ge=0, default=0, sa_type=BigInteger)
|
||||
"""单文件最大限制,单位字节,0表示不限制"""
|
||||
|
||||
file_type: list[str] | None = None
|
||||
@@ -186,18 +188,18 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
|
||||
|
||||
根目录规则:
|
||||
- 每个用户有一个显式根目录对象(name=用户的username, parent_id=NULL)
|
||||
- 每个用户有一个显式根目录对象(name="/", parent_id=NULL)
|
||||
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
|
||||
- 根目录的 policy_id 指定用户默认存储策略
|
||||
- 路径格式:/username/path/to/file(如 /admin/docs/readme.md)
|
||||
- 路径格式:/path/to/file(如 /docs/readme.md),不包含用户名前缀
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
# 同一父目录下名称唯一(包括 parent_id 为 NULL 的情况)
|
||||
UniqueConstraint("owner_id", "parent_id", "name", name="uq_object_parent_name"),
|
||||
# 名称不能包含斜杠 ([TODO] 还有特殊字符)
|
||||
# 名称不能包含斜杠(根目录 parent_id IS NULL 除外,因为根目录 name="/")
|
||||
CheckConstraint(
|
||||
"name NOT LIKE '%/%' AND name NOT LIKE '%\\%'",
|
||||
"parent_id IS NULL OR (name NOT LIKE '%/%' AND name NOT LIKE '%\\%')",
|
||||
name="ck_object_name_no_slash",
|
||||
),
|
||||
# 性能索引
|
||||
@@ -220,7 +222,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
|
||||
# ==================== 文件专属字段 ====================
|
||||
|
||||
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
size: int = Field(default=0, sa_type=BigInteger, sa_column_kwargs={"server_default": "0"})
|
||||
"""文件大小(字节),目录为 0"""
|
||||
|
||||
upload_session_id: str | None = Field(default=None, max_length=255, unique=True, index=True)
|
||||
@@ -374,15 +376,16 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
session,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
username: str,
|
||||
) -> "Object | None":
|
||||
"""
|
||||
根据路径获取对象
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 "/" 表示根目录,"/docs/images" 表示根目录下的 docs/images。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:param path: 路径,如 "/username" 或 "/username/docs/images"
|
||||
:param username: 用户名,用于识别根目录
|
||||
:param path: 路径,如 "/" 或 "/docs/images"
|
||||
:return: Object 或 None
|
||||
"""
|
||||
path = path.strip()
|
||||
@@ -403,16 +406,7 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
if not parts:
|
||||
return root
|
||||
|
||||
# 检查第一部分是否是用户名(根目录名)
|
||||
if parts[0] == username:
|
||||
# 路径以用户名开头,如 /admin/docs
|
||||
if len(parts) == 1:
|
||||
# 只有用户名,返回根目录
|
||||
return root
|
||||
# 去掉用户名部分,从第二个部分开始遍历
|
||||
parts = parts[1:]
|
||||
|
||||
# 从根目录开始遍历剩余路径
|
||||
# 从根目录开始遍历路径
|
||||
current = root
|
||||
for part in parts:
|
||||
if not current:
|
||||
@@ -443,6 +437,77 @@ class Object(ObjectBase, UUIDTableBaseMixin):
|
||||
fetch_mode="all"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def resolve_uri(
|
||||
cls,
|
||||
session,
|
||||
uri: "DiskNextURI",
|
||||
requesting_user_id: UUID | None = None,
|
||||
) -> "Object":
|
||||
"""
|
||||
将 URI 解析为 Object 实例
|
||||
|
||||
分派逻辑(类似 Cloudreve 的 getNavigator):
|
||||
- MY → user_id = uri.id(str(requesting_user_id))
|
||||
验证权限(自己的或管理员),然后 get_by_path
|
||||
- SHARE → 通过 uri.fs_id 查 Share 表,验证密码和有效期
|
||||
获取 share.object,然后沿 uri.path 遍历子对象
|
||||
- TRASH → 延后实现
|
||||
|
||||
:param session: 数据库会话
|
||||
:param uri: DiskNextURI 实例
|
||||
:param requesting_user_id: 请求用户UUID
|
||||
:return: Object 实例
|
||||
:raises ValueError: URI 无法解析
|
||||
:raises PermissionError: 权限不足
|
||||
:raises NotImplementedError: 不支持的命名空间
|
||||
"""
|
||||
from .uri import FileSystemNamespace
|
||||
|
||||
if uri.namespace == FileSystemNamespace.MY:
|
||||
# 确定目标用户
|
||||
target_user_id_str = uri.id(str(requesting_user_id) if requesting_user_id else None)
|
||||
if not target_user_id_str:
|
||||
raise ValueError("MY 命名空间需要提供 fs_id 或 requesting_user_id")
|
||||
|
||||
target_user_id = UUID(target_user_id_str)
|
||||
|
||||
# 权限检查:只能访问自己的空间(管理员权限由路由层判断)
|
||||
if requesting_user_id and target_user_id != requesting_user_id:
|
||||
raise PermissionError("无权访问其他用户的文件空间")
|
||||
|
||||
obj = await cls.get_by_path(session, target_user_id, uri.path)
|
||||
if not obj:
|
||||
raise ValueError(f"路径不存在: {uri.path}")
|
||||
return obj
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.SHARE:
|
||||
raise NotImplementedError("分享空间解析尚未实现")
|
||||
|
||||
elif uri.namespace == FileSystemNamespace.TRASH:
|
||||
raise NotImplementedError("回收站解析尚未实现")
|
||||
|
||||
else:
|
||||
raise ValueError(f"未知的命名空间: {uri.namespace}")
|
||||
|
||||
async def get_full_path(self, session) -> str:
|
||||
"""
|
||||
从当前对象沿 parent_id 向上遍历到根目录,返回完整路径
|
||||
|
||||
:param session: 数据库会话
|
||||
:return: 完整路径,如 "/docs/images/photo.jpg"
|
||||
"""
|
||||
parts: list[str] = []
|
||||
current: Object | None = self
|
||||
|
||||
while current and current.parent_id is not None:
|
||||
parts.append(current.name)
|
||||
current = await Object.get(session, Object.id == current.parent_id)
|
||||
|
||||
# 反转顺序(从根到当前)
|
||||
parts.reverse()
|
||||
return "/" + "/".join(parts)
|
||||
|
||||
|
||||
# ==================== 上传会话模型 ====================
|
||||
|
||||
@@ -452,10 +517,10 @@ class UploadSessionBase(SQLModelBase):
|
||||
file_name: str = Field(max_length=255)
|
||||
"""原始文件名"""
|
||||
|
||||
file_size: int = Field(ge=0)
|
||||
file_size: int = Field(ge=0, sa_type=BigInteger)
|
||||
"""文件总大小(字节)"""
|
||||
|
||||
chunk_size: int = Field(ge=1)
|
||||
chunk_size: int = Field(ge=1, sa_type=BigInteger)
|
||||
"""分片大小(字节)"""
|
||||
|
||||
total_chunks: int = Field(ge=1)
|
||||
@@ -474,7 +539,7 @@ class UploadSession(UploadSessionBase, UUIDTableBaseMixin):
|
||||
uploaded_chunks: int = 0
|
||||
"""已上传分片数"""
|
||||
|
||||
uploaded_size: int = 0
|
||||
uploaded_size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""已上传大小(字节)"""
|
||||
|
||||
storage_path: str | None = Field(default=None, max_length=512)
|
||||
@@ -680,8 +745,8 @@ class AdminFileResponse(ObjectResponse):
|
||||
owner_id: UUID
|
||||
"""所有者UUID"""
|
||||
|
||||
owner_username: str
|
||||
"""所有者用户名"""
|
||||
owner_email: str
|
||||
"""所有者邮箱"""
|
||||
|
||||
policy_name: str
|
||||
"""存储策略名称"""
|
||||
@@ -709,12 +774,12 @@ class AdminFileResponse(ObjectResponse):
|
||||
# ObjectResponse 字段
|
||||
id=obj.id,
|
||||
thumb=False,
|
||||
date=obj.updated_at,
|
||||
create_date=obj.created_at,
|
||||
created_at=obj.created_at,
|
||||
updated_at=obj.updated_at,
|
||||
source_enabled=False,
|
||||
# AdminFileResponse 字段
|
||||
owner_id=obj.owner_id,
|
||||
owner_username=owner.username if owner else "unknown",
|
||||
owner_email=owner.email if owner else "unknown",
|
||||
policy_name=policy.name if policy else "unknown",
|
||||
is_banned=obj.is_banned,
|
||||
banned_at=obj.banned_at,
|
||||
@@ -725,7 +790,7 @@ class AdminFileResponse(ObjectResponse):
|
||||
class FileBanRequest(SQLModelBase):
|
||||
"""文件封禁请求 DTO"""
|
||||
|
||||
is_banned: bool = True
|
||||
ban: bool = True
|
||||
"""是否封禁"""
|
||||
|
||||
reason: str | None = Field(default=None, max_length=500)
|
||||
@@ -12,6 +12,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import SQLModelBase
|
||||
@@ -28,7 +29,7 @@ class PhysicalFileBase(SQLModelBase):
|
||||
storage_path: str = Field(max_length=512)
|
||||
"""物理存储路径(相对于存储策略根目录)"""
|
||||
|
||||
size: int = 0
|
||||
size: int = Field(default=0, sa_type=BigInteger)
|
||||
"""文件大小(字节)"""
|
||||
|
||||
checksum_md5: str | None = Field(default=None, max_length=32)
|
||||
@@ -20,16 +20,10 @@ class SiteConfigResponse(SQLModelBase):
|
||||
title: str = "DiskNext"
|
||||
"""网站标题"""
|
||||
|
||||
# themes: dict[str, str] = {}
|
||||
# """网站主题配置"""
|
||||
|
||||
# default_theme: dict[str, str] = {}
|
||||
# """默认主题RGB色号"""
|
||||
|
||||
site_notice: str | None = None
|
||||
"""网站公告"""
|
||||
|
||||
user: UserResponse
|
||||
user: UserResponse | None = None
|
||||
"""用户信息"""
|
||||
|
||||
logo_light: str | None = None
|
||||
@@ -38,11 +32,23 @@ class SiteConfigResponse(SQLModelBase):
|
||||
logo_dark: str | None = None
|
||||
"""网站Logo URL(深色模式)"""
|
||||
|
||||
captcha_type: CaptchaType | None = None
|
||||
register_enabled: bool = True
|
||||
"""是否允许注册"""
|
||||
|
||||
login_captcha: bool = False
|
||||
"""登录是否需要验证码"""
|
||||
|
||||
reg_captcha: bool = False
|
||||
"""注册是否需要验证码"""
|
||||
|
||||
forget_captcha: bool = False
|
||||
"""找回密码是否需要验证码"""
|
||||
|
||||
captcha_type: CaptchaType = CaptchaType.DEFAULT
|
||||
"""验证码类型"""
|
||||
|
||||
captcha_key: str | None = None
|
||||
"""验证码密钥"""
|
||||
"""验证码 public key(DEFAULT 类型时为 None)"""
|
||||
|
||||
|
||||
# ==================== 管理员设置 DTO ====================
|
||||
@@ -215,6 +215,6 @@ class AdminShareListItem(ShareListItemBase):
|
||||
"""从 Share ORM 对象构建"""
|
||||
return cls(
|
||||
**ShareListItemBase.model_validate(share, from_attributes=True).model_dump(),
|
||||
username=user.username if user else None,
|
||||
username=user.email if user else None,
|
||||
object_name=obj.name if obj else None,
|
||||
)
|
||||
@@ -73,7 +73,7 @@ class TaskSummary(TaskSummaryBase):
|
||||
"""从 Task ORM 对象构建"""
|
||||
return cls(
|
||||
**TaskSummaryBase.model_validate(task, from_attributes=True).model_dump(),
|
||||
username=user.username if user else None,
|
||||
username=user.email if user else None,
|
||||
)
|
||||
|
||||
|
||||
258
sqlmodels/uri.py
Normal file
258
sqlmodels/uri.py
Normal file
@@ -0,0 +1,258 @@
|
||||
|
||||
from enum import StrEnum
|
||||
from urllib.parse import urlparse, parse_qs, urlencode, quote, unquote
|
||||
|
||||
from .base import SQLModelBase
|
||||
|
||||
|
||||
class FileSystemNamespace(StrEnum):
|
||||
"""文件系统命名空间"""
|
||||
MY = "my"
|
||||
"""用户个人空间"""
|
||||
|
||||
SHARE = "share"
|
||||
"""分享空间"""
|
||||
|
||||
TRASH = "trash"
|
||||
"""回收站"""
|
||||
|
||||
|
||||
class DiskNextURI(SQLModelBase):
|
||||
"""
|
||||
DiskNext 文件 URI
|
||||
|
||||
URI 格式: disknext://[fs_id[:password]@]namespace[/path][?query]
|
||||
|
||||
fs_id 可省略:
|
||||
- my/trash 命名空间省略时默认当前用户
|
||||
- share 命名空间必须提供 fs_id(Share.code)
|
||||
"""
|
||||
|
||||
fs_id: str | None = None
|
||||
"""文件系统标识符,可省略"""
|
||||
|
||||
namespace: FileSystemNamespace
|
||||
"""命名空间"""
|
||||
|
||||
path: str = "/"
|
||||
"""路径"""
|
||||
|
||||
password: str | None = None
|
||||
"""访问密码(用于有密码的分享)"""
|
||||
|
||||
query: dict[str, str] | None = None
|
||||
"""查询参数"""
|
||||
|
||||
# === 属性 ===
|
||||
|
||||
@property
|
||||
def path_parts(self) -> list[str]:
|
||||
"""路径分割为列表(过滤空串)"""
|
||||
return [p for p in self.path.split("/") if p]
|
||||
|
||||
@property
|
||||
def is_root(self) -> bool:
|
||||
"""是否指向根目录"""
|
||||
return self.path.strip("/") == ""
|
||||
|
||||
# === 核心方法 ===
|
||||
|
||||
def id(self, default_id: str | None = None) -> str | None:
|
||||
"""
|
||||
获取 fs_id,省略时返回 default_id
|
||||
|
||||
参考 Cloudreve URI.ID(defaultUid) 方法
|
||||
|
||||
:param default_id: 默认值(通常为当前用户 ID)
|
||||
:return: fs_id 或 default_id
|
||||
"""
|
||||
return self.fs_id if self.fs_id else default_id
|
||||
|
||||
# === 类方法 ===
|
||||
|
||||
@classmethod
|
||||
def parse(cls, uri: str) -> "DiskNextURI":
|
||||
"""
|
||||
解析 URI 字符串
|
||||
|
||||
实现方式:替换 disknext:// 为 http:// 后用 urllib.parse.urlparse 解析
|
||||
- hostname → namespace
|
||||
- username → fs_id
|
||||
- password → password
|
||||
- path → path
|
||||
- query → query dict
|
||||
|
||||
:param uri: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
:return: DiskNextURI 实例
|
||||
:raises ValueError: URI 格式无效
|
||||
"""
|
||||
if not uri.startswith("disknext://"):
|
||||
raise ValueError(f"URI 必须以 disknext:// 开头: {uri}")
|
||||
|
||||
# 替换协议为 http:// 以利用 urllib.parse 解析
|
||||
http_uri = "http://" + uri[len("disknext://"):]
|
||||
parsed = urlparse(http_uri)
|
||||
|
||||
# 解析 namespace
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
raise ValueError(f"URI 缺少命名空间: {uri}")
|
||||
|
||||
try:
|
||||
namespace = FileSystemNamespace(hostname)
|
||||
except ValueError:
|
||||
raise ValueError(f"无效的命名空间 '{hostname}',有效值: {[e.value for e in FileSystemNamespace]}")
|
||||
|
||||
# 解析 fs_id 和 password
|
||||
fs_id = unquote(parsed.username) if parsed.username else None
|
||||
password = unquote(parsed.password) if parsed.password else None
|
||||
|
||||
# 解析 path
|
||||
path = unquote(parsed.path) if parsed.path else "/"
|
||||
if not path:
|
||||
path = "/"
|
||||
|
||||
# 解析 query
|
||||
query: dict[str, str] | None = None
|
||||
if parsed.query:
|
||||
raw_query = parse_qs(parsed.query, keep_blank_values=True)
|
||||
query = {k: v[0] for k, v in raw_query.items()}
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
query=query,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
namespace: FileSystemNamespace,
|
||||
path: str = "/",
|
||||
fs_id: str | None = None,
|
||||
password: str | None = None,
|
||||
) -> "DiskNextURI":
|
||||
"""
|
||||
构建 URI 实例
|
||||
|
||||
:param namespace: 命名空间
|
||||
:param path: 路径
|
||||
:param fs_id: 文件系统标识符
|
||||
:param password: 访问密码
|
||||
:return: DiskNextURI 实例
|
||||
"""
|
||||
# 确保 path 以 / 开头
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
|
||||
return cls(
|
||||
fs_id=fs_id,
|
||||
namespace=namespace,
|
||||
path=path,
|
||||
password=password,
|
||||
)
|
||||
|
||||
# === 实例方法 ===
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""
|
||||
序列化为 URI 字符串
|
||||
|
||||
:return: URI 字符串,如 "disknext://my/docs/readme.md"
|
||||
"""
|
||||
result = "disknext://"
|
||||
|
||||
# fs_id 和 password
|
||||
if self.fs_id:
|
||||
result += quote(self.fs_id, safe="")
|
||||
if self.password:
|
||||
result += ":" + quote(self.password, safe="")
|
||||
result += "@"
|
||||
|
||||
# namespace
|
||||
result += self.namespace.value
|
||||
|
||||
# path
|
||||
result += self.path
|
||||
|
||||
# query
|
||||
if self.query:
|
||||
result += "?" + urlencode(self.query)
|
||||
|
||||
return result
|
||||
|
||||
def join(self, *elements: str) -> "DiskNextURI":
|
||||
"""
|
||||
拼接路径元素,返回新 URI
|
||||
|
||||
:param elements: 路径元素
|
||||
:return: 新的 DiskNextURI 实例
|
||||
"""
|
||||
base = self.path.rstrip("/")
|
||||
for element in elements:
|
||||
element = element.strip("/")
|
||||
if element:
|
||||
base += "/" + element
|
||||
|
||||
if not base:
|
||||
base = "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=base,
|
||||
password=self.password,
|
||||
query=self.query,
|
||||
)
|
||||
|
||||
def dir_uri(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回父目录的 URI
|
||||
|
||||
:return: 父目录的 DiskNextURI 实例
|
||||
"""
|
||||
parts = self.path_parts
|
||||
if not parts:
|
||||
# 已经是根目录
|
||||
return self.root()
|
||||
|
||||
parent_path = "/" + "/".join(parts[:-1])
|
||||
if not parent_path.endswith("/"):
|
||||
parent_path += "/"
|
||||
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path=parent_path,
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def root(self) -> "DiskNextURI":
|
||||
"""
|
||||
返回根目录的 URI(保留 namespace 和 fs_id)
|
||||
|
||||
:return: 根目录的 DiskNextURI 实例
|
||||
"""
|
||||
return DiskNextURI(
|
||||
fs_id=self.fs_id,
|
||||
namespace=self.namespace,
|
||||
path="/",
|
||||
password=self.password,
|
||||
)
|
||||
|
||||
def name(self) -> str:
|
||||
"""
|
||||
返回路径的最后一段(文件名或目录名)
|
||||
|
||||
:return: 文件名或目录名,根目录返回空字符串
|
||||
"""
|
||||
parts = self.path_parts
|
||||
return parts[-1] if parts else ""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.to_string()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DiskNextURI({self.to_string()!r})"
|
||||
@@ -60,8 +60,8 @@ class UserFilterParams(SQLModelBase):
|
||||
group_id: UUID | None = None
|
||||
"""按用户组UUID筛选"""
|
||||
|
||||
username_contains: str | None = Field(default=None, max_length=50)
|
||||
"""用户名包含(不区分大小写的模糊搜索)"""
|
||||
email_contains: str | None = Field(default=None, max_length=50)
|
||||
"""邮箱包含(不区分大小写的模糊搜索)"""
|
||||
|
||||
nickname_contains: str | None = Field(default=None, max_length=50)
|
||||
"""昵称包含(不区分大小写的模糊搜索)"""
|
||||
@@ -75,8 +75,8 @@ class UserFilterParams(SQLModelBase):
|
||||
class UserBase(SQLModelBase):
|
||||
"""用户基础字段,供数据库模型和 DTO 共享"""
|
||||
|
||||
username: str
|
||||
"""用户名"""
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
@@ -90,8 +90,8 @@ class UserBase(SQLModelBase):
|
||||
class LoginRequest(SQLModelBase):
|
||||
"""登录请求 DTO"""
|
||||
|
||||
username: str
|
||||
"""用户名或邮箱"""
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
@@ -99,15 +99,15 @@ class LoginRequest(SQLModelBase):
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
two_fa_code: int | None = Field(min_length=6, max_length=6)
|
||||
two_fa_code: int | None = Field(default=None, min_length=6, max_length=6)
|
||||
"""两步验证代码"""
|
||||
|
||||
|
||||
class RegisterRequest(SQLModelBase):
|
||||
"""注册请求 DTO"""
|
||||
|
||||
username: str
|
||||
"""用户名,唯一,一经注册不可更改"""
|
||||
email: str
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
@@ -116,6 +116,20 @@ class RegisterRequest(SQLModelBase):
|
||||
"""验证码"""
|
||||
|
||||
|
||||
class BatchDeleteRequest(SQLModelBase):
|
||||
"""批量删除请求 DTO"""
|
||||
|
||||
ids: list[UUID]
|
||||
"""待删除 UUID 列表"""
|
||||
|
||||
|
||||
class RefreshTokenRequest(SQLModelBase):
|
||||
"""刷新令牌请求 DTO"""
|
||||
|
||||
refresh_token: str
|
||||
"""刷新令牌"""
|
||||
|
||||
|
||||
class WebAuthnInfo(SQLModelBase):
|
||||
"""WebAuthn 信息 DTO"""
|
||||
|
||||
@@ -137,6 +151,22 @@ class WebAuthnInfo(SQLModelBase):
|
||||
transports: list[str]
|
||||
"""支持的传输方式"""
|
||||
|
||||
class JWTPayload(SQLModelBase):
|
||||
"""JWT 访问令牌解析后的 claims"""
|
||||
|
||||
sub: UUID
|
||||
"""用户 ID"""
|
||||
|
||||
jti: UUID
|
||||
"""令牌唯一标识符"""
|
||||
|
||||
status: UserStatus
|
||||
"""用户状态"""
|
||||
|
||||
group: "GroupClaims"
|
||||
"""用户组权限快照"""
|
||||
|
||||
|
||||
class AccessTokenBase(BaseModel):
|
||||
"""访问令牌响应 DTO"""
|
||||
|
||||
@@ -166,6 +196,9 @@ class UserResponse(ResponseBase):
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""用户昵称"""
|
||||
|
||||
@@ -184,11 +217,23 @@ class UserResponse(ResponseBase):
|
||||
tags: list[str] = []
|
||||
"""用户标签列表"""
|
||||
|
||||
class UserStorageResponse(SQLModelBase):
|
||||
"""用户存储信息 DTO"""
|
||||
|
||||
used: int
|
||||
"""已用存储空间(字节)"""
|
||||
|
||||
free: int
|
||||
"""剩余存储空间(字节)"""
|
||||
|
||||
total: int
|
||||
"""总存储空间(字节)"""
|
||||
|
||||
|
||||
class UserPublic(UserBase):
|
||||
"""用户公开信息 DTO,用于 API 响应"""
|
||||
|
||||
id: UUID | None = None
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
nickname: str | None = None
|
||||
@@ -206,6 +251,9 @@ class UserPublic(UserBase):
|
||||
group_id: UUID | None = None
|
||||
"""所属用户组UUID"""
|
||||
|
||||
group_name: str | None = None
|
||||
"""用户组名称"""
|
||||
|
||||
two_factor: str | None = None
|
||||
"""两步验证密钥(32位字符串,null 表示未启用)"""
|
||||
|
||||
@@ -219,29 +267,63 @@ class UserPublic(UserBase):
|
||||
class UserSettingResponse(SQLModelBase):
|
||||
"""用户设置响应 DTO"""
|
||||
|
||||
authn: "AuthnResponse | None" = None
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
email: str
|
||||
"""用户邮箱"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""昵称"""
|
||||
|
||||
created_at: datetime
|
||||
"""用户注册时间"""
|
||||
|
||||
group_name: str
|
||||
"""用户所属用户组名称"""
|
||||
|
||||
language: str
|
||||
"""语言偏好"""
|
||||
|
||||
timezone: int
|
||||
"""时区"""
|
||||
|
||||
authn: "list[AuthnResponse] | None" = None
|
||||
"""认证信息"""
|
||||
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
prefer_theme: str = "#5898d4"
|
||||
"""用户首选主题"""
|
||||
|
||||
themes: dict[str, str] = {}
|
||||
"""用户主题配置"""
|
||||
|
||||
two_factor: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
uid: UUID | None = None
|
||||
"""用户UUID"""
|
||||
|
||||
|
||||
# ==================== 管理员用户管理 DTO ====================
|
||||
|
||||
class UserAdminCreateRequest(SQLModelBase):
|
||||
"""管理员创建用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""用户邮箱"""
|
||||
|
||||
password: str
|
||||
"""用户密码(明文,由服务端加密)"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
|
||||
group_id: UUID
|
||||
"""所属用户组UUID"""
|
||||
|
||||
status: UserStatus = UserStatus.ACTIVE
|
||||
"""用户状态"""
|
||||
|
||||
|
||||
class UserAdminUpdateRequest(SQLModelBase):
|
||||
"""管理员更新用户请求 DTO"""
|
||||
|
||||
email: str = Field(max_length=50)
|
||||
"""邮箱"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""昵称"""
|
||||
@@ -304,10 +386,11 @@ class UserAdminDetailResponse(UserPublic):
|
||||
|
||||
|
||||
# 前向引用导入
|
||||
from .group import GroupResponse # noqa: E402
|
||||
from .group import GroupClaims, GroupResponse # noqa: E402
|
||||
from .user_authn import AuthnResponse # noqa: E402
|
||||
|
||||
# 更新前向引用
|
||||
JWTPayload.model_rebuild()
|
||||
UserResponse.model_rebuild()
|
||||
UserSettingResponse.model_rebuild()
|
||||
|
||||
@@ -317,8 +400,8 @@ UserSettingResponse.model_rebuild()
|
||||
class User(UserBase, UUIDTableBaseMixin):
|
||||
"""用户模型"""
|
||||
|
||||
username: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户名,唯一,一经注册不可更改"""
|
||||
email: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
nickname: str | None = Field(default=None, max_length=50)
|
||||
"""用于公开展示的名字,可使用真实姓名或昵称"""
|
||||
@@ -426,8 +509,10 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
)
|
||||
|
||||
def to_public(self) -> "UserPublic":
|
||||
"""转换为公开 DTO,排除敏感字段"""
|
||||
return UserPublic.model_validate(self)
|
||||
"""转换为公开 DTO,排除敏感字段。需要预加载 group 关系。"""
|
||||
data = UserPublic.model_validate(self)
|
||||
data.group_name = self.group.name
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def get_with_count(
|
||||
@@ -457,8 +542,8 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
if filter_params.group_id is not None:
|
||||
filter_conditions.append(cls.group_id == filter_params.group_id)
|
||||
|
||||
if filter_params.username_contains is not None:
|
||||
filter_conditions.append(cls.username.ilike(f"%{filter_params.username_contains}%"))
|
||||
if filter_params.email_contains is not None:
|
||||
filter_conditions.append(cls.email.ilike(f"%{filter_params.email_contains}%"))
|
||||
|
||||
if filter_params.nickname_contains is not None:
|
||||
filter_conditions.append(cls.nickname.ilike(f"%{filter_params.nickname_contains}%"))
|
||||
@@ -482,4 +567,5 @@ class User(UserBase, UUIDTableBaseMixin):
|
||||
order_by=order_by,
|
||||
filter=filter,
|
||||
table_view=table_view,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -49,13 +49,13 @@ def main():
|
||||
("itsdangerous", "签名工具"),
|
||||
|
||||
# 项目模块
|
||||
("models", "数据库模型"),
|
||||
("models.user", "用户模型"),
|
||||
("models.group", "用户组模型"),
|
||||
("models.object", "对象模型"),
|
||||
("models.setting", "设置模型"),
|
||||
("models.policy", "策略模型"),
|
||||
("models.database", "数据库连接"),
|
||||
("sqlmodels", "数据库模型"),
|
||||
("sqlmodels.user", "用户模型"),
|
||||
("sqlmodels.group", "用户组模型"),
|
||||
("sqlmodels.object", "对象模型"),
|
||||
("sqlmodels.setting", "设置模型"),
|
||||
("sqlmodels.policy", "策略模型"),
|
||||
("sqlmodels.database", "数据库连接"),
|
||||
("utils.password.pwd", "密码工具"),
|
||||
("utils.JWT.JWT", "JWT 工具"),
|
||||
("service.user.login", "登录服务"),
|
||||
|
||||
@@ -23,13 +23,13 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from main import app
|
||||
from models.database import get_session
|
||||
from models.group import Group, GroupOptions
|
||||
from models.migration import migration
|
||||
from models.object import Object, ObjectType
|
||||
from models.policy import Policy, PolicyType
|
||||
from models.user import User
|
||||
from utils.JWT.JWT import create_access_token
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.group import Group, GroupClaims, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.user import User, UserStatus
|
||||
from utils.JWT import create_access_token
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession):
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
创建测试用户并返回 {id, username, password, token}
|
||||
创建测试用户并返回 {id, email, password, token}
|
||||
|
||||
创建一个普通用户,包含用户组、存储策略和根目录。
|
||||
"""
|
||||
@@ -190,10 +190,10 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建测试用户
|
||||
password = "test_password_123"
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=100,
|
||||
group_id=group.id,
|
||||
@@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建用户根目录
|
||||
root_folder = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -211,14 +211,24 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
group.options = group_options
|
||||
group_claims = GroupClaims.from_group(group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(user.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
@@ -227,7 +237,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
获取管理员用户 {id, username, token}
|
||||
获取管理员用户 {id, email, token}
|
||||
|
||||
创建具有管理员权限的用户。
|
||||
"""
|
||||
@@ -267,10 +277,10 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建管理员用户
|
||||
password = "admin_password_456"
|
||||
admin = User(
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=9999,
|
||||
group_id=admin_group.id,
|
||||
@@ -279,7 +289,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建管理员根目录
|
||||
root_folder = Object(
|
||||
name=admin.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=admin.id,
|
||||
@@ -288,14 +298,24 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
)
|
||||
await root_folder.save(db_session)
|
||||
|
||||
# 构建权限快照
|
||||
admin_group.options = admin_group_options
|
||||
admin_group_claims = GroupClaims.from_group(admin_group)
|
||||
|
||||
# 生成访问令牌
|
||||
access_token, _ = create_access_token({"sub": str(admin.id)})
|
||||
from uuid import uuid4
|
||||
access_token_obj = create_access_token(
|
||||
sub=admin.id,
|
||||
jti=uuid4(),
|
||||
status=admin.status.value,
|
||||
group=admin_group_claims,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": admin.id,
|
||||
"username": admin.username,
|
||||
"email": admin.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"token": access_token_obj.access_token,
|
||||
"group_id": admin_group.id,
|
||||
"policy_id": policy.id,
|
||||
}
|
||||
|
||||
@@ -8,9 +8,9 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from models.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from tests.fixtures import UserFactory, GroupFactory, ObjectFactory
|
||||
|
||||
|
||||
@@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession):
|
||||
user = await UserFactory.create(
|
||||
db_session,
|
||||
group_id=group.id,
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password="password123"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.group_id == group.id
|
||||
assert user.status is True
|
||||
|
||||
@@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession):
|
||||
async def test_object_factory(db_session: AsyncSession):
|
||||
"""测试对象工厂的基本功能"""
|
||||
# 准备依赖
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = await GroupFactory.create(db_session)
|
||||
user = await UserFactory.create(db_session, group_id=group.id)
|
||||
@@ -102,7 +102,7 @@ async def test_conftest_fixtures(
|
||||
"""测试 conftest.py 中的 fixtures"""
|
||||
# 验证 test_user fixture
|
||||
assert test_user["id"] is not None
|
||||
assert test_user["username"] == "testuser"
|
||||
assert test_user["email"] == "testuser@test.local"
|
||||
assert test_user["token"] is not None
|
||||
|
||||
# 验证 auth_headers fixture
|
||||
@@ -112,7 +112,7 @@ async def test_conftest_fixtures(
|
||||
# 验证用户在数据库中存在
|
||||
user = await User.get(db_session, User.id == test_user["id"])
|
||||
assert user is not None
|
||||
assert user.username == test_user["username"]
|
||||
assert user.email == test_user["email"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -145,7 +145,7 @@ async def test_test_directory_fixture(
|
||||
@pytest.mark.integration
|
||||
async def test_nested_structure_factory(db_session: AsyncSession):
|
||||
"""测试嵌套结构工厂"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
# 准备依赖
|
||||
group = await GroupFactory.create(db_session)
|
||||
|
||||
2
tests/fixtures/groups.py
vendored
2
tests/fixtures/groups.py
vendored
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
|
||||
|
||||
class GroupFactory:
|
||||
|
||||
6
tests/fixtures/objects.py
vendored
6
tests/fixtures/objects.py
vendored
@@ -7,8 +7,8 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
|
||||
|
||||
class ObjectFactory:
|
||||
@@ -119,7 +119,7 @@ class ObjectFactory:
|
||||
Object: 创建的根目录实例
|
||||
"""
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
|
||||
50
tests/fixtures/users.py
vendored
50
tests/fixtures/users.py
vendored
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import User
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class UserFactory:
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs
|
||||
) -> User:
|
||||
@@ -28,7 +28,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: test_user_{随机})
|
||||
email: 用户邮箱(默认: test_user_{随机}@test.local)
|
||||
password: 明文密码(默认: password123)
|
||||
**kwargs: 其他用户字段
|
||||
|
||||
@@ -37,15 +37,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"test_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"test_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
if password is None:
|
||||
password = "password123"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=kwargs.get("nickname", username),
|
||||
email=email,
|
||||
nickname=kwargs.get("nickname", email),
|
||||
password=Password.hash(password),
|
||||
status=kwargs.get("status", True),
|
||||
storage=kwargs.get("storage", 0),
|
||||
@@ -67,7 +67,7 @@ class UserFactory:
|
||||
async def create_admin(
|
||||
session: AsyncSession,
|
||||
admin_group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
admin_group_id: 管理员组UUID
|
||||
username: 用户名(默认: admin_{随机})
|
||||
email: 用户邮箱(默认: admin_{随机}@disknext.local)
|
||||
password: 明文密码(默认: admin_password)
|
||||
|
||||
返回:
|
||||
@@ -84,15 +84,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"admin_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local"
|
||||
|
||||
if password is None:
|
||||
password = "admin_password"
|
||||
|
||||
admin = User(
|
||||
username=username,
|
||||
nickname=f"管理员 {username}",
|
||||
email=email,
|
||||
nickname=f"管理员 {email}",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
storage=0,
|
||||
@@ -108,7 +108,7 @@ class UserFactory:
|
||||
async def create_banned(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建被封禁用户
|
||||
@@ -116,19 +116,19 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: banned_user_{随机})
|
||||
email: 用户邮箱(默认: banned_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的被封禁用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"banned_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
banned_user = User(
|
||||
username=username,
|
||||
nickname=f"封禁用户 {username}",
|
||||
email=email,
|
||||
nickname=f"封禁用户 {email}",
|
||||
password=Password.hash("banned_password"),
|
||||
status=False, # 封禁状态
|
||||
storage=0,
|
||||
@@ -145,7 +145,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
storage_bytes: int,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建已使用指定存储空间的用户
|
||||
@@ -154,19 +154,19 @@ class UserFactory:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
storage_bytes: 已使用的存储空间(字节)
|
||||
username: 用户名(默认: storage_user_{随机})
|
||||
email: 用户邮箱(默认: storage_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"storage_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
email=email,
|
||||
nickname=email,
|
||||
password=Password.hash("password123"),
|
||||
status=True,
|
||||
storage=storage_bytes,
|
||||
|
||||
@@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data(
|
||||
if len(users) > 0:
|
||||
user = users[0]
|
||||
assert "id" in user
|
||||
assert "username" in user
|
||||
assert "email" in user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient):
|
||||
"""测试创建用户需要认证"""
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin(
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
headers=auth_headers,
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from uuid import UUID
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取目录需要认证"""
|
||||
response = await async_client.get("/api/directory/testuser")
|
||||
response = await async_client.get("/api/directory/")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ async def test_directory_get_root(
|
||||
):
|
||||
"""测试获取用户根目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -45,7 +45,7 @@ async def test_directory_get_nested(
|
||||
):
|
||||
"""测试获取嵌套目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -63,7 +63,7 @@ async def test_directory_get_contains_children(
|
||||
):
|
||||
"""测试目录包含子对象"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -75,19 +75,6 @@ async def test_directory_get_contains_children(
|
||||
assert len(objects) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_forbidden_other_user(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试访问他人目录返回 403"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_not_found(
|
||||
async_client: AsyncClient,
|
||||
@@ -95,23 +82,23 @@ async def test_directory_not_found(
|
||||
):
|
||||
"""测试目录不存在返回 404"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/nonexistent",
|
||||
"/api/directory/nonexistent",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_empty_path_returns_400(
|
||||
async def test_directory_root_returns_200(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试空路径返回 400"""
|
||||
"""测试根目录端点返回 200"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -121,7 +108,7 @@ async def test_directory_response_includes_policy(
|
||||
):
|
||||
"""测试目录响应包含存储策略"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent(
|
||||
"""测试在他人目录下创建目录返回 404"""
|
||||
# 先用管理员账号获取管理员的根目录ID
|
||||
admin_response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
"/api/directory/",
|
||||
headers=admin_headers
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
|
||||
@@ -16,7 +16,7 @@ async def test_user_login_success(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": test_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -38,7 +38,7 @@ async def test_user_login_wrong_password(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": "wrongpassword",
|
||||
}
|
||||
)
|
||||
@@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": "nonexistent",
|
||||
"username": "nonexistent@test.local",
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -67,7 +67,7 @@ async def test_user_login_user_banned(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": banned_user_info["username"],
|
||||
"username": banned_user_info["email"],
|
||||
"password": banned_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": "newuser",
|
||||
"email": "newuser@test.local",
|
||||
"password": "newpass123",
|
||||
}
|
||||
)
|
||||
@@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "user_id" in data["data"]
|
||||
assert "username" in data["data"]
|
||||
assert data["data"]["username"] == "newuser"
|
||||
assert "email" in data["data"]
|
||||
assert data["data"]["email"] == "newuser@test.local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_register_duplicate_username(
|
||||
async def test_user_register_duplicate_email(
|
||||
async_client: AsyncClient,
|
||||
test_user_info: dict[str, str]
|
||||
):
|
||||
"""测试重复用户名返回 400"""
|
||||
"""测试重复邮箱返回 400"""
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": test_user_info["username"],
|
||||
"email": test_user_info["email"],
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -143,8 +143,8 @@ async def test_user_me_returns_user_info(
|
||||
assert "data" in data
|
||||
user_data = data["data"]
|
||||
assert "id" in user_data
|
||||
assert "username" in user_data
|
||||
assert user_data["username"] == "testuser"
|
||||
assert "email" in user_data
|
||||
assert user_data["email"] == "testuser@test.local"
|
||||
assert "group" in user_data
|
||||
assert "tags" in user_data
|
||||
|
||||
|
||||
@@ -22,10 +22,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from main import app
|
||||
from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels import Group, GroupClaims, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels.user import UserStatus
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
from utils.JWT import JWT
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== 事件循环配置 ====================
|
||||
@@ -92,6 +93,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.VIEW, name="home_view_method", value="list"),
|
||||
Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"),
|
||||
Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
|
||||
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
|
||||
@@ -180,43 +182,40 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 6. 创建测试用户
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(test_user)
|
||||
|
||||
admin_user = User(
|
||||
id=uuid4(),
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=admin_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(admin_user)
|
||||
|
||||
banned_user = User(
|
||||
id=uuid4(),
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED,
|
||||
storage=0,
|
||||
score=0,
|
||||
group_id=default_group.id,
|
||||
avatar="default",
|
||||
theme="system",
|
||||
)
|
||||
test_session.add(banned_user)
|
||||
|
||||
@@ -230,7 +229,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 7. 创建用户根目录
|
||||
test_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=test_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=test_user.id,
|
||||
parent_id=None,
|
||||
@@ -241,7 +240,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
admin_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=admin_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user.id,
|
||||
parent_id=None,
|
||||
@@ -255,6 +254,10 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 8. 设置JWT密钥(从数据库加载)
|
||||
JWT.SECRET_KEY = "test_secret_key_for_jwt_token_generation"
|
||||
|
||||
# 刷新 group options
|
||||
await test_session.refresh(default_group_options)
|
||||
await test_session.refresh(admin_group_options)
|
||||
|
||||
return test_session
|
||||
|
||||
|
||||
@@ -264,7 +267,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
def test_user_info() -> dict[str, str]:
|
||||
"""测试用户信息"""
|
||||
return {
|
||||
"username": "testuser",
|
||||
"email": "testuser@test.local",
|
||||
"password": "testpass123",
|
||||
}
|
||||
|
||||
@@ -273,7 +276,7 @@ def test_user_info() -> dict[str, str]:
|
||||
def admin_user_info() -> dict[str, str]:
|
||||
"""管理员用户信息"""
|
||||
return {
|
||||
"username": "admin",
|
||||
"email": "admin@disknext.local",
|
||||
"password": "adminpass123",
|
||||
}
|
||||
|
||||
@@ -282,41 +285,75 @@ def admin_user_info() -> dict[str, str]:
|
||||
def banned_user_info() -> dict[str, str]:
|
||||
"""封禁用户信息"""
|
||||
return {
|
||||
"username": "banneduser",
|
||||
"email": "banneduser@test.local",
|
||||
"password": "banned123",
|
||||
}
|
||||
|
||||
|
||||
# ==================== JWT Token ====================
|
||||
|
||||
@pytest.fixture
|
||||
def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
def _build_group_claims(group: Group, group_options: GroupOptions | None) -> GroupClaims:
|
||||
"""从 Group 对象构建 GroupClaims"""
|
||||
group.options = group_options
|
||||
return GroupClaims.from_group(group)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def test_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成测试用户的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
@pytest_asyncio.fixture
|
||||
async def admin_user_token(initialized_db: AsyncSession) -> str:
|
||||
"""生成管理员的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": admin_user_info["username"]},
|
||||
user = await User.get(initialized_db, User.email == "admin@disknext.local")
|
||||
group = await Group.get(initialized_db, Group.id == user.group_id)
|
||||
group_options = await GroupOptions.get(initialized_db, GroupOptions.group_id == group.id)
|
||||
group_claims = _build_group_claims(group, group_options)
|
||||
|
||||
result = create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
status=user.status.value,
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
return result.access_token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expired_token() -> str:
|
||||
"""生成过期的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "testuser"},
|
||||
expires_delta=timedelta(seconds=-1), # 已过期
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
return token
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
return result.access_token
|
||||
|
||||
|
||||
# ==================== 认证头 ====================
|
||||
@@ -362,7 +399,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU
|
||||
"""创建测试目录结构"""
|
||||
|
||||
# 获取测试用户和根目录
|
||||
test_user = await User.get(initialized_db, User.username == "testuser")
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
test_user_root = await Object.get_root(initialized_db, test_user.id)
|
||||
|
||||
default_policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
"""
|
||||
认证中间件集成测试
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from datetime import timedelta
|
||||
|
||||
from utils.JWT import JWT
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token
|
||||
import utils.JWT as JWT
|
||||
|
||||
|
||||
# ==================== AuthRequired 测试 ====================
|
||||
@@ -66,11 +70,14 @@ async def test_auth_required_valid_token(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
"""测试缺少sub字段的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"other_field": "value"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
"""测试缺少必要字段的token返回 401"""
|
||||
import jwt as pyjwt
|
||||
# 手动构建一个缺少 status 和 group 的 token
|
||||
payload = {
|
||||
"other_field": "value",
|
||||
"exp": int((__import__('datetime').datetime.now(__import__('datetime').timezone.utc) + timedelta(hours=1)).timestamp()),
|
||||
}
|
||||
token = pyjwt.encode(payload, JWT.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
@@ -81,16 +88,29 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
|
||||
"""测试用户不存在的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "nonexistent_user"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
"""测试用户不存在的token返回 403 或 401(取决于 Redis 可用性)"""
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(), # 不存在的用户 UUID
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
response = await async_client.get(
|
||||
"/api/user/me",
|
||||
headers={"Authorization": f"Bearer {token}"}
|
||||
headers={"Authorization": f"Bearer {result.access_token}"}
|
||||
)
|
||||
# auth_required 会查库,用户不存在时返回 401
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -178,12 +198,12 @@ async def test_auth_on_directory_endpoint(
|
||||
):
|
||||
"""测试目录端点应用认证"""
|
||||
# 无认证
|
||||
response_no_auth = await async_client.get("/api/directory/testuser")
|
||||
response_no_auth = await async_client.get("/api/directory/")
|
||||
assert response_no_auth.status_code == 401
|
||||
|
||||
# 有认证
|
||||
response_with_auth = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response_with_auth.status_code == 200
|
||||
@@ -234,23 +254,36 @@ async def test_auth_on_storage_endpoint(
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
"""测试刷新token格式正确"""
|
||||
refresh_token, _ = JWT.create_refresh_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
expires_delta=timedelta(days=7)
|
||||
result = create_refresh_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
expires_delta=timedelta(days=7),
|
||||
)
|
||||
|
||||
assert isinstance(refresh_token, str)
|
||||
assert len(refresh_token) > 0
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert len(result.refresh_token) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_access_token_format(test_user_info: dict[str, str]):
|
||||
"""测试访问token格式正确"""
|
||||
access_token, expires = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
expires_delta=timedelta(hours=1)
|
||||
group_claims = GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=0,
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
admin=False,
|
||||
speed_limit=0,
|
||||
)
|
||||
result = create_access_token(
|
||||
sub=uuid4(),
|
||||
jti=uuid4(),
|
||||
status="active",
|
||||
group=group_claims,
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert isinstance(access_token, str)
|
||||
assert len(access_token) > 0
|
||||
assert expires is not None
|
||||
assert isinstance(result.access_token, str)
|
||||
assert len(result.access_token) > 0
|
||||
assert result.access_expires is not None
|
||||
|
||||
@@ -3,14 +3,14 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_db():
|
||||
"""测试创建数据库结构"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
"""测试获取数据库连接Session"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@@ -20,8 +20,8 @@ async def db_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration():
|
||||
"""测试数据库创建并初始化配置"""
|
||||
from models import migration
|
||||
from models import database
|
||||
from sqlmodels import migration
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database
|
||||
from models.setting import Setting
|
||||
from sqlmodels import database
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from models.user import User
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.user import User
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
@@ -17,7 +17,7 @@ async def test_user_curd():
|
||||
created_group = await test_user_group.save(session)
|
||||
|
||||
test_user = User(
|
||||
username='test_user',
|
||||
email='test_user@test.local',
|
||||
password='test_password',
|
||||
group_id=created_group.id
|
||||
)
|
||||
@@ -27,7 +27,7 @@ async def test_user_curd():
|
||||
|
||||
# 验证用户是否存在
|
||||
assert created_user.id is not None
|
||||
assert created_user.username == 'test_user'
|
||||
assert created_user.email == 'test_user@test.local'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
@@ -35,18 +35,18 @@ async def test_user_curd():
|
||||
fetched_user = await User.get(session, User.id == created_user.id)
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.username == 'test_user'
|
||||
assert fetched_user.email == 'test_user@test.local'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await fetched_user.update(
|
||||
session,
|
||||
{"username": "updated_user", "password": "updated_password"}
|
||||
{"email": "updated_user@test.local", "password": "updated_password"}
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.username == 'updated_user'
|
||||
assert updated_user.email == 'updated_user@test.local'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 更新数据
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后名称")
|
||||
updated_group = await group.update(db_session, update_data)
|
||||
|
||||
@@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 更新记录
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后的名称")
|
||||
group = await group.update(db_session, update_data)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions, GroupResponse
|
||||
from sqlmodels.group import Group, GroupOptions, GroupResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -5,21 +5,21 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_folder(db_session: AsyncSession):
|
||||
"""测试创建目录"""
|
||||
# 创建必要的依赖数据
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_file(db_session: AsyncSession):
|
||||
"""测试创建文件"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=1024,
|
||||
source_name="test_source.txt"
|
||||
)
|
||||
file = await file.save(db_session)
|
||||
|
||||
@@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
assert file.name == "test.txt"
|
||||
assert file.type == ObjectType.FILE
|
||||
assert file.size == 1024
|
||||
assert file.source_name == "test_source.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_file_property(db_session: AsyncSession):
|
||||
"""测试 is_file 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
"""测试 is_folder 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_root(db_session: AsyncSession):
|
||||
"""测试 get_root() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="rootuser", password="password", group_id=group.id)
|
||||
user = User(email="rootuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
"""测试获取根目录"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
root = await root.save(db_session)
|
||||
|
||||
# 通过路径获取根目录
|
||||
result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username)
|
||||
result = await Object.get_by_path(db_session, user.id, "/")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == root.id
|
||||
@@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
"""测试获取嵌套路径"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="nesteduser", password="password", group_id=group.id)
|
||||
user = User(email="nesteduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
|
||||
# 创建目录结构: root -> docs -> work -> project
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/nesteduser/docs/work/project",
|
||||
user.username
|
||||
"/docs/work/project",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
@@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
"""测试路径不存在"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="notfounduser", password="password", group_id=group.id)
|
||||
user = User(email="notfounduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/notfounduser/nonexistent",
|
||||
user.username
|
||||
"/nonexistent",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_children(db_session: AsyncSession):
|
||||
"""测试 get_children() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="childrenuser", password="password", group_id=group.id)
|
||||
user = User(email="childrenuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
"""测试父子关系"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="reluser", password="password", group_id=group.id)
|
||||
user = User(email="reluser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
"""测试同目录名称唯一约束"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="uniqueuser", password="password", group_id=group.id)
|
||||
user = User(email="uniqueuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await file2.save(db_session)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_full_path(db_session: AsyncSession):
|
||||
"""测试 get_full_path() 方法"""
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
policy = await policy.save(db_session)
|
||||
|
||||
# 创建目录结构: root -> docs -> images -> photo.jpg
|
||||
root = Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
root = await root.save(db_session)
|
||||
|
||||
docs = Object(
|
||||
name="docs",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
docs = await docs.save(db_session)
|
||||
|
||||
images = Object(
|
||||
name="images",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=docs.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
images = await images.save(db_session)
|
||||
|
||||
photo = Object(
|
||||
name="photo.jpg",
|
||||
type=ObjectType.FILE,
|
||||
parent_id=images.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=2048
|
||||
)
|
||||
photo = await photo.save(db_session)
|
||||
|
||||
# 测试完整路径
|
||||
full_path = await photo.get_full_path(db_session)
|
||||
assert full_path == "/docs/images/photo.jpg"
|
||||
|
||||
# 测试根目录的 full_path
|
||||
root_path = await root.get_full_path(db_session)
|
||||
assert root_path == "/"
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.setting import Setting, SettingsType
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
|
||||
setting = await setting.save(db_session)
|
||||
|
||||
# 更新值
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class SettingUpdate(SQLModelBase):
|
||||
value: str | None = None
|
||||
|
||||
273
tests/unit/models/test_uri.py
Normal file
273
tests/unit/models/test_uri.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DiskNextURI 模型的单元测试
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from sqlmodels.uri import DiskNextURI, FileSystemNamespace
|
||||
|
||||
|
||||
class TestDiskNextURIParse:
|
||||
"""测试 URI 解析"""
|
||||
|
||||
def test_parse_my_root(self):
|
||||
"""测试解析个人空间根目录"""
|
||||
uri = DiskNextURI.parse("disknext://my/")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
assert uri.password is None
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_my_with_path(self):
|
||||
"""测试解析个人空间带路径"""
|
||||
uri = DiskNextURI.parse("disknext://my/docs/readme.md")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/docs/readme.md"
|
||||
assert uri.fs_id is None
|
||||
assert uri.path_parts == ["docs", "readme.md"]
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_parse_my_with_fs_id(self):
|
||||
"""测试解析带 fs_id 的个人空间"""
|
||||
uri = DiskNextURI.parse("disknext://some-uuid@my/docs")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.fs_id == "some-uuid"
|
||||
assert uri.path == "/docs"
|
||||
|
||||
def test_parse_share_with_code(self):
|
||||
"""测试解析分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123@share/")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.path == "/"
|
||||
assert uri.password is None
|
||||
|
||||
def test_parse_share_with_password(self):
|
||||
"""测试解析带密码的分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "mypass"
|
||||
assert uri.path == "/sub/dir"
|
||||
|
||||
def test_parse_trash(self):
|
||||
"""测试解析回收站"""
|
||||
uri = DiskNextURI.parse("disknext://trash/")
|
||||
assert uri.namespace == FileSystemNamespace.TRASH
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_with_query(self):
|
||||
"""测试解析带查询参数的 URI"""
|
||||
uri = DiskNextURI.parse("disknext://my/?name=report&type=file")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.query is not None
|
||||
assert uri.query["name"] == "report"
|
||||
assert uri.query["type"] == "file"
|
||||
|
||||
def test_parse_invalid_scheme(self):
|
||||
"""测试无效的协议前缀"""
|
||||
with pytest.raises(ValueError, match="disknext://"):
|
||||
DiskNextURI.parse("http://my/docs")
|
||||
|
||||
def test_parse_invalid_namespace(self):
|
||||
"""测试无效的命名空间"""
|
||||
with pytest.raises(ValueError, match="无效的命名空间"):
|
||||
DiskNextURI.parse("disknext://invalid/docs")
|
||||
|
||||
def test_parse_no_namespace(self):
|
||||
"""测试缺少命名空间"""
|
||||
with pytest.raises(ValueError):
|
||||
DiskNextURI.parse("disknext://")
|
||||
|
||||
|
||||
class TestDiskNextURIBuild:
|
||||
"""测试 URI 构建"""
|
||||
|
||||
def test_build_simple(self):
|
||||
"""测试简单构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
|
||||
def test_build_with_path(self):
|
||||
"""测试带路径构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_path_auto_prefix(self):
|
||||
"""测试路径自动添加 / 前缀"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_with_fs_id(self):
|
||||
"""测试带 fs_id 构建"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="abc123",
|
||||
password="secret",
|
||||
)
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "secret"
|
||||
|
||||
|
||||
class TestDiskNextURIToString:
|
||||
"""测试 URI 序列化"""
|
||||
|
||||
def test_to_string_simple(self):
|
||||
"""测试简单序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.to_string() == "disknext://my/"
|
||||
|
||||
def test_to_string_with_path(self):
|
||||
"""测试带路径序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.to_string() == "disknext://my/docs/readme.md"
|
||||
|
||||
def test_to_string_with_fs_id(self):
|
||||
"""测试带 fs_id 序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123")
|
||||
assert uri.to_string() == "disknext://uuid-123@my/"
|
||||
|
||||
def test_to_string_with_password(self):
|
||||
"""测试带密码序列化"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="code",
|
||||
password="pass",
|
||||
)
|
||||
assert uri.to_string() == "disknext://code:pass@share/"
|
||||
|
||||
def test_to_string_roundtrip(self):
|
||||
"""测试序列化-反序列化往返"""
|
||||
original = "disknext://abc123:pass@share/sub/dir"
|
||||
uri = DiskNextURI.parse(original)
|
||||
result = uri.to_string()
|
||||
assert result == original
|
||||
|
||||
|
||||
class TestDiskNextURIId:
|
||||
"""测试 id() 方法"""
|
||||
|
||||
def test_id_with_fs_id(self):
|
||||
"""测试有 fs_id 时返回 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid")
|
||||
assert uri.id("default") == "my-uuid"
|
||||
|
||||
def test_id_without_fs_id(self):
|
||||
"""测试无 fs_id 时返回默认值"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id("default-uuid") == "default-uuid"
|
||||
|
||||
def test_id_without_fs_id_no_default(self):
|
||||
"""测试无 fs_id 且无默认值时返回 None"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id() is None
|
||||
|
||||
|
||||
class TestDiskNextURIJoin:
|
||||
"""测试 join() 方法"""
|
||||
|
||||
def test_join_single(self):
|
||||
"""测试拼接单个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
joined = uri.join("readme.md")
|
||||
assert joined.path == "/docs/readme.md"
|
||||
|
||||
def test_join_multiple(self):
|
||||
"""测试拼接多个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
joined = uri.join("docs", "work", "report.pdf")
|
||||
assert joined.path == "/docs/work/report.pdf"
|
||||
|
||||
def test_join_preserves_metadata(self):
|
||||
"""测试 join 保留 namespace 和 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123")
|
||||
joined = uri.join("sub")
|
||||
assert joined.namespace == FileSystemNamespace.SHARE
|
||||
assert joined.fs_id == "code123"
|
||||
|
||||
|
||||
class TestDiskNextURIDirUri:
|
||||
"""测试 dir_uri() 方法"""
|
||||
|
||||
def test_dir_uri_file(self):
|
||||
"""测试获取文件的父目录 URI"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/docs/"
|
||||
|
||||
def test_dir_uri_root(self):
|
||||
"""测试根目录的 dir_uri 返回自身"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/"
|
||||
|
||||
|
||||
class TestDiskNextURIRoot:
|
||||
"""测试 root() 方法"""
|
||||
|
||||
def test_root_resets_path(self):
|
||||
"""测试 root 重置路径"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.MY,
|
||||
path="/docs/work/report.pdf",
|
||||
fs_id="uuid-123",
|
||||
)
|
||||
root = uri.root()
|
||||
assert root.path == "/"
|
||||
assert root.fs_id == "uuid-123"
|
||||
assert root.namespace == FileSystemNamespace.MY
|
||||
|
||||
|
||||
class TestDiskNextURIName:
|
||||
"""测试 name() 方法"""
|
||||
|
||||
def test_name_file(self):
|
||||
"""测试获取文件名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.name() == "readme.md"
|
||||
|
||||
def test_name_directory(self):
|
||||
"""测试获取目录名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work")
|
||||
assert uri.name() == "work"
|
||||
|
||||
def test_name_root(self):
|
||||
"""测试根目录的 name 返回空字符串"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.name() == ""
|
||||
|
||||
|
||||
class TestDiskNextURIProperties:
|
||||
"""测试属性方法"""
|
||||
|
||||
def test_path_parts(self):
|
||||
"""测试路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf")
|
||||
assert uri.path_parts == ["docs", "work", "report.pdf"]
|
||||
|
||||
def test_path_parts_root(self):
|
||||
"""测试根路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.path_parts == []
|
||||
|
||||
def test_is_root_true(self):
|
||||
"""测试 is_root 为真"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_is_root_false(self):
|
||||
"""测试 is_root 为假"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_str_representation(self):
|
||||
"""测试字符串表示"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert str(uri) == "disknext://my/docs"
|
||||
|
||||
def test_repr(self):
|
||||
"""测试 repr"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert "disknext://my/docs" in repr(uri)
|
||||
@@ -5,8 +5,8 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, ThemeType, UserPublic
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, ThemeType, UserPublic, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password="hashed_password",
|
||||
group_id=group.id
|
||||
@@ -26,23 +26,23 @@ async def test_user_create(db_session: AsyncSession):
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.nickname == "测试用户"
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
assert user.storage == 0
|
||||
assert user.score == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_unique_username(db_session: AsyncSession):
|
||||
"""测试用户名唯一约束"""
|
||||
async def test_user_unique_email(db_session: AsyncSession):
|
||||
"""测试邮箱唯一约束"""
|
||||
# 创建用户组
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 创建第一个用户
|
||||
user1 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password1",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession):
|
||||
|
||||
# 尝试创建同名用户
|
||||
user2 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password2",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="publicuser",
|
||||
email="publicuser@test.local",
|
||||
nickname="公开用户",
|
||||
password="secret_password",
|
||||
storage=1024,
|
||||
@@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
assert isinstance(public_user, UserPublic)
|
||||
assert public_user.id == user.id
|
||||
assert public_user.username == "publicuser"
|
||||
assert public_user.email == "publicuser@test.local"
|
||||
# 注意: UserPublic.nick 字段名与 User.nickname 不同,
|
||||
# model_validate 不会自动映射,所以 nick 为 None
|
||||
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
|
||||
@@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="vipuser",
|
||||
email="vipuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -125,13 +125,13 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="defaultuser",
|
||||
email="defaultuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.status is True
|
||||
assert user.status == UserStatus.ACTIVE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="storageuser",
|
||||
email="storageuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试默认值
|
||||
user1 = User(
|
||||
username="user1",
|
||||
email="user1@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 LIGHT
|
||||
user2 = User(
|
||||
username="user2",
|
||||
email="user2@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.LIGHT,
|
||||
group_id=group.id
|
||||
@@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 DARK
|
||||
user3 = User(
|
||||
username="user3",
|
||||
email="user3@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.DARK,
|
||||
group_id=group.id
|
||||
|
||||
@@ -4,8 +4,8 @@ Login 服务的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, LoginRequest, TokenResponse
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse, UserStatus
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -20,9 +20,9 @@ async def setup_user(db_session: AsyncSession):
|
||||
# 创建正常用户
|
||||
plain_password = "secure_password_123"
|
||||
user = User(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -41,9 +41,9 @@ async def setup_banned_user(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=False, # 封禁状态
|
||||
status=UserStatus.ADMIN_BANNED, # 封禁状态
|
||||
group_id=group.id
|
||||
)
|
||||
user = await user.save(db_session)
|
||||
@@ -61,9 +61,9 @@ async def setup_2fa_user(db_session: AsyncSession):
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user = User(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=True,
|
||||
status=UserStatus.ACTIVE,
|
||||
two_factor=secret,
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_not_found(db_session: AsyncSession):
|
||||
"""测试用户不存在"""
|
||||
login_request = LoginRequest(
|
||||
username="nonexistent_user",
|
||||
email="nonexistent@test.local",
|
||||
password="any_password"
|
||||
)
|
||||
|
||||
@@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
|
||||
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
"""测试密码错误"""
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password="wrong_password"
|
||||
)
|
||||
|
||||
@@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
"""测试用户被封禁"""
|
||||
login_request = LoginRequest(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password="password"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"]
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
@@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
)
|
||||
@@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
valid_code = totp.now()
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code=valid_code
|
||||
)
|
||||
@@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user):
|
||||
"""测试用户名大小写敏感"""
|
||||
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱大小写敏感"""
|
||||
user_data = setup_user
|
||||
|
||||
# 使用大写用户名登录(如果数据库是 loginuser)
|
||||
# 使用大写邮箱登录
|
||||
login_request = LoginRequest(
|
||||
username="LOGINUSER",
|
||||
email="LOGINUSER@TEST.LOCAL",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
# 应该失败,因为用户名大小写不匹配
|
||||
# 应该失败,因为邮箱大小写不匹配
|
||||
assert result is None
|
||||
|
||||
@@ -1,49 +1,86 @@
|
||||
"""
|
||||
JWT 工具的单元测试
|
||||
"""
|
||||
import time
|
||||
from datetime import timedelta, datetime, timezone
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
|
||||
from utils.JWT.JWT import create_access_token, create_refresh_token, SECRET_KEY
|
||||
from sqlmodels.group import GroupClaims
|
||||
from utils.JWT import create_access_token, create_refresh_token, build_token_payload
|
||||
|
||||
|
||||
# 测试用的 GroupClaims
|
||||
def _make_group_claims(admin: bool = False) -> GroupClaims:
|
||||
return GroupClaims(
|
||||
id=uuid4(),
|
||||
name="测试组",
|
||||
max_storage=1073741824,
|
||||
share_enabled=True,
|
||||
web_dav_enabled=False,
|
||||
admin=admin,
|
||||
speed_limit=0,
|
||||
)
|
||||
|
||||
|
||||
# 设置测试用的密钥
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_secret_key():
|
||||
"""为测试设置密钥"""
|
||||
import utils.JWT.JWT as jwt_module
|
||||
import utils.JWT as jwt_module
|
||||
jwt_module.SECRET_KEY = "test_secret_key_for_unit_tests"
|
||||
yield
|
||||
# 测试后恢复(虽然在单元测试中不太重要)
|
||||
|
||||
|
||||
def test_create_access_token():
|
||||
"""测试访问令牌创建"""
|
||||
data = {"sub": "testuser", "role": "user"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, expire_time = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.access_token, str)
|
||||
assert isinstance(result.access_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
assert decoded["role"] == "user"
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["jti"] == str(jti)
|
||||
assert decoded["status"] == "active"
|
||||
assert decoded["group"]["admin"] is False
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_access_token_custom_expiry():
|
||||
"""测试自定义过期时间"""
|
||||
data = {"sub": "testuser"}
|
||||
custom_expiry = timedelta(hours=1)
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
custom_expiry = timedelta(minutes=30)
|
||||
|
||||
token, expire_time = create_access_token(data, expires_delta=custom_expiry)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group, expires_delta=custom_expiry)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30分钟后
|
||||
exp_timestamp = decoded["exp"]
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# 允许1秒误差
|
||||
assert abs(exp_timestamp - now_timestamp - 1800) < 1
|
||||
|
||||
|
||||
def test_create_access_token_default_expiry():
|
||||
"""测试访问令牌默认1小时过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是1小时后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -55,27 +92,29 @@ def test_create_access_token_custom_expiry():
|
||||
|
||||
def test_create_refresh_token():
|
||||
"""测试刷新令牌创建"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert isinstance(expire_time, datetime)
|
||||
assert isinstance(result.refresh_token, str)
|
||||
assert isinstance(result.refresh_expires, datetime)
|
||||
|
||||
# 解码验证
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == "testuser"
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
assert decoded["sub"] == str(sub)
|
||||
assert decoded["token_type"] == "refresh"
|
||||
assert "exp" in decoded
|
||||
|
||||
|
||||
def test_create_refresh_token_default_expiry():
|
||||
"""测试刷新令牌默认30天过期"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, expire_time = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
# 验证过期时间大约是30天后
|
||||
exp_timestamp = decoded["exp"]
|
||||
@@ -86,78 +125,72 @@ def test_create_refresh_token_default_expiry():
|
||||
assert abs(exp_timestamp - now_timestamp - 2592000) < 1
|
||||
|
||||
|
||||
def test_token_decode():
|
||||
"""测试令牌解码"""
|
||||
data = {"sub": "user123", "email": "user@example.com"}
|
||||
def test_access_token_contains_group_claims():
|
||||
"""测试访问令牌包含完整的 group claims"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims(admin=True)
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 解码
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["email"] == "user@example.com"
|
||||
|
||||
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
# 创建一个立即过期的令牌
|
||||
token, _ = create_access_token(data, expires_delta=timedelta(seconds=-1))
|
||||
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
data = {"sub": "testuser"}
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(token, "wrong_secret_key", algorithms=["HS256"])
|
||||
assert decoded["group"]["admin"] is True
|
||||
assert decoded["group"]["name"] == "测试组"
|
||||
assert decoded["group"]["max_storage"] == 1073741824
|
||||
assert decoded["group"]["share_enabled"] is True
|
||||
|
||||
|
||||
def test_access_token_does_not_have_token_type():
|
||||
"""测试访问令牌不包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert "token_type" not in decoded
|
||||
|
||||
|
||||
def test_refresh_token_has_token_type():
|
||||
"""测试刷新令牌包含 token_type"""
|
||||
data = {"sub": "testuser"}
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
|
||||
token, _ = create_refresh_token(data)
|
||||
result = create_refresh_token(sub=sub, jti=jti)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
decoded = pyjwt.decode(result.refresh_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["token_type"] == "refresh"
|
||||
|
||||
|
||||
def test_token_payload_preserved():
|
||||
"""测试自定义负载保留"""
|
||||
data = {
|
||||
"sub": "user123",
|
||||
"name": "Test User",
|
||||
"roles": ["admin", "user"],
|
||||
"metadata": {"key": "value"}
|
||||
}
|
||||
def test_token_expired():
|
||||
"""测试令牌过期"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
token, _ = create_access_token(data)
|
||||
# 创建一个立即过期的令牌
|
||||
result = create_access_token(
|
||||
sub=sub, jti=jti, status="active", group=group,
|
||||
expires_delta=timedelta(seconds=-1),
|
||||
)
|
||||
|
||||
decoded = pyjwt.decode(token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
# 尝试解码应该抛出过期异常
|
||||
with pytest.raises(pyjwt.ExpiredSignatureError):
|
||||
pyjwt.decode(result.access_token, "test_secret_key_for_unit_tests", algorithms=["HS256"])
|
||||
|
||||
assert decoded["sub"] == "user123"
|
||||
assert decoded["name"] == "Test User"
|
||||
assert decoded["roles"] == ["admin", "user"]
|
||||
assert decoded["metadata"] == {"key": "value"}
|
||||
|
||||
def test_token_invalid_signature():
|
||||
"""测试无效签名"""
|
||||
sub = uuid4()
|
||||
jti = uuid4()
|
||||
group = _make_group_claims()
|
||||
|
||||
result = create_access_token(sub=sub, jti=jti, status="active", group=group)
|
||||
|
||||
# 使用错误的密钥解码
|
||||
with pytest.raises(pyjwt.InvalidSignatureError):
|
||||
pyjwt.decode(result.access_token, "wrong_secret_key", algorithms=["HS256"])
|
||||
|
||||
@@ -72,9 +72,9 @@ def test_password_verify_expired():
|
||||
@pytest.mark.asyncio
|
||||
async def test_totp_generate():
|
||||
"""测试 TOTP 密钥生成"""
|
||||
username = "testuser"
|
||||
email = "testuser@test.local"
|
||||
|
||||
response = await Password.generate_totp(username)
|
||||
response = await Password.generate_totp(email)
|
||||
|
||||
assert response.setup_token is not None
|
||||
assert response.uri is not None
|
||||
@@ -82,7 +82,7 @@ async def test_totp_generate():
|
||||
assert isinstance(response.uri, str)
|
||||
# TOTP URI 格式: otpauth://totp/...
|
||||
assert response.uri.startswith("otpauth://totp/")
|
||||
assert username in response.uri
|
||||
assert email in response.uri
|
||||
|
||||
|
||||
def test_totp_verify_valid():
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
|
||||
from models import AccessTokenBase, RefreshTokenBase
|
||||
from sqlmodels import AccessTokenBase, RefreshTokenBase, TokenResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlmodels.group import GroupClaims
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
scheme_name='获取 JWT Bearer 令牌',
|
||||
@@ -21,8 +25,8 @@ async def load_secret_key() -> None:
|
||||
从数据库读取 JWT 的密钥。
|
||||
"""
|
||||
# 延迟导入以避免循环依赖
|
||||
from models.database import get_session
|
||||
from models.setting import Setting
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
global SECRET_KEY
|
||||
async for session in get_session():
|
||||
@@ -59,7 +63,7 @@ def build_token_payload(
|
||||
elif is_refresh:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=3)
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
to_encode.update({
|
||||
"iat": int(datetime.now(timezone.utc).timestamp()),
|
||||
"exp": int(expire.timestamp())
|
||||
@@ -69,25 +73,38 @@ def build_token_payload(
|
||||
|
||||
# 访问令牌
|
||||
def create_access_token(
|
||||
data: dict,
|
||||
sub: UUID,
|
||||
jti: UUID,
|
||||
*,
|
||||
status: str,
|
||||
group: "GroupClaims",
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
algorithm: str = "HS256",
|
||||
) -> AccessTokenBase:
|
||||
"""
|
||||
生成访问令牌,默认有效期 3 小时。
|
||||
生成访问令牌,默认有效期 1 小时。
|
||||
|
||||
:param data: 需要放进 JWT Payload 的字段。
|
||||
:param expires_delta: 过期时间, 缺省时为 3 小时。
|
||||
:param sub: 令牌的主题,通常是用户 ID。
|
||||
:param jti: 令牌的唯一标识符,通常是一个 UUID。
|
||||
:param status: 用户状态字符串。
|
||||
:param group: 用户组权限快照。
|
||||
:param expires_delta: 过期时间, 缺省时为 1 小时。
|
||||
:param algorithm: JWT 密钥强度,缺省时为 HS256
|
||||
|
||||
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
||||
"""
|
||||
data = {
|
||||
"sub": str(sub),
|
||||
"jti": str(jti),
|
||||
"status": status,
|
||||
"group": group.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
access_token, expire_at = build_token_payload(
|
||||
data,
|
||||
False,
|
||||
algorithm,
|
||||
expires_delta
|
||||
data,
|
||||
False,
|
||||
algorithm,
|
||||
expires_delta,
|
||||
)
|
||||
return AccessTokenBase(
|
||||
access_token=access_token,
|
||||
@@ -97,19 +114,29 @@ def create_access_token(
|
||||
|
||||
# 刷新令牌
|
||||
def create_refresh_token(
|
||||
data: dict,
|
||||
sub: UUID,
|
||||
jti: UUID,
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
algorithm: str = "HS256",
|
||||
**kwargs,
|
||||
) -> RefreshTokenBase:
|
||||
"""
|
||||
生成刷新令牌,默认有效期 30 天。
|
||||
|
||||
:param data: 需要放进 JWT Payload 的字段。
|
||||
:param sub: 令牌的主题,通常是用户 ID。
|
||||
:param jti: 令牌的唯一标识符,通常是一个 UUID。
|
||||
:param expires_delta: 过期时间, 缺省时为 30 天。
|
||||
:param algorithm: JWT 密钥强度,缺省时为 HS256
|
||||
:param kwargs: 需要放进 JWT Payload 的字段。
|
||||
|
||||
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
|
||||
"""
|
||||
|
||||
data = {"sub": str(sub), "jti": str(jti)}
|
||||
|
||||
# 将额外的字段添加到 Payload 中
|
||||
for key, value in kwargs.items():
|
||||
data[key] = value
|
||||
|
||||
refresh_token, expire_at = build_token_payload(
|
||||
data,
|
||||
|
||||
@@ -28,6 +28,10 @@ def raise_forbidden(detail: str | None = None, *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception."""
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
|
||||
|
||||
def raise_banned(detail: str = "此文件已被管理员封禁,仅允许删除操作", *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 403 Forbidden exception for banned objects."""
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail, *args, **kwargs)
|
||||
|
||||
def raise_not_found(detail: str | None = None, *args, **kwargs) -> NoReturn:
|
||||
"""Raises an HTTP 404 Not Found exception."""
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail, *args, **kwargs)
|
||||
|
||||
@@ -73,6 +73,8 @@ class Password:
|
||||
|
||||
:param length: 密码长度
|
||||
:type length: int
|
||||
:param url_safe: 是否生成 URL 安全的密码
|
||||
:type url_safe: bool
|
||||
:return: 随机密码
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user