Refactor and enhance OAuth2.0 implementation; update models and routes
- Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity. - Added OAUTH type to SettingsType enum. - Cleaned up imports in webdav.py. - Updated admin router to improve summary data retrieval and response handling. - Enhanced file management routes with better condition handling and user storage updates. - Improved group management routes by optimizing data retrieval. - Refined task management routes for better condition handling. - Updated user management routes to streamline access token retrieval. - Implemented a new captcha verification structure with abstract base class. - Removed deprecated env.md file and replaced with a new structured version. - Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations. - Enhanced password management with improved hashing strategies. - Added detailed comments and documentation throughout the codebase for clarity.
This commit is contained in:
9
.claude/settings.local.json
Normal file
9
.claude/settings.local.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(git rev-parse:*)",
|
||||
"Bash(findstr:*)",
|
||||
"Bash(find:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -9,4 +9,4 @@
|
||||
- `REDIS_PORT`: Redis 端口
|
||||
- `REDIS_PASSWORD`: Redis 密码
|
||||
- `REDIS_DB`: Redis 数据库
|
||||
- `REDIS_PROTOCOL`
|
||||
- `REDIS_PROTOCOL`: Redis 协议
|
||||
16
main.py
16
main.py
@@ -33,8 +33,22 @@ app = FastAPI(
|
||||
openapi_url="/openapi.json" if appmeta.debug else None,
|
||||
)
|
||||
|
||||
# 添加跨域 CORS 中间件,仅在调试模式下启用,以允许所有来源访问 API
|
||||
if appmeta.debug:
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def handle_unexpected_exceptions(request: Request, exc: Exception) -> NoReturn:
|
||||
async def handle_unexpected_exceptions(
|
||||
request: Request,
|
||||
exc: Exception
|
||||
) -> NoReturn:
|
||||
"""
|
||||
捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。
|
||||
"""
|
||||
|
||||
@@ -60,10 +60,10 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
|
||||
try:
|
||||
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
if payload.get("type") != "download":
|
||||
return None
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
jti = payload.get("jti")
|
||||
if not jti:
|
||||
return None
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
|
||||
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
|
||||
return None
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("Download token required")
|
||||
@@ -120,7 +120,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
from sqlalchemy import and_
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
@@ -128,7 +127,7 @@ async def init_default_settings() -> None:
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
and_(Setting.type == SettingsType.VERSION, Setting.name == f"db_version_{BackendVersion}")
|
||||
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
|
||||
)
|
||||
if ver and ver.value == "installed":
|
||||
return
|
||||
|
||||
@@ -218,7 +218,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
|
||||
"""
|
||||
向数据库中添加一个新的或多个新的记录.
|
||||
|
||||
@@ -230,6 +230,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
|
||||
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
|
||||
@@ -244,6 +246,11 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 添加单个实例
|
||||
item3 = Item(name="Cherry")
|
||||
added_item = await Item.add(session, item3)
|
||||
|
||||
# 批量操作,减少提交次数
|
||||
await Item.add(session, [item1, item2], commit=False)
|
||||
await Item.add(session, [item3, item4], commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
@@ -252,7 +259,10 @@ class TableBaseMixin(AsyncAttrs):
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
@@ -266,15 +276,16 @@ class TableBaseMixin(AsyncAttrs):
|
||||
async def save(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
load: RelationshipInfo | None = None,
|
||||
refresh: bool = True
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
保存(插入或更新)当前模型实例到数据库.
|
||||
|
||||
这是一个实例方法,它将当前对象添加到会话中并提交更改。
|
||||
可以用于创建新记录或更新现有记录。还可以选择在保存后
|
||||
预加载(eager load)一个关联关系.
|
||||
预加载(eager load)一个或多个关联关系.
|
||||
|
||||
**重要**:调用此方法后,session中的所有对象都会过期(expired)。
|
||||
如果需要继续使用该对象,必须使用返回值:
|
||||
@@ -287,6 +298,14 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.save(session, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await item1.save(session, commit=False)
|
||||
await item2.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.save(session, load=[User.group, User.tags])
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.save(session)
|
||||
return client # client 对象已过期
|
||||
@@ -294,16 +313,22 @@ class TableBaseMixin(AsyncAttrs):
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
例如 `User.posts`.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.posts` 或 `[User.group, User.tags]`.
|
||||
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
"""
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
@@ -324,8 +349,9 @@ class TableBaseMixin(AsyncAttrs):
|
||||
extra_data: dict[str, Any] | None = None,
|
||||
exclude_unset: bool = True,
|
||||
exclude: set[str] | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
refresh: bool = True
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
refresh: bool = True,
|
||||
commit: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
使用另一个模型实例或字典中的数据来更新当前实例.
|
||||
@@ -348,6 +374,14 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
|
||||
await client.update(session, update_data, refresh=False)
|
||||
|
||||
# ✅ 正确:批量操作,减少提交次数
|
||||
await user1.update(session, data1, commit=False)
|
||||
await user2.update(session, data2, commit=False)
|
||||
await session.commit()
|
||||
|
||||
# ✅ 正确:批量操作并预加载多个关联关系
|
||||
user = await user.update(session, data, load=[User.group, User.tags])
|
||||
|
||||
# ❌ 错误:需要返回值但未使用
|
||||
await client.update(session, update_data)
|
||||
return client # client 对象已过期
|
||||
@@ -360,10 +394,13 @@ class TableBaseMixin(AsyncAttrs):
|
||||
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
|
||||
的字段将被忽略. 默认为 True.
|
||||
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
|
||||
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
例如 `User.permission`.
|
||||
load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.permission` 或 `[User.group, User.tags]`.
|
||||
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
|
||||
设为 False 可节省一次数据库查询。默认为 True.
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self.
|
||||
@@ -374,7 +411,10 @@ class TableBaseMixin(AsyncAttrs):
|
||||
)
|
||||
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
if commit:
|
||||
await session.commit()
|
||||
else:
|
||||
await session.flush()
|
||||
|
||||
if not refresh:
|
||||
return self
|
||||
@@ -388,33 +428,82 @@ class TableBaseMixin(AsyncAttrs):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
||||
async def delete(
|
||||
cls: type[T],
|
||||
session: AsyncSession,
|
||||
instances: T | list[T] | None = None,
|
||||
*,
|
||||
condition: BinaryExpression | ClauseElement | None = None,
|
||||
commit: bool = True
|
||||
) -> int:
|
||||
"""
|
||||
从数据库中删除一个或多个记录.
|
||||
从数据库中删除记录.
|
||||
|
||||
支持两种删除方式:
|
||||
1. 实例删除:传入 instances 参数,先加载再删除
|
||||
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
|
||||
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
|
||||
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
|
||||
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
|
||||
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
|
||||
之后需要手动调用 `session.commit()`。默认为 True.
|
||||
|
||||
Returns:
|
||||
None
|
||||
int: 删除的记录数量
|
||||
|
||||
Usage:
|
||||
# 实例删除
|
||||
item_to_delete = await Item.get(session, Item.id == 1)
|
||||
if item_to_delete:
|
||||
await Item.delete(session, item_to_delete)
|
||||
deleted_count = await Item.delete(session, item_to_delete)
|
||||
|
||||
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
|
||||
if items_to_delete:
|
||||
await Item.delete(session, items_to_delete)
|
||||
# 条件删除(更高效,无需加载实例)
|
||||
deleted_count = await Item.delete(
|
||||
session,
|
||||
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
|
||||
)
|
||||
|
||||
# 批量删除后手动提交
|
||||
await Item.delete(session, item1, commit=False)
|
||||
await Item.delete(session, item2, commit=False)
|
||||
await session.commit()
|
||||
"""
|
||||
# 条件删除模式
|
||||
if condition is not None:
|
||||
from sqlmodel import delete as sql_delete
|
||||
|
||||
if instances is not None:
|
||||
raise ValueError("不能同时指定 instances 和 condition")
|
||||
|
||||
# 执行条件删除
|
||||
stmt = sql_delete(cls).where(condition)
|
||||
result = await session.exec(stmt)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
# 实例删除模式(原有逻辑)
|
||||
if instances is None:
|
||||
raise ValueError("必须指定 instances 或 condition")
|
||||
|
||||
deleted_count = 0
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
deleted_count += 1
|
||||
else:
|
||||
await session.delete(instances)
|
||||
deleted_count = 1
|
||||
|
||||
await session.commit()
|
||||
if commit:
|
||||
await session.commit()
|
||||
|
||||
return deleted_count
|
||||
|
||||
@classmethod
|
||||
def _build_time_filters(
|
||||
@@ -458,7 +547,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
with_for_update: bool = False,
|
||||
@@ -491,8 +580,9 @@ class TableBaseMixin(AsyncAttrs):
|
||||
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
|
||||
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
|
||||
例如 `[selectinload(User.posts)]`.
|
||||
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
|
||||
例如 `User.profile`.
|
||||
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
|
||||
可以是单个关系或关系列表.
|
||||
例如 `User.profile` 或 `[User.group, User.tags]`.
|
||||
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
|
||||
例如 `[User.name.asc(), User.created_at.desc()]`.
|
||||
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
|
||||
@@ -595,9 +685,15 @@ class TableBaseMixin(AsyncAttrs):
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
# 标准化为列表
|
||||
load_list = load if isinstance(load, list) else [load]
|
||||
|
||||
# 处理多态加载
|
||||
if load_polymorphic is not None:
|
||||
target_class = load.property.mapper.class_
|
||||
# 多态加载只支持单个关系
|
||||
if len(load_list) > 1:
|
||||
raise ValueError("load_polymorphic 仅支持单个关系")
|
||||
target_class = load_list[0].property.mapper.class_
|
||||
|
||||
# 检查目标类是否继承自 PolymorphicBaseMixin
|
||||
if not issubclass(target_class, PolymorphicBaseMixin):
|
||||
@@ -609,7 +705,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
if load_polymorphic == 'all':
|
||||
# 两阶段查询:获取实际关联的多态类型
|
||||
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
|
||||
session, condition, load, target_class
|
||||
session, condition, load_list[0], target_class
|
||||
)
|
||||
else:
|
||||
subclasses_to_load = load_polymorphic
|
||||
@@ -618,12 +714,14 @@ class TableBaseMixin(AsyncAttrs):
|
||||
# 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项
|
||||
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
|
||||
statement = statement.options(
|
||||
selectinload(load).selectin_polymorphic(subclasses_to_load)
|
||||
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
|
||||
)
|
||||
else:
|
||||
statement = statement.options(selectinload(load))
|
||||
statement = statement.options(selectinload(load_list[0]))
|
||||
else:
|
||||
statement = statement.options(selectinload(load))
|
||||
# 为每个关系添加 selectinload
|
||||
for rel in load_list:
|
||||
statement = statement.options(selectinload(rel))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
@@ -796,7 +894,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
*,
|
||||
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: RelationshipInfo | None = None,
|
||||
load: RelationshipInfo | list[RelationshipInfo] | None = None,
|
||||
order_by: list[ClauseElement] | None = None,
|
||||
filter: BinaryExpression | ClauseElement | None = None,
|
||||
table_view: TableViewRequest | None = None,
|
||||
@@ -865,7 +963,7 @@ class TableBaseMixin(AsyncAttrs):
|
||||
return ListResponse(count=total_count, items=items)
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
|
||||
"""
|
||||
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
@@ -875,7 +973,8 @@ class TableBaseMixin(AsyncAttrs):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (int): 要查找的记录的主键 ID.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
@@ -903,7 +1002,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
|
||||
@override
|
||||
@classmethod
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
|
||||
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
|
||||
"""
|
||||
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
|
||||
|
||||
@@ -913,7 +1012,8 @@ class UUIDTableBaseMixin(TableBaseMixin):
|
||||
Args:
|
||||
session (AsyncSession): 用于数据库操作的异步会话对象.
|
||||
id (uuid.UUID): 要查找的记录的 UUID 主键.
|
||||
load (Relationship | None): 可选的,用于预加载的关联属性.
|
||||
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
|
||||
可以是单个关系或关系列表.
|
||||
|
||||
Returns:
|
||||
T: 找到的模型实例.
|
||||
|
||||
@@ -79,9 +79,8 @@ class VersionInfo(SQLModelBase):
|
||||
commit: str
|
||||
"""提交哈希"""
|
||||
|
||||
|
||||
class AdminSummaryData(SQLModelBase):
|
||||
"""管理员概况数据"""
|
||||
class AdminSummaryResponse(ResponseBase):
|
||||
"""管理员概况响应"""
|
||||
|
||||
metrics_summary: MetricsSummary
|
||||
"""统计摘要"""
|
||||
@@ -95,13 +94,6 @@ class AdminSummaryData(SQLModelBase):
|
||||
version: VersionInfo
|
||||
"""版本信息"""
|
||||
|
||||
|
||||
class AdminSummaryResponse(ResponseBase):
|
||||
"""管理员概况响应"""
|
||||
|
||||
data: AdminSummaryData | None = None
|
||||
"""响应数据"""
|
||||
|
||||
class MCPMethod(StrEnum):
|
||||
"""MCP 方法枚举"""
|
||||
|
||||
|
||||
@@ -104,6 +104,7 @@ class SettingsType(StrEnum):
|
||||
MAIL = "mail"
|
||||
MAIL_TEMPLATE = "mail_template"
|
||||
MOBILE = "mobile"
|
||||
OAUTH = "oauth"
|
||||
PATH = "path"
|
||||
PREVIEW = "preview"
|
||||
PWA = "pwa"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase
|
||||
from .mixin import TableBaseMixin
|
||||
|
||||
@@ -2,14 +2,12 @@ from datetime import datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
AdminSummaryData,
|
||||
)
|
||||
from models.base import SQLModelBase
|
||||
from models.setting import (
|
||||
@@ -75,8 +73,8 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
||||
Returns:
|
||||
AdminSummaryResponse: 包含站点概况信息的响应模型。
|
||||
"""
|
||||
# 统计最近 12 天的数据
|
||||
days_count = 12
|
||||
# 统计最近 14 天的数据
|
||||
days_count = 14
|
||||
now = datetime.now()
|
||||
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
@@ -135,7 +133,7 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
||||
site_urls: list[str] = []
|
||||
site_url_setting = await Setting.get(
|
||||
session,
|
||||
and_(Setting.type == SettingsType.BASIC, Setting.name == "siteURL"),
|
||||
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
|
||||
)
|
||||
if site_url_setting and site_url_setting.value:
|
||||
site_urls.append(site_url_setting.value)
|
||||
@@ -156,15 +154,13 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
|
||||
commit="dev",
|
||||
)
|
||||
|
||||
data = AdminSummaryData(
|
||||
return AdminSummaryResponse(
|
||||
metrics_summary=metrics_summary,
|
||||
site_urls=site_urls,
|
||||
license=license_info,
|
||||
version=version_info,
|
||||
)
|
||||
|
||||
return AdminSummaryResponse(data=data)
|
||||
|
||||
@admin_router.get(
|
||||
path='/news',
|
||||
summary='获取社区新闻',
|
||||
@@ -203,7 +199,7 @@ async def router_admin_update_settings(
|
||||
for item in request.settings:
|
||||
existing = await Setting.get(
|
||||
session,
|
||||
and_(Setting.type == item.type, Setting.name == item.name)
|
||||
(Setting.type == item.type) & (Setting.name == item.name)
|
||||
)
|
||||
|
||||
if existing:
|
||||
@@ -245,7 +241,12 @@ async def router_admin_get_settings(
|
||||
if name:
|
||||
conditions.append(Setting.name == name)
|
||||
|
||||
condition = and_(*conditions) if conditions else None
|
||||
if conditions:
|
||||
condition = conditions[0]
|
||||
for c in conditions[1:]:
|
||||
condition = condition & c
|
||||
else:
|
||||
condition = None
|
||||
|
||||
settings = await Setting.get(session, condition, fetch_mode="all")
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
@@ -51,7 +50,12 @@ async def router_admin_get_file_list(
|
||||
if keyword:
|
||||
conditions.append(Object.name.ilike(f"%{keyword}%"))
|
||||
|
||||
condition = and_(*conditions) if len(conditions) > 1 else conditions[0]
|
||||
if len(conditions) > 1:
|
||||
condition = conditions[0]
|
||||
for c in conditions[1:]:
|
||||
condition = condition & c
|
||||
else:
|
||||
condition = conditions[0]
|
||||
result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner)
|
||||
|
||||
# 构建响应
|
||||
@@ -197,13 +201,15 @@ async def router_admin_delete_file(
|
||||
except Exception as e:
|
||||
l.warning(f"删除物理文件失败: {e}")
|
||||
|
||||
# 更新用户存储量
|
||||
owner = await User.get(session, User.id == owner_id)
|
||||
if owner:
|
||||
owner.storage = max(0, owner.storage - file_size)
|
||||
await owner.save(session)
|
||||
# 更新用户存储量(使用 SQL UPDATE 直接更新,无需加载实例)
|
||||
from sqlmodel import update as sql_update
|
||||
stmt = sql_update(User).where(User.id == owner_id).values(
|
||||
storage=max(0, User.storage - file_size)
|
||||
)
|
||||
await session.exec(stmt)
|
||||
|
||||
await Object.delete(session, file_obj)
|
||||
# 使用条件删除
|
||||
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
@@ -63,12 +63,13 @@ async def router_admin_get_group(
|
||||
:param group_id: 用户组UUID
|
||||
:return: 用户组详情
|
||||
"""
|
||||
group = await Group.get(session, Group.id == group_id, load=Group.options)
|
||||
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
|
||||
|
||||
if not group:
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
policies = await group.awaitable_attrs.policies
|
||||
# 直接访问已加载的关系,无需额外查询
|
||||
policies = group.policies
|
||||
user_count = await User.count(session, User.group_id == group_id)
|
||||
response = GroupDetailResponse.from_group(group, user_count, policies)
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
@@ -43,7 +42,12 @@ async def router_admin_get_task_list(
|
||||
if status:
|
||||
conditions.append(Task.status == status)
|
||||
|
||||
condition = and_(*conditions) if conditions else None
|
||||
if conditions:
|
||||
condition = conditions[0]
|
||||
for c in conditions[1:]:
|
||||
condition = condition & c
|
||||
else:
|
||||
condition = None
|
||||
result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user)
|
||||
|
||||
items: list[TaskSummary] = []
|
||||
|
||||
@@ -2,7 +2,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger as l
|
||||
from sqlalchemy import func, and_
|
||||
from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
@@ -198,7 +198,7 @@ async def router_admin_calibrate_storage(
|
||||
from sqlmodel import select
|
||||
result = await session.execute(
|
||||
select(func.sum(Object.size), func.count(Object.id)).where(
|
||||
and_(Object.owner_id == user_id, Object.type == ObjectType.FILE)
|
||||
(Object.owner_id == user_id) & (Object.type == ObjectType.FILE)
|
||||
)
|
||||
)
|
||||
row = result.one()
|
||||
|
||||
@@ -233,7 +233,7 @@ async def upload_chunk(
|
||||
policy_id=upload_session.policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
@@ -246,11 +246,18 @@ async def upload_chunk(
|
||||
owner_id=user_id,
|
||||
policy_id=upload_session.policy_id,
|
||||
)
|
||||
file_object = await file_object.save(session)
|
||||
file_object = await file_object.save(session, commit=False)
|
||||
file_object_id = file_object.id
|
||||
|
||||
# 删除上传会话
|
||||
await UploadSession.delete(session, upload_session)
|
||||
# 删除上传会话(使用条件删除)
|
||||
await UploadSession.delete(
|
||||
session,
|
||||
condition=UploadSession.id == upload_session.id,
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
|
||||
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from fastapi import APIRouter
|
||||
from sqlalchemy import and_
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
@@ -55,5 +54,5 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
dict: The site configuration.
|
||||
"""
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")),
|
||||
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||
)
|
||||
@@ -3,7 +3,6 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy import and_
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
@@ -115,7 +114,7 @@ async def router_user_register(
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||
session,
|
||||
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
|
||||
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
@@ -352,18 +351,18 @@ async def router_user_authn_start(
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await models.Setting.get(
|
||||
session,
|
||||
and_(models.Setting.type == "authn", models.Setting.name == "authn_enabled")
|
||||
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
|
||||
site_url_setting = await models.Setting.get(
|
||||
session,
|
||||
and_(models.Setting.type == "basic", models.Setting.name == "siteURL")
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
|
||||
)
|
||||
site_title_setting = await models.Setting.get(
|
||||
session,
|
||||
and_(models.Setting.type == "basic", models.Setting.name == "siteTitle")
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
|
||||
)
|
||||
|
||||
options = generate_registration_options(
|
||||
|
||||
@@ -1,5 +1,39 @@
|
||||
import abc
|
||||
import aiohttp
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .gcaptcha import GCaptcha
|
||||
from .turnstile import TurnstileCaptcha
|
||||
|
||||
|
||||
class CaptchaRequestBase(BaseModel):
|
||||
"""验证码验证请求"""
|
||||
token: str
|
||||
secret: str
|
||||
"""验证 token"""
|
||||
secret: str
|
||||
"""验证密钥"""
|
||||
|
||||
|
||||
class CaptchaBase(abc.ABC):
|
||||
"""验证码验证器抽象基类"""
|
||||
|
||||
verify_url: str
|
||||
"""验证 API 地址(子类必须定义)"""
|
||||
|
||||
async def verify_captcha(self, request: CaptchaRequestBase) -> bool:
|
||||
"""
|
||||
验证 token 是否有效。
|
||||
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:rtype: bool
|
||||
"""
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
return result.get('success', False)
|
||||
@@ -1,21 +1,7 @@
|
||||
import aiohttp
|
||||
from . import CaptchaBase
|
||||
|
||||
from . import CaptchaRequestBase
|
||||
|
||||
async def verify_captcha(request: CaptchaRequestBase) -> bool:
|
||||
"""
|
||||
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
|
||||
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:rtype: bool
|
||||
"""
|
||||
verify_url = "https://www.google.com/recaptcha/api/siteverify"
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
return result.get('success', False)
|
||||
class GCaptcha(CaptchaBase):
|
||||
"""Google reCAPTCHA v2/v3 验证器"""
|
||||
|
||||
verify_url = "https://www.google.com/recaptcha/api/siteverify"
|
||||
@@ -1,21 +1,7 @@
|
||||
import aiohttp
|
||||
from . import CaptchaBase
|
||||
|
||||
from . import CaptchaRequestBase
|
||||
|
||||
async def verify_captcha(request: CaptchaRequestBase) -> bool:
|
||||
"""
|
||||
验证 Turnstile 的 token 是否有效。
|
||||
|
||||
:return: 如果验证成功返回 True,否则返回 False
|
||||
:rtype: bool
|
||||
"""
|
||||
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
payload = request.model_dump()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(verify_url, data=payload) as response:
|
||||
if response.status != 200:
|
||||
return False
|
||||
|
||||
result = await response.json()
|
||||
return result.get('success', False)
|
||||
class TurnstileCaptcha(CaptchaBase):
|
||||
"""Cloudflare Turnstile 验证器"""
|
||||
|
||||
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
223
service/oauth/__init__.py
Normal file
223
service/oauth/__init__.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
OAuth2.0 认证模块
|
||||
|
||||
提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。
|
||||
"""
|
||||
import abc
|
||||
import aiohttp
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# ==================== 共享数据模型 ====================
|
||||
|
||||
class AccessTokenBase(BaseModel):
|
||||
"""访问令牌基类"""
|
||||
access_token: str
|
||||
"""访问令牌"""
|
||||
|
||||
|
||||
class OAuthUserData(BaseModel):
|
||||
"""OAuth 用户数据通用 DTO"""
|
||||
openid: str
|
||||
"""用户唯一标识(GitHub 为 id,QQ 为 openid)"""
|
||||
nickname: str | None
|
||||
"""用户昵称"""
|
||||
avatar_url: str | None
|
||||
"""头像 URL"""
|
||||
email: str | None
|
||||
"""邮箱"""
|
||||
bio: str | None
|
||||
"""个人简介"""
|
||||
|
||||
|
||||
class OAuthUserInfoResponse(BaseModel):
|
||||
"""OAuth 用户信息响应"""
|
||||
code: str
|
||||
"""状态码"""
|
||||
openid: str
|
||||
"""用户唯一标识"""
|
||||
user_data: OAuthUserData
|
||||
"""用户数据"""
|
||||
|
||||
|
||||
# ==================== OAuth2.0 抽象基类 ====================
|
||||
|
||||
class OAuthBase(abc.ABC):
|
||||
"""
|
||||
OAuth2.0 客户端抽象基类
|
||||
|
||||
子类需要定义以下类属性:
|
||||
- access_token_url: 获取 Access Token 的 API 地址
|
||||
- user_info_url: 获取用户信息的 API 地址
|
||||
- http_method: 获取 token 的 HTTP 方法(POST 或 GET)
|
||||
"""
|
||||
|
||||
# 子类必须定义的类属性
|
||||
access_token_url: str
|
||||
"""获取 Access Token 的 API 地址"""
|
||||
|
||||
user_info_url: str
|
||||
"""获取用户信息的 API 地址"""
|
||||
|
||||
http_method: str = "POST"
|
||||
"""获取 token 的 HTTP 方法:POST 或 GET"""
|
||||
|
||||
# 实例属性(构造函数传入)
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str) -> None:
|
||||
"""
|
||||
初始化 OAuth 客户端
|
||||
|
||||
Args:
|
||||
client_id: 应用 client_id
|
||||
client_secret: 应用 client_secret
|
||||
"""
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase:
|
||||
"""
|
||||
通过 Authorization Code 获取 Access Token
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
**kwargs: 额外参数(如 QQ 需要 redirect_uri)
|
||||
|
||||
Returns:
|
||||
AccessTokenBase: 访问令牌
|
||||
"""
|
||||
params = {
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
}
|
||||
params.update(kwargs)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if self.http_method == "POST":
|
||||
async with session.post(
|
||||
url=self.access_token_url,
|
||||
params=params,
|
||||
headers={'accept': 'application/json'},
|
||||
) as access_resp:
|
||||
access_data = await access_resp.json()
|
||||
return self._parse_token_response(access_data)
|
||||
else:
|
||||
async with session.get(
|
||||
url=self.access_token_url,
|
||||
params=params,
|
||||
) as access_resp:
|
||||
access_data = await access_resp.json()
|
||||
return self._parse_token_response(access_data)
|
||||
|
||||
async def get_user_info(
|
||||
self,
|
||||
access_token: str | AccessTokenBase,
|
||||
**kwargs
|
||||
) -> OAuthUserInfoResponse:
|
||||
"""
|
||||
获取用户信息
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
**kwargs: 额外参数(如 QQ 需要 app_id, openid)
|
||||
|
||||
Returns:
|
||||
OAuthUserInfoResponse: 用户信息
|
||||
"""
|
||||
if isinstance(access_token, AccessTokenBase):
|
||||
access_token = access_token.access_token
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url=self.user_info_url,
|
||||
params=self._build_user_info_params(access_token, **kwargs),
|
||||
headers=self._build_user_info_headers(access_token),
|
||||
) as resp:
|
||||
user_data = await resp.json()
|
||||
return self._parse_user_response(user_data)
|
||||
|
||||
# ==================== 钩子方法(子类可覆盖) ====================
|
||||
|
||||
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
|
||||
"""
|
||||
构建获取用户信息的请求参数
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
**kwargs: 额外参数
|
||||
|
||||
Returns:
|
||||
dict: 请求参数
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _build_user_info_headers(self, access_token: str) -> dict:
|
||||
"""
|
||||
构建获取用户信息的请求头
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
|
||||
Returns:
|
||||
dict: 请求头
|
||||
"""
|
||||
return {
|
||||
'accept': 'application/json',
|
||||
}
|
||||
|
||||
def _parse_token_response(self, data: dict) -> AccessTokenBase:
|
||||
"""
|
||||
解析 token 响应
|
||||
|
||||
Args:
|
||||
data: API 返回的数据
|
||||
|
||||
Returns:
|
||||
AccessTokenBase: 访问令牌
|
||||
"""
|
||||
return AccessTokenBase(access_token=data.get('access_token'))
|
||||
|
||||
def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse:
|
||||
"""
|
||||
解析用户信息响应
|
||||
|
||||
Args:
|
||||
data: API 返回的数据
|
||||
|
||||
Returns:
|
||||
OAuthUserInfoResponse: 用户信息
|
||||
"""
|
||||
return OAuthUserInfoResponse(
|
||||
code='0',
|
||||
openid='',
|
||||
user_data=OAuthUserData(openid=''),
|
||||
)
|
||||
|
||||
|
||||
# ==================== 导出 ====================
|
||||
|
||||
from .github import GithubOAuth, GithubAccessToken, GithubUserData
|
||||
from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData
|
||||
|
||||
__all__ = [
|
||||
# 共享模型
|
||||
'AccessTokenBase',
|
||||
'OAuthUserData',
|
||||
'OAuthUserInfoResponse',
|
||||
'OAuthBase',
|
||||
|
||||
# GitHub
|
||||
'GithubOAuth',
|
||||
'GithubAccessToken',
|
||||
'GithubUserData',
|
||||
|
||||
# QQ
|
||||
'QQOAuth',
|
||||
'QQAccessToken',
|
||||
'QQOpenIDResponse',
|
||||
'QQUserData',
|
||||
]
|
||||
@@ -1,77 +1,127 @@
|
||||
from pydantic import BaseModel
|
||||
import aiohttp
|
||||
"""GitHub OAuth2.0 认证实现"""
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
class GithubAccessToken(BaseModel):
|
||||
access_token: str
|
||||
from pydantic import BaseModel
|
||||
from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import OAuthUserData
|
||||
|
||||
|
||||
class GithubAccessToken(AccessTokenBase):
|
||||
"""GitHub 访问令牌响应"""
|
||||
token_type: str
|
||||
"""令牌类型"""
|
||||
scope: str
|
||||
"""授权范围"""
|
||||
|
||||
|
||||
class GithubUserData(BaseModel):
|
||||
"""GitHub 用户数据"""
|
||||
login: str
|
||||
"""用户名"""
|
||||
id: int
|
||||
"""用户 ID"""
|
||||
node_id: str
|
||||
"""节点 ID"""
|
||||
avatar_url: str
|
||||
"""头像 URL"""
|
||||
gravatar_id: str | None
|
||||
"""Gravatar ID"""
|
||||
url: str
|
||||
"""API URL"""
|
||||
html_url: str
|
||||
"""主页 URL"""
|
||||
followers_url: str
|
||||
"""粉丝列表 URL"""
|
||||
following_url: str
|
||||
"""关注列表 URL"""
|
||||
gists_url: str
|
||||
"""Gists 列表 URL"""
|
||||
starred_url: str
|
||||
"""星标列表 URL"""
|
||||
subscriptions_url: str
|
||||
"""订阅列表 URL"""
|
||||
organizations_url: str
|
||||
"""组织列表 URL"""
|
||||
repos_url: str
|
||||
"""仓库列表 URL"""
|
||||
events_url: str
|
||||
"""事件列表 URL"""
|
||||
received_events_url: str
|
||||
"""接收的事件列表 URL"""
|
||||
type: str
|
||||
"""用户类型"""
|
||||
site_admin: bool
|
||||
"""是否为站点管理员"""
|
||||
name: str | None
|
||||
"""显示名称"""
|
||||
company: str | None
|
||||
"""公司"""
|
||||
blog: str | None
|
||||
"""博客"""
|
||||
location: str | None
|
||||
"""位置"""
|
||||
email: str | None
|
||||
"""邮箱"""
|
||||
hireable: bool | None
|
||||
"""是否可雇佣"""
|
||||
bio: str | None
|
||||
"""个人简介"""
|
||||
twitter_username: str | None
|
||||
"""Twitter 用户名"""
|
||||
public_repos: int
|
||||
"""公开仓库数"""
|
||||
public_gists: int
|
||||
"""公开 Gists 数"""
|
||||
followers: int
|
||||
"""粉丝数"""
|
||||
following: int
|
||||
created_at: str # ISO 8601 format date-time string
|
||||
updated_at: str # ISO 8601 format date-time string
|
||||
"""关注数"""
|
||||
created_at: str
|
||||
"""创建时间(ISO 8601 格式)"""
|
||||
updated_at: str
|
||||
"""更新时间(ISO 8601 格式)"""
|
||||
|
||||
|
||||
class GithubUserInfoResponse(BaseModel):
|
||||
"""GitHub 用户信息响应"""
|
||||
code: str
|
||||
"""状态码"""
|
||||
user_data: GithubUserData
|
||||
"""用户数据"""
|
||||
|
||||
async def get_access_token(code: str) -> GithubAccessToken:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url='https://github.com/login/oauth/access_token',
|
||||
params={
|
||||
'client_id': '',
|
||||
'client_secret': '',
|
||||
'code': code
|
||||
},
|
||||
headers={'accept': 'application/json'},
|
||||
) as access_resp:
|
||||
access_data = await access_resp.json()
|
||||
return GithubAccessToken(
|
||||
access_token=access_data.get('access_token'),
|
||||
token_type=access_data.get('token_type'),
|
||||
scope=access_data.get('scope')
|
||||
)
|
||||
|
||||
async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse:
|
||||
if isinstance(access_token, GithubAccessToken):
|
||||
access_token = access_token.access_token
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url='https://api.github.com/user',
|
||||
headers={
|
||||
'accept': 'application/json',
|
||||
'Authorization': f'token {access_token}'},
|
||||
) as resp:
|
||||
user_data = await resp.json()
|
||||
return GithubUserInfoResponse(**user_data)
|
||||
class GithubOAuth(OAuthBase):
|
||||
"""GitHub OAuth2.0 客户端"""
|
||||
|
||||
access_token_url = "https://github.com/login/oauth/access_token"
|
||||
"""获取 Access Token 的 API 地址"""
|
||||
|
||||
user_info_url = "https://api.github.com/user"
|
||||
"""获取用户信息的 API 地址"""
|
||||
|
||||
http_method = "POST"
|
||||
"""获取 token 的 HTTP 方法"""
|
||||
|
||||
def _parse_token_response(self, data: dict) -> GithubAccessToken:
|
||||
"""解析 GitHub token 响应"""
|
||||
return GithubAccessToken(
|
||||
access_token=data.get('access_token'),
|
||||
token_type=data.get('token_type'),
|
||||
scope=data.get('scope'),
|
||||
)
|
||||
|
||||
def _build_user_info_headers(self, access_token: str) -> dict:
|
||||
"""构建 GitHub 用户信息请求头"""
|
||||
return {
|
||||
'accept': 'application/json',
|
||||
'Authorization': f'token {access_token}',
|
||||
}
|
||||
|
||||
def _parse_user_response(self, data: dict) -> GithubUserInfoResponse:
|
||||
"""解析 GitHub 用户信息响应"""
|
||||
return GithubUserInfoResponse(
|
||||
code='0' if data.get('login') else '1',
|
||||
user_data=GithubUserData(**data),
|
||||
)
|
||||
|
||||
@@ -1,7 +1,158 @@
|
||||
from pydantic import BaseModel
|
||||
"""QQ OAuth2.0 认证实现"""
|
||||
import aiohttp
|
||||
|
||||
async def get_access_token(
|
||||
from pydantic import BaseModel
|
||||
from . import AccessTokenBase, OAuthBase
|
||||
|
||||
|
||||
class QQAccessToken(AccessTokenBase):
|
||||
"""QQ 访问令牌响应"""
|
||||
expires_in: int
|
||||
"""access token 的有效期,单位为秒"""
|
||||
refresh_token: str
|
||||
"""用于刷新 access token 的令牌"""
|
||||
|
||||
|
||||
class QQOpenIDResponse(BaseModel):
|
||||
"""QQ OpenID 响应"""
|
||||
client_id: str
|
||||
"""应用的 appid"""
|
||||
openid: str
|
||||
"""用户的唯一标识"""
|
||||
|
||||
|
||||
class QQUserData(BaseModel):
|
||||
"""QQ 用户数据"""
|
||||
ret: int
|
||||
"""返回码,0 表示成功"""
|
||||
msg: str
|
||||
"""返回信息"""
|
||||
nickname: str | None
|
||||
"""用户昵称"""
|
||||
gender: str | None
|
||||
"""性别"""
|
||||
figureurl: str | None
|
||||
"""头像 URL"""
|
||||
figureurl_1: str | None
|
||||
"""头像 URL(大图)"""
|
||||
figureurl_2: str | None
|
||||
"""头像 URL(更大图)"""
|
||||
figureurl_qq_1: str | None
|
||||
"""QQ 头像 URL(大图)"""
|
||||
figureurl_qq_2: str | None
|
||||
"""QQ 头像 URL(更大图)"""
|
||||
is_yellow_vip: str | None
|
||||
"""是否黄钻用户"""
|
||||
vip: str | None
|
||||
"""是否 VIP 用户"""
|
||||
yellow_vip_level: str | None
|
||||
"""黄钻等级"""
|
||||
level: str | None
|
||||
"""等级"""
|
||||
is_yellow_year_vip: str | None
|
||||
"""是否年费黄钻"""
|
||||
|
||||
|
||||
class QQUserInfoResponse(BaseModel):
|
||||
"""QQ 用户信息响应"""
|
||||
code: str
|
||||
):
|
||||
...
|
||||
"""状态码"""
|
||||
openid: str
|
||||
"""用户 OpenID"""
|
||||
user_data: QQUserData
|
||||
"""用户数据"""
|
||||
|
||||
|
||||
class QQOAuth(OAuthBase):
|
||||
"""QQ OAuth2.0 客户端"""
|
||||
|
||||
access_token_url = "https://graph.qq.com/oauth2.0/token"
|
||||
"""获取 Access Token 的 API 地址"""
|
||||
|
||||
user_info_url = "https://graph.qq.com/user/get_user_info"
|
||||
"""获取用户信息的 API 地址"""
|
||||
|
||||
openid_url = "https://graph.qq.com/oauth2.0/me"
|
||||
"""获取 OpenID 的 API 地址"""
|
||||
|
||||
http_method = "GET"
|
||||
"""获取 token 的 HTTP 方法"""
|
||||
|
||||
async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken:
|
||||
"""
|
||||
通过 Authorization Code 获取 Access Token
|
||||
|
||||
Args:
|
||||
code: 授权码
|
||||
redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode
|
||||
|
||||
Returns:
|
||||
QQAccessToken: 访问令牌
|
||||
|
||||
文档:
|
||||
https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token
|
||||
"""
|
||||
params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'client_secret': self.client_secret,
|
||||
'code': code,
|
||||
'redirect_uri': redirect_uri,
|
||||
'fmt': 'json',
|
||||
'need_openid': 1,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url=self.access_token_url, params=params) as access_resp:
|
||||
access_data = await access_resp.json()
|
||||
return QQAccessToken(
|
||||
access_token=access_data.get('access_token'),
|
||||
expires_in=access_data.get('expires_in'),
|
||||
refresh_token=access_data.get('refresh_token'),
|
||||
)
|
||||
|
||||
async def get_openid(self, access_token: str) -> QQOpenIDResponse:
|
||||
"""
|
||||
获取用户 OpenID
|
||||
|
||||
注意:如果在 get_access_token 时传入了 need_openid=1,响应中已包含 openid,
|
||||
无需额外调用此接口。此函数用于单独获取 openid 的场景。
|
||||
|
||||
Args:
|
||||
access_token: 访问令牌
|
||||
|
||||
Returns:
|
||||
QQOpenIDResponse: 包含 client_id 和 openid
|
||||
|
||||
文档:
|
||||
https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
url=self.openid_url,
|
||||
params={
|
||||
'access_token': access_token,
|
||||
'fmt': 'json',
|
||||
},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
return QQOpenIDResponse(
|
||||
client_id=data.get('client_id'),
|
||||
openid=data.get('openid'),
|
||||
)
|
||||
|
||||
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
|
||||
"""构建 QQ 用户信息请求参数"""
|
||||
return {
|
||||
'access_token': access_token,
|
||||
'oauth_consumer_key': kwargs.get('app_id', self.client_id),
|
||||
'openid': kwargs.get('openid', ''),
|
||||
}
|
||||
|
||||
def _parse_user_response(self, data: dict) -> QQUserInfoResponse:
|
||||
"""解析 QQ 用户信息响应"""
|
||||
return QQUserInfoResponse(
|
||||
code='0' if data.get('ret') == 0 else str(data.get('ret')),
|
||||
openid=data.get('openid', ''),
|
||||
user_data=QQUserData(**data),
|
||||
)
|
||||
|
||||
@@ -25,12 +25,12 @@ async def login(
|
||||
# TODO: 验证码校验
|
||||
# captcha_setting = await Setting.get(
|
||||
# session,
|
||||
# and_(Setting.type == "auth", Setting.name == "login_captcha")
|
||||
# (Setting.type == "auth") & (Setting.name == "login_captcha")
|
||||
# )
|
||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||
|
||||
# 获取用户信息
|
||||
current_user = await User.get(session, User.username == login_request.username, fetch_mode="first")
|
||||
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
|
||||
|
||||
# 验证用户是否存在
|
||||
if not current_user:
|
||||
|
||||
23
tests/fixtures/groups.py
vendored
23
tests/fixtures/groups.py
vendored
@@ -42,10 +42,9 @@ class GroupFactory:
|
||||
speed_limit=kwargs.get("speed_limit", 0),
|
||||
)
|
||||
|
||||
group = await group.save(session)
|
||||
|
||||
# 如果提供了选项参数,创建 GroupOptions
|
||||
if kwargs.get("create_options", False):
|
||||
group = await group.save(session, commit=False)
|
||||
options = GroupOptions(
|
||||
group_id=group.id,
|
||||
share_download=kwargs.get("share_download", True),
|
||||
@@ -55,7 +54,10 @@ class GroupFactory:
|
||||
select_node=kwargs.get("select_node", False),
|
||||
advance_delete=kwargs.get("advance_delete", False),
|
||||
)
|
||||
await options.save(session)
|
||||
await options.save(session, commit=False)
|
||||
await session.commit()
|
||||
else:
|
||||
group = await group.save(session)
|
||||
|
||||
return group
|
||||
|
||||
@@ -88,7 +90,7 @@ class GroupFactory:
|
||||
speed_limit=0,
|
||||
)
|
||||
|
||||
admin_group = await admin_group.save(session)
|
||||
admin_group = await admin_group.save(session, commit=False)
|
||||
|
||||
# 创建管理员组选项
|
||||
admin_options = GroupOptions(
|
||||
@@ -105,7 +107,8 @@ class GroupFactory:
|
||||
aria2=True,
|
||||
redirected_source=True,
|
||||
)
|
||||
await admin_options.save(session)
|
||||
await admin_options.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
return admin_group
|
||||
|
||||
@@ -140,7 +143,7 @@ class GroupFactory:
|
||||
speed_limit=1024, # 1MB/s
|
||||
)
|
||||
|
||||
limited_group = await limited_group.save(session)
|
||||
limited_group = await limited_group.save(session, commit=False)
|
||||
|
||||
# 创建限制组选项
|
||||
limited_options = GroupOptions(
|
||||
@@ -152,7 +155,8 @@ class GroupFactory:
|
||||
select_node=False,
|
||||
advance_delete=False,
|
||||
)
|
||||
await limited_options.save(session)
|
||||
await limited_options.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
return limited_group
|
||||
|
||||
@@ -185,7 +189,7 @@ class GroupFactory:
|
||||
speed_limit=512, # 512KB/s
|
||||
)
|
||||
|
||||
free_group = await free_group.save(session)
|
||||
free_group = await free_group.save(session, commit=False)
|
||||
|
||||
# 创建免费组选项
|
||||
free_options = GroupOptions(
|
||||
@@ -197,6 +201,7 @@ class GroupFactory:
|
||||
select_node=False,
|
||||
advance_delete=False,
|
||||
)
|
||||
await free_options.save(session)
|
||||
await free_options.save(session, commit=False)
|
||||
await session.commit()
|
||||
|
||||
return free_group
|
||||
|
||||
@@ -13,7 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(
|
||||
refreshUrl="/api/v1/user/session/refresh",
|
||||
)
|
||||
|
||||
SECRET_KEY = ''
|
||||
SECRET_KEY: str = ''
|
||||
|
||||
|
||||
async def load_secret_key() -> None:
|
||||
@@ -26,10 +26,10 @@ async def load_secret_key() -> None:
|
||||
|
||||
global SECRET_KEY
|
||||
async for session in get_session():
|
||||
setting = await Setting.get(
|
||||
setting: Setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == "auth") & (Setting.name == "secret_key")
|
||||
)
|
||||
) # type: ignore
|
||||
if setting:
|
||||
SECRET_KEY = setting.value
|
||||
|
||||
@@ -40,7 +40,14 @@ def build_token_payload(
|
||||
algorithm: str,
|
||||
expires_delta: timedelta | None = None,
|
||||
) -> tuple[str, datetime]:
|
||||
"""构建令牌"""
|
||||
"""
|
||||
构建令牌。
|
||||
|
||||
:param data: 需要放进 JWT Payload 的字段
|
||||
:param is_refresh: 是否为刷新令牌
|
||||
:param algorithm: JWT 签名算法
|
||||
:param expires_delta: 过期时间
|
||||
"""
|
||||
|
||||
to_encode = data.copy()
|
||||
|
||||
@@ -61,8 +68,11 @@ def build_token_payload(
|
||||
|
||||
|
||||
# 访问令牌
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256") -> AccessTokenBase:
|
||||
def create_access_token(
|
||||
data: dict,
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
) -> AccessTokenBase:
|
||||
"""
|
||||
生成访问令牌,默认有效期 3 小时。
|
||||
|
||||
@@ -73,7 +83,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
||||
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
|
||||
"""
|
||||
|
||||
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
|
||||
access_token, expire_at = build_token_payload(
|
||||
data,
|
||||
False,
|
||||
algorithm,
|
||||
expires_delta
|
||||
)
|
||||
return AccessTokenBase(
|
||||
access_token=access_token,
|
||||
access_expires=expire_at,
|
||||
@@ -81,8 +96,11 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
|
||||
|
||||
|
||||
# 刷新令牌
|
||||
def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256") -> RefreshTokenBase:
|
||||
def create_refresh_token(
|
||||
data: dict,
|
||||
expires_delta: timedelta | None = None,
|
||||
algorithm: str = "HS256"
|
||||
) -> RefreshTokenBase:
|
||||
"""
|
||||
生成刷新令牌,默认有效期 30 天。
|
||||
|
||||
@@ -93,7 +111,12 @@ def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
|
||||
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
|
||||
"""
|
||||
|
||||
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
|
||||
refresh_token, expire_at = build_token_payload(
|
||||
data,
|
||||
True,
|
||||
algorithm,
|
||||
expires_delta
|
||||
)
|
||||
return RefreshTokenBase(
|
||||
refresh_token=refresh_token,
|
||||
refresh_expires=expire_at,
|
||||
|
||||
@@ -13,7 +13,7 @@ license_info = {"name": "GPLv3", "url": "https://opensource.org/license/gpl-3.0"
|
||||
BackendVersion = "0.0.1"
|
||||
"""后端版本"""
|
||||
|
||||
IsPro = False
|
||||
IsPro: bool = False
|
||||
|
||||
mode: str = os.getenv('MODE', 'master')
|
||||
"""运行模式"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import secrets
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from argon2 import PasswordHasher
|
||||
@@ -11,7 +12,23 @@ from pydantic import BaseModel, Field
|
||||
from utils.JWT import SECRET_KEY
|
||||
from utils.conf import appmeta
|
||||
|
||||
_ph = PasswordHasher()
|
||||
# FIRST RECOMMENDED option per RFC 9106.
|
||||
_ph_lowmem = PasswordHasher(
|
||||
salt_len=16,
|
||||
hash_len=32,
|
||||
time_cost=3,
|
||||
memory_cost=65536, # 64 MiB
|
||||
parallelism=4,
|
||||
)
|
||||
|
||||
# SECOND RECOMMENDED option per RFC 9106.
|
||||
_ph_highmem = PasswordHasher(
|
||||
salt_len=16,
|
||||
hash_len=32,
|
||||
time_cost=1,
|
||||
memory_cost=2097152, # 2 GiB
|
||||
parallelism=4,
|
||||
)
|
||||
|
||||
class PasswordStatus(StrEnum):
|
||||
"""密码校验状态枚举"""
|
||||
@@ -48,7 +65,8 @@ class Password:
|
||||
|
||||
@staticmethod
|
||||
def generate(
|
||||
length: int = 8
|
||||
length: int = 8,
|
||||
url_safe: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
生成指定长度的随机密码。
|
||||
@@ -58,7 +76,16 @@ class Password:
|
||||
:return: 随机密码
|
||||
:rtype: str
|
||||
"""
|
||||
return secrets.token_hex(length)
|
||||
if url_safe:
|
||||
return secrets.token_urlsafe(length)
|
||||
else:
|
||||
return secrets.token_hex(length)
|
||||
|
||||
@staticmethod
|
||||
def generate_hex(
|
||||
length: int = 8
|
||||
) -> bytes:
|
||||
return secrets.token_bytes(length)
|
||||
|
||||
@staticmethod
|
||||
def hash(
|
||||
@@ -72,7 +99,7 @@ class Password:
|
||||
:param password: 需要哈希的原始密码
|
||||
:return: Argon2 哈希字符串
|
||||
"""
|
||||
return _ph.hash(password)
|
||||
return _ph_lowmem.hash(password)
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
@@ -87,21 +114,16 @@ class Password:
|
||||
:return: 如果密码匹配返回 True, 否则返回 False
|
||||
"""
|
||||
try:
|
||||
# verify 函数会自动解析 stored_password 中的盐和参数
|
||||
_ph.verify(hash, password)
|
||||
_ph_lowmem.verify(hash, password)
|
||||
|
||||
# 检查哈希参数是否已过时。如果返回True,
|
||||
# 意味着你应该使用新的参数重新哈希密码并更新存储。
|
||||
# 这是一个很好的实践,可以随着时间推移增强安全性。
|
||||
if _ph.check_needs_rehash(hash):
|
||||
# 检查哈希参数是否已过时
|
||||
if _ph_lowmem.check_needs_rehash(hash):
|
||||
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
|
||||
return PasswordStatus.EXPIRED
|
||||
|
||||
return PasswordStatus.VALID
|
||||
except VerifyMismatchError:
|
||||
# 这是预期的异常,当密码不匹配时触发。
|
||||
return PasswordStatus.INVALID
|
||||
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
||||
|
||||
@staticmethod
|
||||
async def generate_totp(
|
||||
|
||||
Reference in New Issue
Block a user