Refactor and enhance OAuth2.0 implementation; update models and routes
- Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity. - Added OAUTH type to SettingsType enum. - Cleaned up imports in webdav.py. - Updated admin router to improve summary data retrieval and response handling. - Enhanced file management routes with better condition handling and user storage updates. - Improved group management routes by optimizing data retrieval. - Refined task management routes for better condition handling. - Updated user management routes to streamline access token retrieval. - Implemented a new captcha verification structure with abstract base class. - Removed deprecated env.md file and replaced with a new structured version. - Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations. - Enhanced password management with improved hashing strategies. - Added detailed comments and documentation throughout the codebase for clarity.
This commit is contained in:
9
.claude/settings.local.json
Normal file
9
.claude/settings.local.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(git rev-parse:*)",
|
||||||
|
"Bash(findstr:*)",
|
||||||
|
"Bash(find:*)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -9,4 +9,4 @@
|
|||||||
- `REDIS_PORT`: Redis 端口
|
- `REDIS_PORT`: Redis 端口
|
||||||
- `REDIS_PASSWORD`: Redis 密码
|
- `REDIS_PASSWORD`: Redis 密码
|
||||||
- `REDIS_DB`: Redis 数据库
|
- `REDIS_DB`: Redis 数据库
|
||||||
- `REDIS_PROTOCOL`
|
- `REDIS_PROTOCOL`: Redis 协议
|
||||||
16
main.py
16
main.py
@@ -33,8 +33,22 @@ app = FastAPI(
|
|||||||
openapi_url="/openapi.json" if appmeta.debug else None,
|
openapi_url="/openapi.json" if appmeta.debug else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加跨域 CORS 中间件,仅在调试模式下启用,以允许所有来源访问 API
|
||||||
|
if appmeta.debug:
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
@app.exception_handler(Exception)
|
@app.exception_handler(Exception)
|
||||||
async def handle_unexpected_exceptions(request: Request, exc: Exception) -> NoReturn:
|
async def handle_unexpected_exceptions(
|
||||||
|
request: Request,
|
||||||
|
exc: Exception
|
||||||
|
) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。
|
捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -60,10 +60,10 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
|||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||||
if payload.get("type") != "download":
|
if payload.get("type") != "download":
|
||||||
return None
|
http_exceptions.raise_unauthorized("Download token required")
|
||||||
jti = payload.get("jti")
|
jti = payload.get("jti")
|
||||||
if not jti:
|
if not jti:
|
||||||
return None
|
http_exceptions.raise_unauthorized("Download token required")
|
||||||
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
|
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
|
||||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
except jwt.InvalidTokenError:
|
||||||
return None
|
http_exceptions.raise_unauthorized("Download token required")
|
||||||
@@ -120,7 +120,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
|||||||
async def init_default_settings() -> None:
|
async def init_default_settings() -> None:
|
||||||
from .setting import Setting
|
from .setting import Setting
|
||||||
from .database import get_session
|
from .database import get_session
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
log.info('初始化设置...')
|
log.info('初始化设置...')
|
||||||
|
|
||||||
@@ -128,7 +127,7 @@ async def init_default_settings() -> None:
|
|||||||
# 检查是否已经存在版本设置
|
# 检查是否已经存在版本设置
|
||||||
ver = await Setting.get(
|
ver = await Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(Setting.type == SettingsType.VERSION, Setting.name == f"db_version_{BackendVersion}")
|
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
|
||||||
)
|
)
|
||||||
if ver and ver.value == "installed":
|
if ver and ver.value == "installed":
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
|
||||||
"""
|
"""
|
||||||
向数据库中添加一个新的或多个新的记录.
|
向数据库中添加一个新的或多个新的记录.
|
||||||
|
|
||||||
@@ -230,6 +230,8 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||||
|
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||||
|
之后需要手动调用 `session.commit()`。默认为 True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||||
@@ -244,6 +246,11 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
# 添加单个实例
|
# 添加单个实例
|
||||||
item3 = Item(name="Cherry")
|
item3 = Item(name="Cherry")
|
||||||
added_item = await Item.add(session, item3)
|
added_item = await Item.add(session, item3)
|
||||||
|
|
||||||
|
# 批量操作,减少提交次数
|
||||||
|
await Item.add(session, [item1, item2], commit=False)
|
||||||
|
await Item.add(session, [item3, item4], commit=False)
|
||||||
|
await session.commit()
|
||||||
"""
|
"""
|
||||||
is_list = False
|
is_list = False
|
||||||
if isinstance(instances, list):
|
if isinstance(instances, list):
|
||||||
@@ -252,7 +259,10 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
else:
|
else:
|
||||||
session.add(instances)
|
session.add(instances)
|
||||||
|
|
||||||
await session.commit()
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
if refresh:
|
if refresh:
|
||||||
if is_list:
|
if is_list:
|
||||||
@@ -266,15 +276,16 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
async def save(
|
async def save(
|
||||||
self: T,
|
self: T,
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
load: RelationshipInfo | None = None,
|
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||||
refresh: bool = True
|
refresh: bool = True,
|
||||||
|
commit: bool = True
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
保存(插入或更新)当前模型实例到数据库.
|
保存(插入或更新)当前模型实例到数据库.
|
||||||
|
|
||||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||||
预加载(eager load)一个关联关系.
|
预加载(eager load)一个或多个关联关系.
|
||||||
|
|
||||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||||
如果需要继续使用该对象,必须使用返回值:
|
如果需要继续使用该对象,必须使用返回值:
|
||||||
@@ -287,6 +298,14 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||||
await client.save(session, refresh=False)
|
await client.save(session, refresh=False)
|
||||||
|
|
||||||
|
# ✅ 正确:批量操作,减少提交次数
|
||||||
|
await item1.save(session, commit=False)
|
||||||
|
await item2.save(session, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# ✅ 正确:批量操作并预加载多个关联关系
|
||||||
|
user = await user.save(session, load=[User.group, User.tags])
|
||||||
|
|
||||||
# ❌ 错误:需要返回值但未使用
|
# ❌ 错误:需要返回值但未使用
|
||||||
await client.save(session)
|
await client.save(session)
|
||||||
return client # client 对象已过期
|
return client # client 对象已过期
|
||||||
@@ -294,16 +313,22 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||||
例如 `User.posts`.
|
可以是单个关系或关系列表.
|
||||||
|
例如 `User.posts` 或 `[User.group, User.tags]`.
|
||||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||||
设为 False 可节省一次数据库查询。默认为 True.
|
设为 False 可节省一次数据库查询。默认为 True.
|
||||||
|
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||||
|
之后需要手动调用 `session.commit()`。默认为 True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||||
"""
|
"""
|
||||||
session.add(self)
|
session.add(self)
|
||||||
await session.commit()
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
if not refresh:
|
if not refresh:
|
||||||
return self
|
return self
|
||||||
@@ -324,8 +349,9 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
extra_data: dict[str, Any] | None = None,
|
extra_data: dict[str, Any] | None = None,
|
||||||
exclude_unset: bool = True,
|
exclude_unset: bool = True,
|
||||||
exclude: set[str] | None = None,
|
exclude: set[str] | None = None,
|
||||||
load: RelationshipInfo | None = None,
|
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||||
refresh: bool = True
|
refresh: bool = True,
|
||||||
|
commit: bool = True
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||||
@@ -348,6 +374,14 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||||
await client.update(session, update_data, refresh=False)
|
await client.update(session, update_data, refresh=False)
|
||||||
|
|
||||||
|
# ✅ 正确:批量操作,减少提交次数
|
||||||
|
await user1.update(session, data1, commit=False)
|
||||||
|
await user2.update(session, data2, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# ✅ 正确:批量操作并预加载多个关联关系
|
||||||
|
user = await user.update(session, data, load=[User.group, User.tags])
|
||||||
|
|
||||||
# ❌ 错误:需要返回值但未使用
|
# ❌ 错误:需要返回值但未使用
|
||||||
await client.update(session, update_data)
|
await client.update(session, update_data)
|
||||||
return client # client 对象已过期
|
return client # client 对象已过期
|
||||||
@@ -360,10 +394,13 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||||
的字段将被忽略. 默认为 True.
|
的字段将被忽略. 默认为 True.
|
||||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||||
例如 `User.permission`.
|
可以是单个关系或关系列表.
|
||||||
|
例如 `User.permission` 或 `[User.group, User.tags]`.
|
||||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||||
设为 False 可节省一次数据库查询。默认为 True.
|
设为 False 可节省一次数据库查询。默认为 True.
|
||||||
|
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||||
|
之后需要手动调用 `session.commit()`。默认为 True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||||
@@ -374,7 +411,10 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
session.add(self)
|
session.add(self)
|
||||||
await session.commit()
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
if not refresh:
|
if not refresh:
|
||||||
return self
|
return self
|
||||||
@@ -388,33 +428,82 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
async def delete(
|
||||||
|
cls: type[T],
|
||||||
|
session: AsyncSession,
|
||||||
|
instances: T | list[T] | None = None,
|
||||||
|
*,
|
||||||
|
condition: BinaryExpression | ClauseElement | None = None,
|
||||||
|
commit: bool = True
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
从数据库中删除一个或多个记录.
|
从数据库中删除记录.
|
||||||
|
|
||||||
|
支持两种删除方式:
|
||||||
|
1. 实例删除:传入 instances 参数,先加载再删除
|
||||||
|
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||||
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
|
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
|
||||||
|
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
|
||||||
|
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||||
|
之后需要手动调用 `session.commit()`。默认为 True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
int: 删除的记录数量
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
|
# 实例删除
|
||||||
item_to_delete = await Item.get(session, Item.id == 1)
|
item_to_delete = await Item.get(session, Item.id == 1)
|
||||||
if item_to_delete:
|
if item_to_delete:
|
||||||
await Item.delete(session, item_to_delete)
|
deleted_count = await Item.delete(session, item_to_delete)
|
||||||
|
|
||||||
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
|
# 条件删除(更高效,无需加载实例)
|
||||||
if items_to_delete:
|
deleted_count = await Item.delete(
|
||||||
await Item.delete(session, items_to_delete)
|
session,
|
||||||
|
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批量删除后手动提交
|
||||||
|
await Item.delete(session, item1, commit=False)
|
||||||
|
await Item.delete(session, item2, commit=False)
|
||||||
|
await session.commit()
|
||||||
"""
|
"""
|
||||||
|
# 条件删除模式
|
||||||
|
if condition is not None:
|
||||||
|
from sqlmodel import delete as sql_delete
|
||||||
|
|
||||||
|
if instances is not None:
|
||||||
|
raise ValueError("不能同时指定 instances 和 condition")
|
||||||
|
|
||||||
|
# 执行条件删除
|
||||||
|
stmt = sql_delete(cls).where(condition)
|
||||||
|
result = await session.exec(stmt)
|
||||||
|
deleted_count = result.rowcount
|
||||||
|
|
||||||
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
|
# 实例删除模式(原有逻辑)
|
||||||
|
if instances is None:
|
||||||
|
raise ValueError("必须指定 instances 或 condition")
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
if isinstance(instances, list):
|
if isinstance(instances, list):
|
||||||
for instance in instances:
|
for instance in instances:
|
||||||
await session.delete(instance)
|
await session.delete(instance)
|
||||||
|
deleted_count += 1
|
||||||
else:
|
else:
|
||||||
await session.delete(instances)
|
await session.delete(instances)
|
||||||
|
deleted_count = 1
|
||||||
|
|
||||||
await session.commit()
|
if commit:
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return deleted_count
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _build_time_filters(
|
def _build_time_filters(
|
||||||
@@ -458,7 +547,7 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||||
options: list | None = None,
|
options: list | None = None,
|
||||||
load: RelationshipInfo | None = None,
|
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||||
order_by: list[ClauseElement] | None = None,
|
order_by: list[ClauseElement] | None = None,
|
||||||
filter: BinaryExpression | ClauseElement | None = None,
|
filter: BinaryExpression | ClauseElement | None = None,
|
||||||
with_for_update: bool = False,
|
with_for_update: bool = False,
|
||||||
@@ -491,8 +580,9 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
|
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
|
||||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||||
例如 `[selectinload(User.posts)]`.
|
例如 `[selectinload(User.posts)]`.
|
||||||
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
|
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
|
||||||
例如 `User.profile`.
|
可以是单个关系或关系列表.
|
||||||
|
例如 `User.profile` 或 `[User.group, User.tags]`.
|
||||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||||
@@ -595,9 +685,15 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
statement = statement.options(*options)
|
statement = statement.options(*options)
|
||||||
|
|
||||||
if load:
|
if load:
|
||||||
|
# 标准化为列表
|
||||||
|
load_list = load if isinstance(load, list) else [load]
|
||||||
|
|
||||||
# 处理多态加载
|
# 处理多态加载
|
||||||
if load_polymorphic is not None:
|
if load_polymorphic is not None:
|
||||||
target_class = load.property.mapper.class_
|
# 多态加载只支持单个关系
|
||||||
|
if len(load_list) > 1:
|
||||||
|
raise ValueError("load_polymorphic 仅支持单个关系")
|
||||||
|
target_class = load_list[0].property.mapper.class_
|
||||||
|
|
||||||
# 检查目标类是否继承自 PolymorphicBaseMixin
|
# 检查目标类是否继承自 PolymorphicBaseMixin
|
||||||
if not issubclass(target_class, PolymorphicBaseMixin):
|
if not issubclass(target_class, PolymorphicBaseMixin):
|
||||||
@@ -609,7 +705,7 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
if load_polymorphic == 'all':
|
if load_polymorphic == 'all':
|
||||||
# 两阶段查询:获取实际关联的多态类型
|
# 两阶段查询:获取实际关联的多态类型
|
||||||
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
||||||
session, condition, load, target_class
|
session, condition, load_list[0], target_class
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
subclasses_to_load = load_polymorphic
|
subclasses_to_load = load_polymorphic
|
||||||
@@ -618,12 +714,14 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
||||||
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
||||||
statement = statement.options(
|
statement = statement.options(
|
||||||
selectinload(load).selectin_polymorphic(subclasses_to_load)
|
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
statement = statement.options(selectinload(load))
|
statement = statement.options(selectinload(load_list[0]))
|
||||||
else:
|
else:
|
||||||
statement = statement.options(selectinload(load))
|
# 为每个关系添加 selectinload
|
||||||
|
for rel in load_list:
|
||||||
|
statement = statement.options(selectinload(rel))
|
||||||
|
|
||||||
if order_by is not None:
|
if order_by is not None:
|
||||||
statement = statement.order_by(*order_by)
|
statement = statement.order_by(*order_by)
|
||||||
@@ -796,7 +894,7 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
*,
|
*,
|
||||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||||
options: list | None = None,
|
options: list | None = None,
|
||||||
load: RelationshipInfo | None = None,
|
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||||
order_by: list[ClauseElement] | None = None,
|
order_by: list[ClauseElement] | None = None,
|
||||||
filter: BinaryExpression | ClauseElement | None = None,
|
filter: BinaryExpression | ClauseElement | None = None,
|
||||||
table_view: TableViewRequest | None = None,
|
table_view: TableViewRequest | None = None,
|
||||||
@@ -865,7 +963,7 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
return ListResponse(count=total_count, items=items)
|
return ListResponse(count=total_count, items=items)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
|
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
|
||||||
"""
|
"""
|
||||||
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
|
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||||
|
|
||||||
@@ -875,7 +973,8 @@ class TableBaseMixin(AsyncAttrs):
|
|||||||
Args:
|
Args:
|
||||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||||
id (int): 要查找的记录的主键 ID.
|
id (int): 要查找的记录的主键 ID.
|
||||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||||
|
可以是单个关系或关系列表.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
T: 找到的模型实例.
|
T: 找到的模型实例.
|
||||||
@@ -903,7 +1002,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
|||||||
|
|
||||||
@override
|
@override
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
|
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
|
||||||
"""
|
"""
|
||||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||||
|
|
||||||
@@ -913,7 +1012,8 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
|||||||
Args:
|
Args:
|
||||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||||
|
可以是单个关系或关系列表.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
T: 找到的模型实例.
|
T: 找到的模型实例.
|
||||||
|
|||||||
@@ -79,9 +79,8 @@ class VersionInfo(SQLModelBase):
|
|||||||
commit: str
|
commit: str
|
||||||
"""提交哈希"""
|
"""提交哈希"""
|
||||||
|
|
||||||
|
class AdminSummaryResponse(ResponseBase):
|
||||||
class AdminSummaryData(SQLModelBase):
|
"""管理员概况响应"""
|
||||||
"""管理员概况数据"""
|
|
||||||
|
|
||||||
metrics_summary: MetricsSummary
|
metrics_summary: MetricsSummary
|
||||||
"""统计摘要"""
|
"""统计摘要"""
|
||||||
@@ -95,13 +94,6 @@ class AdminSummaryData(SQLModelBase):
|
|||||||
version: VersionInfo
|
version: VersionInfo
|
||||||
"""版本信息"""
|
"""版本信息"""
|
||||||
|
|
||||||
|
|
||||||
class AdminSummaryResponse(ResponseBase):
|
|
||||||
"""管理员概况响应"""
|
|
||||||
|
|
||||||
data: AdminSummaryData | None = None
|
|
||||||
"""响应数据"""
|
|
||||||
|
|
||||||
class MCPMethod(StrEnum):
|
class MCPMethod(StrEnum):
|
||||||
"""MCP 方法枚举"""
|
"""MCP 方法枚举"""
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,7 @@ class SettingsType(StrEnum):
|
|||||||
MAIL = "mail"
|
MAIL = "mail"
|
||||||
MAIL_TEMPLATE = "mail_template"
|
MAIL_TEMPLATE = "mail_template"
|
||||||
MOBILE = "mobile"
|
MOBILE = "mobile"
|
||||||
|
OAUTH = "oauth"
|
||||||
PATH = "path"
|
PATH = "path"
|
||||||
PREVIEW = "preview"
|
PREVIEW = "preview"
|
||||||
PWA = "pwa"
|
PWA = "pwa"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
|
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||||
|
|
||||||
from .base import SQLModelBase
|
from .base import SQLModelBase
|
||||||
from .mixin import TableBaseMixin
|
from .mixin import TableBaseMixin
|
||||||
|
|||||||
@@ -2,14 +2,12 @@ from datetime import datetime, timedelta
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from models import (
|
from models import (
|
||||||
User, ResponseBase,
|
User, ResponseBase,
|
||||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||||
AdminSummaryData,
|
|
||||||
)
|
)
|
||||||
from models.base import SQLModelBase
|
from models.base import SQLModelBase
|
||||||
from models.setting import (
|
from models.setting import (
|
||||||
@@ -75,8 +73,8 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
|||||||
Returns:
|
Returns:
|
||||||
AdminSummaryResponse: 包含站点概况信息的响应模型。
|
AdminSummaryResponse: 包含站点概况信息的响应模型。
|
||||||
"""
|
"""
|
||||||
# 统计最近 12 天的数据
|
# 统计最近 14 天的数据
|
||||||
days_count = 12
|
days_count = 14
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
@@ -135,7 +133,7 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
|||||||
site_urls: list[str] = []
|
site_urls: list[str] = []
|
||||||
site_url_setting = await Setting.get(
|
site_url_setting = await Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(Setting.type == SettingsType.BASIC, Setting.name == "siteURL"),
|
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||||
)
|
)
|
||||||
if site_url_setting and site_url_setting.value:
|
if site_url_setting and site_url_setting.value:
|
||||||
site_urls.append(site_url_setting.value)
|
site_urls.append(site_url_setting.value)
|
||||||
@@ -156,15 +154,13 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
|||||||
commit="dev",
|
commit="dev",
|
||||||
)
|
)
|
||||||
|
|
||||||
data = AdminSummaryData(
|
return AdminSummaryResponse(
|
||||||
metrics_summary=metrics_summary,
|
metrics_summary=metrics_summary,
|
||||||
site_urls=site_urls,
|
site_urls=site_urls,
|
||||||
license=license_info,
|
license=license_info,
|
||||||
version=version_info,
|
version=version_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
return AdminSummaryResponse(data=data)
|
|
||||||
|
|
||||||
@admin_router.get(
|
@admin_router.get(
|
||||||
path='/news',
|
path='/news',
|
||||||
summary='获取社区新闻',
|
summary='获取社区新闻',
|
||||||
@@ -203,7 +199,7 @@ async def router_admin_update_settings(
|
|||||||
for item in request.settings:
|
for item in request.settings:
|
||||||
existing = await Setting.get(
|
existing = await Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(Setting.type == item.type, Setting.name == item.name)
|
(Setting.type == item.type) & (Setting.name == item.name)
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
@@ -245,7 +241,12 @@ async def router_admin_get_settings(
|
|||||||
if name:
|
if name:
|
||||||
conditions.append(Setting.name == name)
|
conditions.append(Setting.name == name)
|
||||||
|
|
||||||
condition = and_(*conditions) if conditions else None
|
if conditions:
|
||||||
|
condition = conditions[0]
|
||||||
|
for c in conditions[1:]:
|
||||||
|
condition = condition & c
|
||||||
|
else:
|
||||||
|
condition = None
|
||||||
|
|
||||||
settings = await Setting.get(session, condition, fetch_mode="all")
|
settings = await Setting.get(session, condition, fetch_mode="all")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ from uuid import UUID
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
@@ -51,7 +50,12 @@ async def router_admin_get_file_list(
|
|||||||
if keyword:
|
if keyword:
|
||||||
conditions.append(Object.name.ilike(f"%{keyword}%"))
|
conditions.append(Object.name.ilike(f"%{keyword}%"))
|
||||||
|
|
||||||
condition = and_(*conditions) if len(conditions) > 1 else conditions[0]
|
if len(conditions) > 1:
|
||||||
|
condition = conditions[0]
|
||||||
|
for c in conditions[1:]:
|
||||||
|
condition = condition & c
|
||||||
|
else:
|
||||||
|
condition = conditions[0]
|
||||||
result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner)
|
result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
@@ -197,13 +201,15 @@ async def router_admin_delete_file(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.warning(f"删除物理文件失败: {e}")
|
l.warning(f"删除物理文件失败: {e}")
|
||||||
|
|
||||||
# 更新用户存储量
|
# 更新用户存储量(使用 SQL UPDATE 直接更新,无需加载实例)
|
||||||
owner = await User.get(session, User.id == owner_id)
|
from sqlmodel import update as sql_update
|
||||||
if owner:
|
stmt = sql_update(User).where(User.id == owner_id).values(
|
||||||
owner.storage = max(0, owner.storage - file_size)
|
storage=max(0, User.storage - file_size)
|
||||||
await owner.save(session)
|
)
|
||||||
|
await session.exec(stmt)
|
||||||
|
|
||||||
await Object.delete(session, file_obj)
|
# 使用条件删除
|
||||||
|
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||||
|
|
||||||
l.info(f"管理员删除了文件: {file_name}")
|
l.info(f"管理员删除了文件: {file_name}")
|
||||||
return ResponseBase(data={"deleted": True})
|
return ResponseBase(data={"deleted": True})
|
||||||
@@ -63,12 +63,13 @@ async def router_admin_get_group(
|
|||||||
:param group_id: 用户组UUID
|
:param group_id: 用户组UUID
|
||||||
:return: 用户组详情
|
:return: 用户组详情
|
||||||
"""
|
"""
|
||||||
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
||||||
|
|
||||||
if not group:
|
if not group:
|
||||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||||
|
|
||||||
policies = await group.awaitable_attrs.policies
|
# 直接访问已加载的关系,无需额外查询
|
||||||
|
policies = group.policies
|
||||||
user_count = await User.count(session, User.group_id == group_id)
|
user_count = await User.count(session, User.group_id == group_id)
|
||||||
response = GroupDetailResponse.from_group(group, user_count, policies)
|
response = GroupDetailResponse.from_group(group, user_count, policies)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
@@ -43,7 +42,12 @@ async def router_admin_get_task_list(
|
|||||||
if status:
|
if status:
|
||||||
conditions.append(Task.status == status)
|
conditions.append(Task.status == status)
|
||||||
|
|
||||||
condition = and_(*conditions) if conditions else None
|
if conditions:
|
||||||
|
condition = conditions[0]
|
||||||
|
for c in conditions[1:]:
|
||||||
|
condition = condition & c
|
||||||
|
else:
|
||||||
|
condition = None
|
||||||
result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user)
|
result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user)
|
||||||
|
|
||||||
items: list[TaskSummary] = []
|
items: list[TaskSummary] = []
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from loguru import logger as l
|
from loguru import logger as l
|
||||||
from sqlalchemy import func, and_
|
from sqlalchemy import func
|
||||||
|
|
||||||
from middleware.auth import admin_required
|
from middleware.auth import admin_required
|
||||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||||
@@ -198,7 +198,7 @@ async def router_admin_calibrate_storage(
|
|||||||
from sqlmodel import select
|
from sqlmodel import select
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(func.sum(Object.size), func.count(Object.id)).where(
|
select(func.sum(Object.size), func.count(Object.id)).where(
|
||||||
and_(Object.owner_id == user_id, Object.type == ObjectType.FILE)
|
(Object.owner_id == user_id) & (Object.type == ObjectType.FILE)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
row = result.one()
|
row = result.one()
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ async def upload_chunk(
|
|||||||
policy_id=upload_session.policy_id,
|
policy_id=upload_session.policy_id,
|
||||||
reference_count=1,
|
reference_count=1,
|
||||||
)
|
)
|
||||||
physical_file = await physical_file.save(session)
|
physical_file = await physical_file.save(session, commit=False)
|
||||||
|
|
||||||
# 创建 Object 记录
|
# 创建 Object 记录
|
||||||
file_object = Object(
|
file_object = Object(
|
||||||
@@ -246,11 +246,18 @@ async def upload_chunk(
|
|||||||
owner_id=user_id,
|
owner_id=user_id,
|
||||||
policy_id=upload_session.policy_id,
|
policy_id=upload_session.policy_id,
|
||||||
)
|
)
|
||||||
file_object = await file_object.save(session)
|
file_object = await file_object.save(session, commit=False)
|
||||||
file_object_id = file_object.id
|
file_object_id = file_object.id
|
||||||
|
|
||||||
# 删除上传会话
|
# 删除上传会话(使用条件删除)
|
||||||
await UploadSession.delete(session, upload_session)
|
await UploadSession.delete(
|
||||||
|
session,
|
||||||
|
condition=UploadSession.id == upload_session.id,
|
||||||
|
commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# 统一提交所有更改
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from sqlalchemy import and_
|
|
||||||
|
|
||||||
from middleware.dependencies import SessionDep
|
from middleware.dependencies import SessionDep
|
||||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||||
@@ -55,5 +54,5 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
|||||||
dict: The site configuration.
|
dict: The site configuration.
|
||||||
"""
|
"""
|
||||||
return SiteConfigResponse(
|
return SiteConfigResponse(
|
||||||
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")),
|
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||||
)
|
)
|
||||||
@@ -3,7 +3,6 @@ from uuid import UUID
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from sqlalchemy import and_
|
|
||||||
from webauthn import generate_registration_options
|
from webauthn import generate_registration_options
|
||||||
from webauthn.helpers import options_to_json_dict
|
from webauthn.helpers import options_to_json_dict
|
||||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||||
@@ -115,7 +114,7 @@ async def router_user_register(
|
|||||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
|
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
|
||||||
)
|
)
|
||||||
if default_group_setting is None or not default_group_setting.value:
|
if default_group_setting is None or not default_group_setting.value:
|
||||||
logger.error("默认用户组不存在")
|
logger.error("默认用户组不存在")
|
||||||
@@ -352,18 +351,18 @@ async def router_user_authn_start(
|
|||||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||||
authn_setting = await models.Setting.get(
|
authn_setting = await models.Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(models.Setting.type == "authn", models.Setting.name == "authn_enabled")
|
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
|
||||||
)
|
)
|
||||||
if not authn_setting or authn_setting.value != "1":
|
if not authn_setting or authn_setting.value != "1":
|
||||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||||
|
|
||||||
site_url_setting = await models.Setting.get(
|
site_url_setting = await models.Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(models.Setting.type == "basic", models.Setting.name == "siteURL")
|
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
|
||||||
)
|
)
|
||||||
site_title_setting = await models.Setting.get(
|
site_title_setting = await models.Setting.get(
|
||||||
session,
|
session,
|
||||||
and_(models.Setting.type == "basic", models.Setting.name == "siteTitle")
|
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
|
||||||
)
|
)
|
||||||
|
|
||||||
options = generate_registration_options(
|
options = generate_registration_options(
|
||||||
|
|||||||
@@ -1,5 +1,39 @@
|
|||||||
|
import abc
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from .gcaptcha import GCaptcha
|
||||||
|
from .turnstile import TurnstileCaptcha
|
||||||
|
|
||||||
|
|
||||||
class CaptchaRequestBase(BaseModel):
|
class CaptchaRequestBase(BaseModel):
|
||||||
|
"""验证码验证请求"""
|
||||||
token: str
|
token: str
|
||||||
secret: str
|
"""验证 token"""
|
||||||
|
secret: str
|
||||||
|
"""验证密钥"""
|
||||||
|
|
||||||
|
|
||||||
|
class CaptchaBase(abc.ABC):
|
||||||
|
"""验证码验证器抽象基类"""
|
||||||
|
|
||||||
|
verify_url: str
|
||||||
|
"""验证 API 地址(子类必须定义)"""
|
||||||
|
|
||||||
|
async def verify_captcha(self, request: CaptchaRequestBase) -> bool:
|
||||||
|
"""
|
||||||
|
验证 token 是否有效。
|
||||||
|
|
||||||
|
:return: 如果验证成功返回 True,否则返回 False
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
payload = request.model_dump()
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(self.verify_url, data=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return False
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
return result.get('success', False)
|
||||||
@@ -1,21 +1,7 @@
|
|||||||
import aiohttp
|
from . import CaptchaBase
|
||||||
|
|
||||||
from . import CaptchaRequestBase
|
|
||||||
|
|
||||||
async def verify_captcha(request: CaptchaRequestBase) -> bool:
|
class GCaptcha(CaptchaBase):
|
||||||
"""
|
"""Google reCAPTCHA v2/v3 验证器"""
|
||||||
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
|
|
||||||
|
verify_url = "https://www.google.com/recaptcha/api/siteverify"
|
||||||
:return: 如果验证成功返回 True,否则返回 False
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
verify_url = "https://www.google.com/recaptcha/api/siteverify"
|
|
||||||
payload = request.model_dump()
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(verify_url, data=payload) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
return False
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
return result.get('success', False)
|
|
||||||
@@ -1,21 +1,7 @@
|
|||||||
import aiohttp
|
from . import CaptchaBase
|
||||||
|
|
||||||
from . import CaptchaRequestBase
|
|
||||||
|
|
||||||
async def verify_captcha(request: CaptchaRequestBase) -> bool:
|
class TurnstileCaptcha(CaptchaBase):
|
||||||
"""
|
"""Cloudflare Turnstile 验证器"""
|
||||||
验证 Turnstile 的 token 是否有效。
|
|
||||||
|
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||||
:return: 如果验证成功返回 True,否则返回 False
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
|
||||||
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
|
||||||
payload = request.model_dump()
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(verify_url, data=payload) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
return False
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
return result.get('success', False)
|
|
||||||
223
service/oauth/__init__.py
Normal file
223
service/oauth/__init__.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""
|
||||||
|
OAuth2.0 认证模块
|
||||||
|
|
||||||
|
提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。
|
||||||
|
"""
|
||||||
|
import abc
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 共享数据模型 ====================
|
||||||
|
|
||||||
|
class AccessTokenBase(BaseModel):
|
||||||
|
"""访问令牌基类"""
|
||||||
|
access_token: str
|
||||||
|
"""访问令牌"""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthUserData(BaseModel):
|
||||||
|
"""OAuth 用户数据通用 DTO"""
|
||||||
|
openid: str
|
||||||
|
"""用户唯一标识(GitHub 为 id,QQ 为 openid)"""
|
||||||
|
nickname: str | None
|
||||||
|
"""用户昵称"""
|
||||||
|
avatar_url: str | None
|
||||||
|
"""头像 URL"""
|
||||||
|
email: str | None
|
||||||
|
"""邮箱"""
|
||||||
|
bio: str | None
|
||||||
|
"""个人简介"""
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthUserInfoResponse(BaseModel):
|
||||||
|
"""OAuth 用户信息响应"""
|
||||||
|
code: str
|
||||||
|
"""状态码"""
|
||||||
|
openid: str
|
||||||
|
"""用户唯一标识"""
|
||||||
|
user_data: OAuthUserData
|
||||||
|
"""用户数据"""
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== OAuth2.0 抽象基类 ====================
|
||||||
|
|
||||||
|
class OAuthBase(abc.ABC):
|
||||||
|
"""
|
||||||
|
OAuth2.0 客户端抽象基类
|
||||||
|
|
||||||
|
子类需要定义以下类属性:
|
||||||
|
- access_token_url: 获取 Access Token 的 API 地址
|
||||||
|
- user_info_url: 获取用户信息的 API 地址
|
||||||
|
- http_method: 获取 token 的 HTTP 方法(POST 或 GET)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 子类必须定义的类属性
|
||||||
|
access_token_url: str
|
||||||
|
"""获取 Access Token 的 API 地址"""
|
||||||
|
|
||||||
|
user_info_url: str
|
||||||
|
"""获取用户信息的 API 地址"""
|
||||||
|
|
||||||
|
http_method: str = "POST"
|
||||||
|
"""获取 token 的 HTTP 方法:POST 或 GET"""
|
||||||
|
|
||||||
|
# 实例属性(构造函数传入)
|
||||||
|
client_id: str
|
||||||
|
client_secret: str
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||||
|
"""
|
||||||
|
初始化 OAuth 客户端
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client_id: 应用 client_id
|
||||||
|
client_secret: 应用 client_secret
|
||||||
|
"""
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
|
||||||
|
async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase:
|
||||||
|
"""
|
||||||
|
通过 Authorization Code 获取 Access Token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 授权码
|
||||||
|
**kwargs: 额外参数(如 QQ 需要 redirect_uri)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AccessTokenBase: 访问令牌
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
'client_id': self.client_id,
|
||||||
|
'client_secret': self.client_secret,
|
||||||
|
'code': code,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
if self.http_method == "POST":
|
||||||
|
async with session.post(
|
||||||
|
url=self.access_token_url,
|
||||||
|
params=params,
|
||||||
|
headers={'accept': 'application/json'},
|
||||||
|
) as access_resp:
|
||||||
|
access_data = await access_resp.json()
|
||||||
|
return self._parse_token_response(access_data)
|
||||||
|
else:
|
||||||
|
async with session.get(
|
||||||
|
url=self.access_token_url,
|
||||||
|
params=params,
|
||||||
|
) as access_resp:
|
||||||
|
access_data = await access_resp.json()
|
||||||
|
return self._parse_token_response(access_data)
|
||||||
|
|
||||||
|
async def get_user_info(
|
||||||
|
self,
|
||||||
|
access_token: str | AccessTokenBase,
|
||||||
|
**kwargs
|
||||||
|
) -> OAuthUserInfoResponse:
|
||||||
|
"""
|
||||||
|
获取用户信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: 访问令牌
|
||||||
|
**kwargs: 额外参数(如 QQ 需要 app_id, openid)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthUserInfoResponse: 用户信息
|
||||||
|
"""
|
||||||
|
if isinstance(access_token, AccessTokenBase):
|
||||||
|
access_token = access_token.access_token
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url=self.user_info_url,
|
||||||
|
params=self._build_user_info_params(access_token, **kwargs),
|
||||||
|
headers=self._build_user_info_headers(access_token),
|
||||||
|
) as resp:
|
||||||
|
user_data = await resp.json()
|
||||||
|
return self._parse_user_response(user_data)
|
||||||
|
|
||||||
|
# ==================== 钩子方法(子类可覆盖) ====================
|
||||||
|
|
||||||
|
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
构建获取用户信息的请求参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: 访问令牌
|
||||||
|
**kwargs: 额外参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 请求参数
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _build_user_info_headers(self, access_token: str) -> dict:
|
||||||
|
"""
|
||||||
|
构建获取用户信息的请求头
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: 访问令牌
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 请求头
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'accept': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_token_response(self, data: dict) -> AccessTokenBase:
|
||||||
|
"""
|
||||||
|
解析 token 响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: API 返回的数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AccessTokenBase: 访问令牌
|
||||||
|
"""
|
||||||
|
return AccessTokenBase(access_token=data.get('access_token'))
|
||||||
|
|
||||||
|
def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse:
|
||||||
|
"""
|
||||||
|
解析用户信息响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: API 返回的数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OAuthUserInfoResponse: 用户信息
|
||||||
|
"""
|
||||||
|
return OAuthUserInfoResponse(
|
||||||
|
code='0',
|
||||||
|
openid='',
|
||||||
|
user_data=OAuthUserData(openid=''),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== 导出 ====================
|
||||||
|
|
||||||
|
from .github import GithubOAuth, GithubAccessToken, GithubUserData
|
||||||
|
from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 共享模型
|
||||||
|
'AccessTokenBase',
|
||||||
|
'OAuthUserData',
|
||||||
|
'OAuthUserInfoResponse',
|
||||||
|
'OAuthBase',
|
||||||
|
|
||||||
|
# GitHub
|
||||||
|
'GithubOAuth',
|
||||||
|
'GithubAccessToken',
|
||||||
|
'GithubUserData',
|
||||||
|
|
||||||
|
# QQ
|
||||||
|
'QQOAuth',
|
||||||
|
'QQAccessToken',
|
||||||
|
'QQOpenIDResponse',
|
||||||
|
'QQUserData',
|
||||||
|
]
|
||||||
@@ -1,77 +1,127 @@
|
|||||||
from pydantic import BaseModel
|
"""GitHub OAuth2.0 认证实现"""
|
||||||
import aiohttp
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
class GithubAccessToken(BaseModel):
|
from pydantic import BaseModel
|
||||||
access_token: str
|
from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import OAuthUserData
|
||||||
|
|
||||||
|
|
||||||
|
class GithubAccessToken(AccessTokenBase):
|
||||||
|
"""GitHub 访问令牌响应"""
|
||||||
token_type: str
|
token_type: str
|
||||||
|
"""令牌类型"""
|
||||||
scope: str
|
scope: str
|
||||||
|
"""授权范围"""
|
||||||
|
|
||||||
|
|
||||||
class GithubUserData(BaseModel):
|
class GithubUserData(BaseModel):
|
||||||
|
"""GitHub 用户数据"""
|
||||||
login: str
|
login: str
|
||||||
|
"""用户名"""
|
||||||
id: int
|
id: int
|
||||||
|
"""用户 ID"""
|
||||||
node_id: str
|
node_id: str
|
||||||
|
"""节点 ID"""
|
||||||
avatar_url: str
|
avatar_url: str
|
||||||
|
"""头像 URL"""
|
||||||
gravatar_id: str | None
|
gravatar_id: str | None
|
||||||
|
"""Gravatar ID"""
|
||||||
url: str
|
url: str
|
||||||
|
"""API URL"""
|
||||||
html_url: str
|
html_url: str
|
||||||
|
"""主页 URL"""
|
||||||
followers_url: str
|
followers_url: str
|
||||||
|
"""粉丝列表 URL"""
|
||||||
following_url: str
|
following_url: str
|
||||||
|
"""关注列表 URL"""
|
||||||
gists_url: str
|
gists_url: str
|
||||||
|
"""Gists 列表 URL"""
|
||||||
starred_url: str
|
starred_url: str
|
||||||
|
"""星标列表 URL"""
|
||||||
subscriptions_url: str
|
subscriptions_url: str
|
||||||
|
"""订阅列表 URL"""
|
||||||
organizations_url: str
|
organizations_url: str
|
||||||
|
"""组织列表 URL"""
|
||||||
repos_url: str
|
repos_url: str
|
||||||
|
"""仓库列表 URL"""
|
||||||
events_url: str
|
events_url: str
|
||||||
|
"""事件列表 URL"""
|
||||||
received_events_url: str
|
received_events_url: str
|
||||||
|
"""接收的事件列表 URL"""
|
||||||
type: str
|
type: str
|
||||||
|
"""用户类型"""
|
||||||
site_admin: bool
|
site_admin: bool
|
||||||
|
"""是否为站点管理员"""
|
||||||
name: str | None
|
name: str | None
|
||||||
|
"""显示名称"""
|
||||||
company: str | None
|
company: str | None
|
||||||
|
"""公司"""
|
||||||
blog: str | None
|
blog: str | None
|
||||||
|
"""博客"""
|
||||||
location: str | None
|
location: str | None
|
||||||
|
"""位置"""
|
||||||
email: str | None
|
email: str | None
|
||||||
|
"""邮箱"""
|
||||||
hireable: bool | None
|
hireable: bool | None
|
||||||
|
"""是否可雇佣"""
|
||||||
bio: str | None
|
bio: str | None
|
||||||
|
"""个人简介"""
|
||||||
twitter_username: str | None
|
twitter_username: str | None
|
||||||
|
"""Twitter 用户名"""
|
||||||
public_repos: int
|
public_repos: int
|
||||||
|
"""公开仓库数"""
|
||||||
public_gists: int
|
public_gists: int
|
||||||
|
"""公开 Gists 数"""
|
||||||
followers: int
|
followers: int
|
||||||
|
"""粉丝数"""
|
||||||
following: int
|
following: int
|
||||||
created_at: str # ISO 8601 format date-time string
|
"""关注数"""
|
||||||
updated_at: str # ISO 8601 format date-time string
|
created_at: str
|
||||||
|
"""创建时间(ISO 8601 格式)"""
|
||||||
|
updated_at: str
|
||||||
|
"""更新时间(ISO 8601 格式)"""
|
||||||
|
|
||||||
|
|
||||||
class GithubUserInfoResponse(BaseModel):
|
class GithubUserInfoResponse(BaseModel):
|
||||||
|
"""GitHub 用户信息响应"""
|
||||||
code: str
|
code: str
|
||||||
|
"""状态码"""
|
||||||
user_data: GithubUserData
|
user_data: GithubUserData
|
||||||
|
"""用户数据"""
|
||||||
|
|
||||||
async def get_access_token(code: str) -> GithubAccessToken:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(
|
|
||||||
url='https://github.com/login/oauth/access_token',
|
|
||||||
params={
|
|
||||||
'client_id': '',
|
|
||||||
'client_secret': '',
|
|
||||||
'code': code
|
|
||||||
},
|
|
||||||
headers={'accept': 'application/json'},
|
|
||||||
) as access_resp:
|
|
||||||
access_data = await access_resp.json()
|
|
||||||
return GithubAccessToken(
|
|
||||||
access_token=access_data.get('access_token'),
|
|
||||||
token_type=access_data.get('token_type'),
|
|
||||||
scope=access_data.get('scope')
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse:
|
class GithubOAuth(OAuthBase):
|
||||||
if isinstance(access_token, GithubAccessToken):
|
"""GitHub OAuth2.0 客户端"""
|
||||||
access_token = access_token.access_token
|
|
||||||
|
access_token_url = "https://github.com/login/oauth/access_token"
|
||||||
async with aiohttp.ClientSession() as session:
|
"""获取 Access Token 的 API 地址"""
|
||||||
async with session.get(
|
|
||||||
url='https://api.github.com/user',
|
user_info_url = "https://api.github.com/user"
|
||||||
headers={
|
"""获取用户信息的 API 地址"""
|
||||||
'accept': 'application/json',
|
|
||||||
'Authorization': f'token {access_token}'},
|
http_method = "POST"
|
||||||
) as resp:
|
"""获取 token 的 HTTP 方法"""
|
||||||
user_data = await resp.json()
|
|
||||||
return GithubUserInfoResponse(**user_data)
|
def _parse_token_response(self, data: dict) -> GithubAccessToken:
|
||||||
|
"""解析 GitHub token 响应"""
|
||||||
|
return GithubAccessToken(
|
||||||
|
access_token=data.get('access_token'),
|
||||||
|
token_type=data.get('token_type'),
|
||||||
|
scope=data.get('scope'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_user_info_headers(self, access_token: str) -> dict:
|
||||||
|
"""构建 GitHub 用户信息请求头"""
|
||||||
|
return {
|
||||||
|
'accept': 'application/json',
|
||||||
|
'Authorization': f'token {access_token}',
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_user_response(self, data: dict) -> GithubUserInfoResponse:
|
||||||
|
"""解析 GitHub 用户信息响应"""
|
||||||
|
return GithubUserInfoResponse(
|
||||||
|
code='0' if data.get('login') else '1',
|
||||||
|
user_data=GithubUserData(**data),
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,7 +1,158 @@
|
|||||||
from pydantic import BaseModel
|
"""QQ OAuth2.0 认证实现"""
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
async def get_access_token(
|
from pydantic import BaseModel
|
||||||
|
from . import AccessTokenBase, OAuthBase
|
||||||
|
|
||||||
|
|
||||||
|
class QQAccessToken(AccessTokenBase):
|
||||||
|
"""QQ 访问令牌响应"""
|
||||||
|
expires_in: int
|
||||||
|
"""access token 的有效期,单位为秒"""
|
||||||
|
refresh_token: str
|
||||||
|
"""用于刷新 access token 的令牌"""
|
||||||
|
|
||||||
|
|
||||||
|
class QQOpenIDResponse(BaseModel):
|
||||||
|
"""QQ OpenID 响应"""
|
||||||
|
client_id: str
|
||||||
|
"""应用的 appid"""
|
||||||
|
openid: str
|
||||||
|
"""用户的唯一标识"""
|
||||||
|
|
||||||
|
|
||||||
|
class QQUserData(BaseModel):
|
||||||
|
"""QQ 用户数据"""
|
||||||
|
ret: int
|
||||||
|
"""返回码,0 表示成功"""
|
||||||
|
msg: str
|
||||||
|
"""返回信息"""
|
||||||
|
nickname: str | None
|
||||||
|
"""用户昵称"""
|
||||||
|
gender: str | None
|
||||||
|
"""性别"""
|
||||||
|
figureurl: str | None
|
||||||
|
"""头像 URL"""
|
||||||
|
figureurl_1: str | None
|
||||||
|
"""头像 URL(大图)"""
|
||||||
|
figureurl_2: str | None
|
||||||
|
"""头像 URL(更大图)"""
|
||||||
|
figureurl_qq_1: str | None
|
||||||
|
"""QQ 头像 URL(大图)"""
|
||||||
|
figureurl_qq_2: str | None
|
||||||
|
"""QQ 头像 URL(更大图)"""
|
||||||
|
is_yellow_vip: str | None
|
||||||
|
"""是否黄钻用户"""
|
||||||
|
vip: str | None
|
||||||
|
"""是否 VIP 用户"""
|
||||||
|
yellow_vip_level: str | None
|
||||||
|
"""黄钻等级"""
|
||||||
|
level: str | None
|
||||||
|
"""等级"""
|
||||||
|
is_yellow_year_vip: str | None
|
||||||
|
"""是否年费黄钻"""
|
||||||
|
|
||||||
|
|
||||||
|
class QQUserInfoResponse(BaseModel):
|
||||||
|
"""QQ 用户信息响应"""
|
||||||
code: str
|
code: str
|
||||||
):
|
"""状态码"""
|
||||||
...
|
openid: str
|
||||||
|
"""用户 OpenID"""
|
||||||
|
user_data: QQUserData
|
||||||
|
"""用户数据"""
|
||||||
|
|
||||||
|
|
||||||
|
class QQOAuth(OAuthBase):
|
||||||
|
"""QQ OAuth2.0 客户端"""
|
||||||
|
|
||||||
|
access_token_url = "https://graph.qq.com/oauth2.0/token"
|
||||||
|
"""获取 Access Token 的 API 地址"""
|
||||||
|
|
||||||
|
user_info_url = "https://graph.qq.com/user/get_user_info"
|
||||||
|
"""获取用户信息的 API 地址"""
|
||||||
|
|
||||||
|
openid_url = "https://graph.qq.com/oauth2.0/me"
|
||||||
|
"""获取 OpenID 的 API 地址"""
|
||||||
|
|
||||||
|
http_method = "GET"
|
||||||
|
"""获取 token 的 HTTP 方法"""
|
||||||
|
|
||||||
|
async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken:
|
||||||
|
"""
|
||||||
|
通过 Authorization Code 获取 Access Token
|
||||||
|
|
||||||
|
Args:
|
||||||
|
code: 授权码
|
||||||
|
redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QQAccessToken: 访问令牌
|
||||||
|
|
||||||
|
文档:
|
||||||
|
https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
'grant_type': 'authorization_code',
|
||||||
|
'client_id': self.client_id,
|
||||||
|
'client_secret': self.client_secret,
|
||||||
|
'code': code,
|
||||||
|
'redirect_uri': redirect_uri,
|
||||||
|
'fmt': 'json',
|
||||||
|
'need_openid': 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url=self.access_token_url, params=params) as access_resp:
|
||||||
|
access_data = await access_resp.json()
|
||||||
|
return QQAccessToken(
|
||||||
|
access_token=access_data.get('access_token'),
|
||||||
|
expires_in=access_data.get('expires_in'),
|
||||||
|
refresh_token=access_data.get('refresh_token'),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_openid(self, access_token: str) -> QQOpenIDResponse:
|
||||||
|
"""
|
||||||
|
获取用户 OpenID
|
||||||
|
|
||||||
|
注意:如果在 get_access_token 时传入了 need_openid=1,响应中已包含 openid,
|
||||||
|
无需额外调用此接口。此函数用于单独获取 openid 的场景。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
access_token: 访问令牌
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QQOpenIDResponse: 包含 client_id 和 openid
|
||||||
|
|
||||||
|
文档:
|
||||||
|
https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(
|
||||||
|
url=self.openid_url,
|
||||||
|
params={
|
||||||
|
'access_token': access_token,
|
||||||
|
'fmt': 'json',
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
data = await resp.json()
|
||||||
|
return QQOpenIDResponse(
|
||||||
|
client_id=data.get('client_id'),
|
||||||
|
openid=data.get('openid'),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
|
||||||
|
"""构建 QQ 用户信息请求参数"""
|
||||||
|
return {
|
||||||
|
'access_token': access_token,
|
||||||
|
'oauth_consumer_key': kwargs.get('app_id', self.client_id),
|
||||||
|
'openid': kwargs.get('openid', ''),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_user_response(self, data: dict) -> QQUserInfoResponse:
|
||||||
|
"""解析 QQ 用户信息响应"""
|
||||||
|
return QQUserInfoResponse(
|
||||||
|
code='0' if data.get('ret') == 0 else str(data.get('ret')),
|
||||||
|
openid=data.get('openid', ''),
|
||||||
|
user_data=QQUserData(**data),
|
||||||
|
)
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ async def login(
|
|||||||
# TODO: 验证码校验
|
# TODO: 验证码校验
|
||||||
# captcha_setting = await Setting.get(
|
# captcha_setting = await Setting.get(
|
||||||
# session,
|
# session,
|
||||||
# and_(Setting.type == "auth", Setting.name == "login_captcha")
|
# (Setting.type == "auth") & (Setting.name == "login_captcha")
|
||||||
# )
|
# )
|
||||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
current_user = await User.get(session, User.username == login_request.username, fetch_mode="first")
|
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
|
||||||
|
|
||||||
# 验证用户是否存在
|
# 验证用户是否存在
|
||||||
if not current_user:
|
if not current_user:
|
||||||
|
|||||||
23
tests/fixtures/groups.py
vendored
23
tests/fixtures/groups.py
vendored
@@ -42,10 +42,9 @@ class GroupFactory:
|
|||||||
speed_limit=kwargs.get("speed_limit", 0),
|
speed_limit=kwargs.get("speed_limit", 0),
|
||||||
)
|
)
|
||||||
|
|
||||||
group = await group.save(session)
|
|
||||||
|
|
||||||
# 如果提供了选项参数,创建 GroupOptions
|
# 如果提供了选项参数,创建 GroupOptions
|
||||||
if kwargs.get("create_options", False):
|
if kwargs.get("create_options", False):
|
||||||
|
group = await group.save(session, commit=False)
|
||||||
options = GroupOptions(
|
options = GroupOptions(
|
||||||
group_id=group.id,
|
group_id=group.id,
|
||||||
share_download=kwargs.get("share_download", True),
|
share_download=kwargs.get("share_download", True),
|
||||||
@@ -55,7 +54,10 @@ class GroupFactory:
|
|||||||
select_node=kwargs.get("select_node", False),
|
select_node=kwargs.get("select_node", False),
|
||||||
advance_delete=kwargs.get("advance_delete", False),
|
advance_delete=kwargs.get("advance_delete", False),
|
||||||
)
|
)
|
||||||
await options.save(session)
|
await options.save(session, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
group = await group.save(session)
|
||||||
|
|
||||||
return group
|
return group
|
||||||
|
|
||||||
@@ -88,7 +90,7 @@ class GroupFactory:
|
|||||||
speed_limit=0,
|
speed_limit=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
admin_group = await admin_group.save(session)
|
admin_group = await admin_group.save(session, commit=False)
|
||||||
|
|
||||||
# 创建管理员组选项
|
# 创建管理员组选项
|
||||||
admin_options = GroupOptions(
|
admin_options = GroupOptions(
|
||||||
@@ -105,7 +107,8 @@ class GroupFactory:
|
|||||||
aria2=True,
|
aria2=True,
|
||||||
redirected_source=True,
|
redirected_source=True,
|
||||||
)
|
)
|
||||||
await admin_options.save(session)
|
await admin_options.save(session, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return admin_group
|
return admin_group
|
||||||
|
|
||||||
@@ -140,7 +143,7 @@ class GroupFactory:
|
|||||||
speed_limit=1024, # 1MB/s
|
speed_limit=1024, # 1MB/s
|
||||||
)
|
)
|
||||||
|
|
||||||
limited_group = await limited_group.save(session)
|
limited_group = await limited_group.save(session, commit=False)
|
||||||
|
|
||||||
# 创建限制组选项
|
# 创建限制组选项
|
||||||
limited_options = GroupOptions(
|
limited_options = GroupOptions(
|
||||||
@@ -152,7 +155,8 @@ class GroupFactory:
|
|||||||
select_node=False,
|
select_node=False,
|
||||||
advance_delete=False,
|
advance_delete=False,
|
||||||
)
|
)
|
||||||
await limited_options.save(session)
|
await limited_options.save(session, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return limited_group
|
return limited_group
|
||||||
|
|
||||||
@@ -185,7 +189,7 @@ class GroupFactory:
|
|||||||
speed_limit=512, # 512KB/s
|
speed_limit=512, # 512KB/s
|
||||||
)
|
)
|
||||||
|
|
||||||
free_group = await free_group.save(session)
|
free_group = await free_group.save(session, commit=False)
|
||||||
|
|
||||||
# 创建免费组选项
|
# 创建免费组选项
|
||||||
free_options = GroupOptions(
|
free_options = GroupOptions(
|
||||||
@@ -197,6 +201,7 @@ class GroupFactory:
|
|||||||
select_node=False,
|
select_node=False,
|
||||||
advance_delete=False,
|
advance_delete=False,
|
||||||
)
|
)
|
||||||
await free_options.save(session)
|
await free_options.save(session, commit=False)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return free_group
|
return free_group
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(
|
|||||||
refreshUrl="/api/v1/user/session/refresh",
|
refreshUrl="/api/v1/user/session/refresh",
|
||||||
)
|
)
|
||||||
|
|
||||||
SECRET_KEY = ''
|
SECRET_KEY: str = ''
|
||||||
|
|
||||||
|
|
||||||
async def load_secret_key() -> None:
|
async def load_secret_key() -> None:
|
||||||
@@ -26,10 +26,10 @@ async def load_secret_key() -> None:
|
|||||||
|
|
||||||
global SECRET_KEY
|
global SECRET_KEY
|
||||||
async for session in get_session():
|
async for session in get_session():
|
||||||
setting = await Setting.get(
|
setting: Setting = await Setting.get(
|
||||||
session,
|
session,
|
||||||
(Setting.type == "auth") & (Setting.name == "secret_key")
|
(Setting.type == "auth") & (Setting.name == "secret_key")
|
||||||
)
|
) # type: ignore
|
||||||
if setting:
|
if setting:
|
||||||
SECRET_KEY = setting.value
|
SECRET_KEY = setting.value
|
||||||
|
|
||||||
@@ -40,7 +40,14 @@ def build_token_payload(
|
|||||||
algorithm: str,
|
algorithm: str,
|
||||||
expires_delta: timedelta | None = None,
|
expires_delta: timedelta | None = None,
|
||||||
) -> tuple[str, datetime]:
|
) -> tuple[str, datetime]:
|
||||||
"""构建令牌"""
|
"""
|
||||||
|
构建令牌。
|
||||||
|
|
||||||
|
:param data: 需要放进 JWT Payload 的字段
|
||||||
|
:param is_refresh: 是否为刷新令牌
|
||||||
|
:param algorithm: JWT 签名算法
|
||||||
|
:param expires_delta: 过期时间
|
||||||
|
"""
|
||||||
|
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
|
|
||||||
@@ -61,8 +68,11 @@ def build_token_payload(
|
|||||||
|
|
||||||
|
|
||||||
# 访问令牌
|
# 访问令牌
|
||||||
def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
def create_access_token(
|
||||||
algorithm: str = "HS256") -> AccessTokenBase:
|
data: dict,
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
algorithm: str = "HS256"
|
||||||
|
) -> AccessTokenBase:
|
||||||
"""
|
"""
|
||||||
生成访问令牌,默认有效期 3 小时。
|
生成访问令牌,默认有效期 3 小时。
|
||||||
|
|
||||||
@@ -73,7 +83,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
|||||||
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
|
access_token, expire_at = build_token_payload(
|
||||||
|
data,
|
||||||
|
False,
|
||||||
|
algorithm,
|
||||||
|
expires_delta
|
||||||
|
)
|
||||||
return AccessTokenBase(
|
return AccessTokenBase(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
access_expires=expire_at,
|
access_expires=expire_at,
|
||||||
@@ -81,8 +96,11 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
|||||||
|
|
||||||
|
|
||||||
# 刷新令牌
|
# 刷新令牌
|
||||||
def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
|
def create_refresh_token(
|
||||||
algorithm: str = "HS256") -> RefreshTokenBase:
|
data: dict,
|
||||||
|
expires_delta: timedelta | None = None,
|
||||||
|
algorithm: str = "HS256"
|
||||||
|
) -> RefreshTokenBase:
|
||||||
"""
|
"""
|
||||||
生成刷新令牌,默认有效期 30 天。
|
生成刷新令牌,默认有效期 30 天。
|
||||||
|
|
||||||
@@ -93,7 +111,12 @@ def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
|
|||||||
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
|
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
|
refresh_token, expire_at = build_token_payload(
|
||||||
|
data,
|
||||||
|
True,
|
||||||
|
algorithm,
|
||||||
|
expires_delta
|
||||||
|
)
|
||||||
return RefreshTokenBase(
|
return RefreshTokenBase(
|
||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
refresh_expires=expire_at,
|
refresh_expires=expire_at,
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ license_info = {"name": "GPLv3", "url": "https://opensource.org/license/gpl-3.0"
|
|||||||
BackendVersion = "0.0.1"
|
BackendVersion = "0.0.1"
|
||||||
"""后端版本"""
|
"""后端版本"""
|
||||||
|
|
||||||
IsPro = False
|
IsPro: bool = False
|
||||||
|
|
||||||
mode: str = os.getenv('MODE', 'master')
|
mode: str = os.getenv('MODE', 'master')
|
||||||
"""运行模式"""
|
"""运行模式"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
@@ -11,7 +12,23 @@ from pydantic import BaseModel, Field
|
|||||||
from utils.JWT import SECRET_KEY
|
from utils.JWT import SECRET_KEY
|
||||||
from utils.conf import appmeta
|
from utils.conf import appmeta
|
||||||
|
|
||||||
_ph = PasswordHasher()
|
# FIRST RECOMMENDED option per RFC 9106.
|
||||||
|
_ph_lowmem = PasswordHasher(
|
||||||
|
salt_len=16,
|
||||||
|
hash_len=32,
|
||||||
|
time_cost=3,
|
||||||
|
memory_cost=65536, # 64 MiB
|
||||||
|
parallelism=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# SECOND RECOMMENDED option per RFC 9106.
|
||||||
|
_ph_highmem = PasswordHasher(
|
||||||
|
salt_len=16,
|
||||||
|
hash_len=32,
|
||||||
|
time_cost=1,
|
||||||
|
memory_cost=2097152, # 2 GiB
|
||||||
|
parallelism=4,
|
||||||
|
)
|
||||||
|
|
||||||
class PasswordStatus(StrEnum):
|
class PasswordStatus(StrEnum):
|
||||||
"""密码校验状态枚举"""
|
"""密码校验状态枚举"""
|
||||||
@@ -48,7 +65,8 @@ class Password:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(
|
def generate(
|
||||||
length: int = 8
|
length: int = 8,
|
||||||
|
url_safe: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
生成指定长度的随机密码。
|
生成指定长度的随机密码。
|
||||||
@@ -58,7 +76,16 @@ class Password:
|
|||||||
:return: 随机密码
|
:return: 随机密码
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
return secrets.token_hex(length)
|
if url_safe:
|
||||||
|
return secrets.token_urlsafe(length)
|
||||||
|
else:
|
||||||
|
return secrets.token_hex(length)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_hex(
|
||||||
|
length: int = 8
|
||||||
|
) -> bytes:
|
||||||
|
return secrets.token_bytes(length)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash(
|
def hash(
|
||||||
@@ -72,7 +99,7 @@ class Password:
|
|||||||
:param password: 需要哈希的原始密码
|
:param password: 需要哈希的原始密码
|
||||||
:return: Argon2 哈希字符串
|
:return: Argon2 哈希字符串
|
||||||
"""
|
"""
|
||||||
return _ph.hash(password)
|
return _ph_lowmem.hash(password)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(
|
def verify(
|
||||||
@@ -87,21 +114,16 @@ class Password:
|
|||||||
:return: 如果密码匹配返回 True, 否则返回 False
|
:return: 如果密码匹配返回 True, 否则返回 False
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# verify 函数会自动解析 stored_password 中的盐和参数
|
_ph_lowmem.verify(hash, password)
|
||||||
_ph.verify(hash, password)
|
|
||||||
|
|
||||||
# 检查哈希参数是否已过时。如果返回True,
|
# 检查哈希参数是否已过时
|
||||||
# 意味着你应该使用新的参数重新哈希密码并更新存储。
|
if _ph_lowmem.check_needs_rehash(hash):
|
||||||
# 这是一个很好的实践,可以随着时间推移增强安全性。
|
|
||||||
if _ph.check_needs_rehash(hash):
|
|
||||||
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
|
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
|
||||||
return PasswordStatus.EXPIRED
|
return PasswordStatus.EXPIRED
|
||||||
|
|
||||||
return PasswordStatus.VALID
|
return PasswordStatus.VALID
|
||||||
except VerifyMismatchError:
|
except VerifyMismatchError:
|
||||||
# 这是预期的异常,当密码不匹配时触发。
|
|
||||||
return PasswordStatus.INVALID
|
return PasswordStatus.INVALID
|
||||||
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def generate_totp(
|
async def generate_totp(
|
||||||
|
|||||||
Reference in New Issue
Block a user