Refactor and enhance OAuth2.0 implementation; update models and routes

- Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity.
- Added OAUTH type to SettingsType enum.
- Cleaned up imports in webdav.py.
- Updated admin router to improve summary data retrieval and response handling.
- Enhanced file management routes with better condition handling and user storage updates.
- Improved group management routes by optimizing data retrieval.
- Refined task management routes for better condition handling.
- Updated user management routes to streamline access token retrieval.
- Implemented a new captcha verification structure with abstract base class.
- Removed deprecated env.md file and replaced with a new structured version.
- Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations.
- Enhanced password management with improved hashing strategies.
- Added detailed comments and documentation throughout the codebase for clarity.
This commit is contained in:
2026-01-12 18:07:44 +08:00
parent 61ddc96f17
commit d2c914cff8
29 changed files with 814 additions and 4609 deletions

View File

@@ -0,0 +1,9 @@
{
"permissions": {
"allow": [
"Bash(git rev-parse:*)",
"Bash(findstr:*)",
"Bash(find:*)"
]
}
}

4407
.xml

File diff suppressed because it is too large Load Diff

View File

@@ -9,4 +9,4 @@
- `REDIS_PORT`: Redis 端口 - `REDIS_PORT`: Redis 端口
- `REDIS_PASSWORD`: Redis 密码 - `REDIS_PASSWORD`: Redis 密码
- `REDIS_DB`: Redis 数据库 - `REDIS_DB`: Redis 数据库
- `REDIS_PROTOCOL` - `REDIS_PROTOCOL`: Redis 协议

16
main.py
View File

@@ -33,8 +33,22 @@ app = FastAPI(
openapi_url="/openapi.json" if appmeta.debug else None, openapi_url="/openapi.json" if appmeta.debug else None,
) )
# 添加跨域 CORS 中间件,仅在调试模式下启用,以允许所有来源访问 API
if appmeta.debug:
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def handle_unexpected_exceptions(request: Request, exc: Exception) -> NoReturn: async def handle_unexpected_exceptions(
request: Request,
exc: Exception
) -> NoReturn:
""" """
捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。 捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。
""" """

View File

@@ -60,10 +60,10 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
try: try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"]) payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
if payload.get("type") != "download": if payload.get("type") != "download":
return None http_exceptions.raise_unauthorized("Download token required")
jti = payload.get("jti") jti = payload.get("jti")
if not jti: if not jti:
return None http_exceptions.raise_unauthorized("Download token required")
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"]) return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): except jwt.InvalidTokenError:
return None http_exceptions.raise_unauthorized("Download token required")

View File

@@ -120,7 +120,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None: async def init_default_settings() -> None:
from .setting import Setting from .setting import Setting
from .database import get_session from .database import get_session
from sqlalchemy import and_
log.info('初始化设置...') log.info('初始化设置...')
@@ -128,7 +127,7 @@ async def init_default_settings() -> None:
# 检查是否已经存在版本设置 # 检查是否已经存在版本设置
ver = await Setting.get( ver = await Setting.get(
session, session,
and_(Setting.type == SettingsType.VERSION, Setting.name == f"db_version_{BackendVersion}") (Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
) )
if ver and ver.value == "installed": if ver and ver.value == "installed":
return return

View File

@@ -218,7 +218,7 @@ class TableBaseMixin(AsyncAttrs):
) )
@classmethod @classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]: async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
""" """
向数据库中添加一个新的或多个新的记录. 向数据库中添加一个新的或多个新的记录.
@@ -230,6 +230,8 @@ class TableBaseMixin(AsyncAttrs):
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表. instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True. refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns: Returns:
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例. T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
@@ -244,6 +246,11 @@ class TableBaseMixin(AsyncAttrs):
# 添加单个实例 # 添加单个实例
item3 = Item(name="Cherry") item3 = Item(name="Cherry")
added_item = await Item.add(session, item3) added_item = await Item.add(session, item3)
# 批量操作,减少提交次数
await Item.add(session, [item1, item2], commit=False)
await Item.add(session, [item3, item4], commit=False)
await session.commit()
""" """
is_list = False is_list = False
if isinstance(instances, list): if isinstance(instances, list):
@@ -252,7 +259,10 @@ class TableBaseMixin(AsyncAttrs):
else: else:
session.add(instances) session.add(instances)
await session.commit() if commit:
await session.commit()
else:
await session.flush()
if refresh: if refresh:
if is_list: if is_list:
@@ -266,15 +276,16 @@ class TableBaseMixin(AsyncAttrs):
async def save( async def save(
self: T, self: T,
session: AsyncSession, session: AsyncSession,
load: RelationshipInfo | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True refresh: bool = True,
commit: bool = True
) -> T: ) -> T:
""" """
保存(插入或更新)当前模型实例到数据库. 保存(插入或更新)当前模型实例到数据库.
这是一个实例方法,它将当前对象添加到会话中并提交更改。 这是一个实例方法,它将当前对象添加到会话中并提交更改。
可以用于创建新记录或更新现有记录。还可以选择在保存后 可以用于创建新记录或更新现有记录。还可以选择在保存后
预加载eager load一个关联关系. 预加载eager load一个或多个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired **重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值: 如果需要继续使用该对象,必须使用返回值:
@@ -287,6 +298,14 @@ class TableBaseMixin(AsyncAttrs):
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False) await client.save(session, refresh=False)
# ✅ 正确:批量操作,减少提交次数
await item1.save(session, commit=False)
await item2.save(session, commit=False)
await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系
user = await user.save(session, load=[User.group, User.tags])
# ❌ 错误:需要返回值但未使用 # ❌ 错误:需要返回值但未使用
await client.save(session) await client.save(session)
return client # client 对象已过期 return client # client 对象已过期
@@ -294,16 +313,22 @@ class TableBaseMixin(AsyncAttrs):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性. load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
例如 `User.posts`. 可以是单个关系或关系列表.
例如 `User.posts` 或 `[User.group, User.tags]`.
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值, refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True. 设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns: Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self. T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
""" """
session.add(self) session.add(self)
await session.commit() if commit:
await session.commit()
else:
await session.flush()
if not refresh: if not refresh:
return self return self
@@ -324,8 +349,9 @@ class TableBaseMixin(AsyncAttrs):
extra_data: dict[str, Any] | None = None, extra_data: dict[str, Any] | None = None,
exclude_unset: bool = True, exclude_unset: bool = True,
exclude: set[str] | None = None, exclude: set[str] | None = None,
load: RelationshipInfo | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True refresh: bool = True,
commit: bool = True
) -> T: ) -> T:
""" """
使用另一个模型实例或字典中的数据来更新当前实例. 使用另一个模型实例或字典中的数据来更新当前实例.
@@ -348,6 +374,14 @@ class TableBaseMixin(AsyncAttrs):
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False) await client.update(session, update_data, refresh=False)
# ✅ 正确:批量操作,减少提交次数
await user1.update(session, data1, commit=False)
await user2.update(session, data2, commit=False)
await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系
user = await user.update(session, data, load=[User.group, User.tags])
# ❌ 错误:需要返回值但未使用 # ❌ 错误:需要返回值但未使用
await client.update(session, update_data) await client.update(session, update_data)
return client # client 对象已过期 return client # client 对象已过期
@@ -360,10 +394,13 @@ class TableBaseMixin(AsyncAttrs):
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供) exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
的字段将被忽略. 默认为 True. 的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}. exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性. load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
例如 `User.permission`. 可以是单个关系或关系列表.
例如 `User.permission` 或 `[User.group, User.tags]`.
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值, refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True. 设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns: Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self. T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
@@ -374,7 +411,10 @@ class TableBaseMixin(AsyncAttrs):
) )
session.add(self) session.add(self)
await session.commit() if commit:
await session.commit()
else:
await session.flush()
if not refresh: if not refresh:
return self return self
@@ -388,33 +428,82 @@ class TableBaseMixin(AsyncAttrs):
return self return self
@classmethod @classmethod
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None: async def delete(
cls: type[T],
session: AsyncSession,
instances: T | list[T] | None = None,
*,
condition: BinaryExpression | ClauseElement | None = None,
commit: bool = True
) -> int:
""" """
从数据库中删除一个或多个记录. 从数据库中删除记录.
支持两种删除方式:
1. 实例删除:传入 instances 参数,先加载再删除
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要删除的单个模型实例或模型实例列表. instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns: Returns:
None int: 删除的记录数量
Usage: Usage:
# 实例删除
item_to_delete = await Item.get(session, Item.id == 1) item_to_delete = await Item.get(session, Item.id == 1)
if item_to_delete: if item_to_delete:
await Item.delete(session, item_to_delete) deleted_count = await Item.delete(session, item_to_delete)
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all") # 条件删除(更高效,无需加载实例)
if items_to_delete: deleted_count = await Item.delete(
await Item.delete(session, items_to_delete) session,
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
)
# 批量删除后手动提交
await Item.delete(session, item1, commit=False)
await Item.delete(session, item2, commit=False)
await session.commit()
""" """
# 条件删除模式
if condition is not None:
from sqlmodel import delete as sql_delete
if instances is not None:
raise ValueError("不能同时指定 instances 和 condition")
# 执行条件删除
stmt = sql_delete(cls).where(condition)
result = await session.exec(stmt)
deleted_count = result.rowcount
if commit:
await session.commit()
return deleted_count
# 实例删除模式(原有逻辑)
if instances is None:
raise ValueError("必须指定 instances 或 condition")
deleted_count = 0
if isinstance(instances, list): if isinstance(instances, list):
for instance in instances: for instance in instances:
await session.delete(instance) await session.delete(instance)
deleted_count += 1
else: else:
await session.delete(instances) await session.delete(instances)
deleted_count = 1
await session.commit() if commit:
await session.commit()
return deleted_count
@classmethod @classmethod
def _build_time_filters( def _build_time_filters(
@@ -458,7 +547,7 @@ class TableBaseMixin(AsyncAttrs):
fetch_mode: Literal["one", "first", "all"] = "first", fetch_mode: Literal["one", "first", "all"] = "first",
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None, options: list | None = None,
load: RelationshipInfo | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None, order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None, filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False, with_for_update: bool = False,
@@ -491,8 +580,9 @@ class TableBaseMixin(AsyncAttrs):
例如 `User` 或 `(Profile, User.id == Profile.user_id)`. 例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据, options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`. 例如 `[selectinload(User.posts)]`.
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系. load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
例如 `User.profile`. 可以是单个关系或关系列表.
例如 `User.profile` 或 `[User.group, User.tags]`.
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表. order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`. 例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件. filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
@@ -595,9 +685,15 @@ class TableBaseMixin(AsyncAttrs):
statement = statement.options(*options) statement = statement.options(*options)
if load: if load:
# 标准化为列表
load_list = load if isinstance(load, list) else [load]
# 处理多态加载 # 处理多态加载
if load_polymorphic is not None: if load_polymorphic is not None:
target_class = load.property.mapper.class_ # 多态加载只支持单个关系
if len(load_list) > 1:
raise ValueError("load_polymorphic 仅支持单个关系")
target_class = load_list[0].property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin # 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin): if not issubclass(target_class, PolymorphicBaseMixin):
@@ -609,7 +705,7 @@ class TableBaseMixin(AsyncAttrs):
if load_polymorphic == 'all': if load_polymorphic == 'all':
# 两阶段查询:获取实际关联的多态类型 # 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses( subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, load, target_class session, condition, load_list[0], target_class
) )
else: else:
subclasses_to_load = load_polymorphic subclasses_to_load = load_polymorphic
@@ -618,12 +714,14 @@ class TableBaseMixin(AsyncAttrs):
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项 # 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options( statement = statement.options(
selectinload(load).selectin_polymorphic(subclasses_to_load) selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
) )
else: else:
statement = statement.options(selectinload(load)) statement = statement.options(selectinload(load_list[0]))
else: else:
statement = statement.options(selectinload(load)) # 为每个关系添加 selectinload
for rel in load_list:
statement = statement.options(selectinload(rel))
if order_by is not None: if order_by is not None:
statement = statement.order_by(*order_by) statement = statement.order_by(*order_by)
@@ -796,7 +894,7 @@ class TableBaseMixin(AsyncAttrs):
*, *,
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None, options: list | None = None,
load: RelationshipInfo | None = None, load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None, order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None, filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None, table_view: TableViewRequest | None = None,
@@ -865,7 +963,7 @@ class TableBaseMixin(AsyncAttrs):
return ListResponse(count=total_count, items=items) return ListResponse(count=total_count, items=items)
@classmethod @classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T: async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
""" """
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常. 根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
@@ -875,7 +973,8 @@ class TableBaseMixin(AsyncAttrs):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID. id (int): 要查找的记录的主键 ID.
load (Relationship | None): 可选的,用于预加载的关联属性. load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
可以是单个关系或关系列表.
Returns: Returns:
T: 找到的模型实例. T: 找到的模型实例.
@@ -903,7 +1002,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
@override @override
@classmethod @classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T: async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
""" """
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常. 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
@@ -913,7 +1012,8 @@ class UUIDTableBaseMixin(TableBaseMixin):
Args: Args:
session (AsyncSession): 用于数据库操作的异步会话对象. session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键. id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | None): 可选的,用于预加载的关联属性. load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
可以是单个关系或关系列表.
Returns: Returns:
T: 找到的模型实例. T: 找到的模型实例.

View File

@@ -79,9 +79,8 @@ class VersionInfo(SQLModelBase):
commit: str commit: str
"""提交哈希""" """提交哈希"""
class AdminSummaryResponse(ResponseBase):
class AdminSummaryData(SQLModelBase): """管理员概况响应"""
"""管理员概况数据"""
metrics_summary: MetricsSummary metrics_summary: MetricsSummary
"""统计摘要""" """统计摘要"""
@@ -95,13 +94,6 @@ class AdminSummaryData(SQLModelBase):
version: VersionInfo version: VersionInfo
"""版本信息""" """版本信息"""
class AdminSummaryResponse(ResponseBase):
"""管理员概况响应"""
data: AdminSummaryData | None = None
"""响应数据"""
class MCPMethod(StrEnum): class MCPMethod(StrEnum):
"""MCP 方法枚举""" """MCP 方法枚举"""

View File

@@ -104,6 +104,7 @@ class SettingsType(StrEnum):
MAIL = "mail" MAIL = "mail"
MAIL_TEMPLATE = "mail_template" MAIL_TEMPLATE = "mail_template"
MOBILE = "mobile" MOBILE = "mobile"
OAUTH = "oauth"
PATH = "path" PATH = "path"
PREVIEW = "preview" PREVIEW = "preview"
PWA = "pwa" PWA = "pwa"

View File

@@ -2,7 +2,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from uuid import UUID from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase from .base import SQLModelBase
from .mixin import TableBaseMixin from .mixin import TableBaseMixin

View File

@@ -2,14 +2,12 @@ from datetime import datetime, timedelta
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from loguru import logger as l from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ( from models import (
User, ResponseBase, User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo, Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
AdminSummaryData,
) )
from models.base import SQLModelBase from models.base import SQLModelBase
from models.setting import ( from models.setting import (
@@ -75,8 +73,8 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
Returns: Returns:
AdminSummaryResponse: 包含站点概况信息的响应模型。 AdminSummaryResponse: 包含站点概况信息的响应模型。
""" """
# 统计最近 12 天的数据 # 统计最近 14 天的数据
days_count = 12 days_count = 14
now = datetime.now() now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
@@ -135,7 +133,7 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
site_urls: list[str] = [] site_urls: list[str] = []
site_url_setting = await Setting.get( site_url_setting = await Setting.get(
session, session,
and_(Setting.type == SettingsType.BASIC, Setting.name == "siteURL"), (Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
) )
if site_url_setting and site_url_setting.value: if site_url_setting and site_url_setting.value:
site_urls.append(site_url_setting.value) site_urls.append(site_url_setting.value)
@@ -156,15 +154,13 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
commit="dev", commit="dev",
) )
data = AdminSummaryData( return AdminSummaryResponse(
metrics_summary=metrics_summary, metrics_summary=metrics_summary,
site_urls=site_urls, site_urls=site_urls,
license=license_info, license=license_info,
version=version_info, version=version_info,
) )
return AdminSummaryResponse(data=data)
@admin_router.get( @admin_router.get(
path='/news', path='/news',
summary='获取社区新闻', summary='获取社区新闻',
@@ -203,7 +199,7 @@ async def router_admin_update_settings(
for item in request.settings: for item in request.settings:
existing = await Setting.get( existing = await Setting.get(
session, session,
and_(Setting.type == item.type, Setting.name == item.name) (Setting.type == item.type) & (Setting.name == item.name)
) )
if existing: if existing:
@@ -245,7 +241,12 @@ async def router_admin_get_settings(
if name: if name:
conditions.append(Setting.name == name) conditions.append(Setting.name == name)
condition = and_(*conditions) if conditions else None if conditions:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = None
settings = await Setting.get(session, condition, fetch_mode="all") settings = await Setting.get(session, condition, fetch_mode="all")

View File

@@ -5,7 +5,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from loguru import logger as l from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -51,7 +50,12 @@ async def router_admin_get_file_list(
if keyword: if keyword:
conditions.append(Object.name.ilike(f"%{keyword}%")) conditions.append(Object.name.ilike(f"%{keyword}%"))
condition = and_(*conditions) if len(conditions) > 1 else conditions[0] if len(conditions) > 1:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = conditions[0]
result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner) result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner)
# 构建响应 # 构建响应
@@ -197,13 +201,15 @@ async def router_admin_delete_file(
except Exception as e: except Exception as e:
l.warning(f"删除物理文件失败: {e}") l.warning(f"删除物理文件失败: {e}")
# 更新用户存储量 # 更新用户存储量(使用 SQL UPDATE 直接更新,无需加载实例)
owner = await User.get(session, User.id == owner_id) from sqlmodel import update as sql_update
if owner: stmt = sql_update(User).where(User.id == owner_id).values(
owner.storage = max(0, owner.storage - file_size) storage=max(0, User.storage - file_size)
await owner.save(session) )
await session.exec(stmt)
await Object.delete(session, file_obj) # 使用条件删除
await Object.delete(session, condition=Object.id == file_obj.id)
l.info(f"管理员删除了文件: {file_name}") l.info(f"管理员删除了文件: {file_name}")
return ResponseBase(data={"deleted": True}) return ResponseBase(data={"deleted": True})

View File

@@ -63,12 +63,13 @@ async def router_admin_get_group(
:param group_id: 用户组UUID :param group_id: 用户组UUID
:return: 用户组详情 :return: 用户组详情
""" """
group = await Group.get(session, Group.id == group_id, load=Group.options) group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
if not group: if not group:
raise HTTPException(status_code=404, detail="用户组不存在") raise HTTPException(status_code=404, detail="用户组不存在")
policies = await group.awaitable_attrs.policies # 直接访问已加载的关系,无需额外查询
policies = group.policies
user_count = await User.count(session, User.group_id == group_id) user_count = await User.count(session, User.group_id == group_id)
response = GroupDetailResponse.from_group(group, user_count, policies) response = GroupDetailResponse.from_group(group, user_count, policies)

View File

@@ -2,7 +2,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -43,7 +42,12 @@ async def router_admin_get_task_list(
if status: if status:
conditions.append(Task.status == status) conditions.append(Task.status == status)
condition = and_(*conditions) if conditions else None if conditions:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = None
result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user) result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user)
items: list[TaskSummary] = [] items: list[TaskSummary] = []

View File

@@ -2,7 +2,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l from loguru import logger as l
from sqlalchemy import func, and_ from sqlalchemy import func
from middleware.auth import admin_required from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -198,7 +198,7 @@ async def router_admin_calibrate_storage(
from sqlmodel import select from sqlmodel import select
result = await session.execute( result = await session.execute(
select(func.sum(Object.size), func.count(Object.id)).where( select(func.sum(Object.size), func.count(Object.id)).where(
and_(Object.owner_id == user_id, Object.type == ObjectType.FILE) (Object.owner_id == user_id) & (Object.type == ObjectType.FILE)
) )
) )
row = result.one() row = result.one()

View File

@@ -233,7 +233,7 @@ async def upload_chunk(
policy_id=upload_session.policy_id, policy_id=upload_session.policy_id,
reference_count=1, reference_count=1,
) )
physical_file = await physical_file.save(session) physical_file = await physical_file.save(session, commit=False)
# 创建 Object 记录 # 创建 Object 记录
file_object = Object( file_object = Object(
@@ -246,11 +246,18 @@ async def upload_chunk(
owner_id=user_id, owner_id=user_id,
policy_id=upload_session.policy_id, policy_id=upload_session.policy_id,
) )
file_object = await file_object.save(session) file_object = await file_object.save(session, commit=False)
file_object_id = file_object.id file_object_id = file_object.id
# 删除上传会话 # 删除上传会话(使用条件删除)
await UploadSession.delete(session, upload_session) await UploadSession.delete(
session,
condition=UploadSession.id == upload_session.id,
commit=False
)
# 统一提交所有更改
await session.commit()
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}") l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")

View File

@@ -1,5 +1,4 @@
from fastapi import APIRouter from fastapi import APIRouter
from sqlalchemy import and_
from middleware.dependencies import SessionDep from middleware.dependencies import SessionDep
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
@@ -55,5 +54,5 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
dict: The site configuration. dict: The site configuration.
""" """
return SiteConfigResponse( return SiteConfigResponse(
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")), title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
) )

View File

@@ -3,7 +3,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy import and_
from webauthn import generate_registration_options from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict from webauthn.helpers import options_to_json_dict
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
@@ -115,7 +114,7 @@ async def router_user_register(
# 2. 获取默认用户组(从设置中读取 UUID # 2. 获取默认用户组(从设置中读取 UUID
default_group_setting: models.Setting | None = await models.Setting.get( default_group_setting: models.Setting | None = await models.Setting.get(
session, session,
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group") (models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
) )
if default_group_setting is None or not default_group_setting.value: if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在") logger.error("默认用户组不存在")
@@ -352,18 +351,18 @@ async def router_user_authn_start(
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等 # TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await models.Setting.get( authn_setting = await models.Setting.get(
session, session,
and_(models.Setting.type == "authn", models.Setting.name == "authn_enabled") (models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
) )
if not authn_setting or authn_setting.value != "1": if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled") raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
site_url_setting = await models.Setting.get( site_url_setting = await models.Setting.get(
session, session,
and_(models.Setting.type == "basic", models.Setting.name == "siteURL") (models.Setting.type == "basic") & (models.Setting.name == "siteURL")
) )
site_title_setting = await models.Setting.get( site_title_setting = await models.Setting.get(
session, session,
and_(models.Setting.type == "basic", models.Setting.name == "siteTitle") (models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
) )
options = generate_registration_options( options = generate_registration_options(

View File

@@ -1,5 +1,39 @@
import abc
import aiohttp
from pydantic import BaseModel from pydantic import BaseModel
from .gcaptcha import GCaptcha
from .turnstile import TurnstileCaptcha
class CaptchaRequestBase(BaseModel): class CaptchaRequestBase(BaseModel):
"""验证码验证请求"""
token: str token: str
secret: str """验证 token"""
secret: str
"""验证密钥"""
class CaptchaBase(abc.ABC):
"""验证码验证器抽象基类"""
verify_url: str
"""验证 API 地址(子类必须定义)"""
async def verify_captcha(self, request: CaptchaRequestBase) -> bool:
"""
验证 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(self.verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

View File

@@ -1,21 +1,7 @@
import aiohttp from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool: class GCaptcha(CaptchaBase):
""" """Google reCAPTCHA v2/v3 验证器"""
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
verify_url = "https://www.google.com/recaptcha/api/siteverify"
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

View File

@@ -1,21 +1,7 @@
import aiohttp from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool: class TurnstileCaptcha(CaptchaBase):
""" """Cloudflare Turnstile 验证器"""
验证 Turnstile 的 token 是否有效。
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

223
service/oauth/__init__.py Normal file
View File

@@ -0,0 +1,223 @@
"""
OAuth2.0 认证模块
提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。
"""
import abc
import aiohttp
from pydantic import BaseModel
# ==================== 共享数据模型 ====================
class AccessTokenBase(BaseModel):
"""访问令牌基类"""
access_token: str
"""访问令牌"""
class OAuthUserData(BaseModel):
"""OAuth 用户数据通用 DTO"""
openid: str
"""用户唯一标识GitHub 为 idQQ 为 openid"""
nickname: str | None
"""用户昵称"""
avatar_url: str | None
"""头像 URL"""
email: str | None
"""邮箱"""
bio: str | None
"""个人简介"""
class OAuthUserInfoResponse(BaseModel):
"""OAuth 用户信息响应"""
code: str
"""状态码"""
openid: str
"""用户唯一标识"""
user_data: OAuthUserData
"""用户数据"""
# ==================== OAuth2.0 抽象基类 ====================
class OAuthBase(abc.ABC):
"""
OAuth2.0 客户端抽象基类
子类需要定义以下类属性:
- access_token_url: 获取 Access Token 的 API 地址
- user_info_url: 获取用户信息的 API 地址
- http_method: 获取 token 的 HTTP 方法POST 或 GET
"""
# 子类必须定义的类属性
access_token_url: str
"""获取 Access Token 的 API 地址"""
user_info_url: str
"""获取用户信息的 API 地址"""
http_method: str = "POST"
"""获取 token 的 HTTP 方法POST 或 GET"""
# 实例属性(构造函数传入)
client_id: str
client_secret: str
def __init__(self, client_id: str, client_secret: str) -> None:
"""
初始化 OAuth 客户端
Args:
client_id: 应用 client_id
client_secret: 应用 client_secret
"""
self.client_id = client_id
self.client_secret = client_secret
async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
**kwargs: 额外参数(如 QQ 需要 redirect_uri
Returns:
AccessTokenBase: 访问令牌
"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
}
params.update(kwargs)
async with aiohttp.ClientSession() as session:
if self.http_method == "POST":
async with session.post(
url=self.access_token_url,
params=params,
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
else:
async with session.get(
url=self.access_token_url,
params=params,
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
async def get_user_info(
self,
access_token: str | AccessTokenBase,
**kwargs
) -> OAuthUserInfoResponse:
"""
获取用户信息
Args:
access_token: 访问令牌
**kwargs: 额外参数(如 QQ 需要 app_id, openid
Returns:
OAuthUserInfoResponse: 用户信息
"""
if isinstance(access_token, AccessTokenBase):
access_token = access_token.access_token
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.user_info_url,
params=self._build_user_info_params(access_token, **kwargs),
headers=self._build_user_info_headers(access_token),
) as resp:
user_data = await resp.json()
return self._parse_user_response(user_data)
# ==================== 钩子方法(子类可覆盖) ====================
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""
构建获取用户信息的请求参数
Args:
access_token: 访问令牌
**kwargs: 额外参数
Returns:
dict: 请求参数
"""
return {}
def _build_user_info_headers(self, access_token: str) -> dict:
"""
构建获取用户信息的请求头
Args:
access_token: 访问令牌
Returns:
dict: 请求头
"""
return {
'accept': 'application/json',
}
def _parse_token_response(self, data: dict) -> AccessTokenBase:
"""
解析 token 响应
Args:
data: API 返回的数据
Returns:
AccessTokenBase: 访问令牌
"""
return AccessTokenBase(access_token=data.get('access_token'))
def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse:
"""
解析用户信息响应
Args:
data: API 返回的数据
Returns:
OAuthUserInfoResponse: 用户信息
"""
return OAuthUserInfoResponse(
code='0',
openid='',
user_data=OAuthUserData(openid=''),
)
# ==================== 导出 ====================
from .github import GithubOAuth, GithubAccessToken, GithubUserData
from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData
__all__ = [
# 共享模型
'AccessTokenBase',
'OAuthUserData',
'OAuthUserInfoResponse',
'OAuthBase',
# GitHub
'GithubOAuth',
'GithubAccessToken',
'GithubUserData',
# QQ
'QQOAuth',
'QQAccessToken',
'QQOpenIDResponse',
'QQUserData',
]

View File

@@ -1,77 +1,127 @@
from pydantic import BaseModel """GitHub OAuth2.0 认证实现"""
import aiohttp from typing import TYPE_CHECKING
class GithubAccessToken(BaseModel): from pydantic import BaseModel
access_token: str from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse
if TYPE_CHECKING:
from . import OAuthUserData
class GithubAccessToken(AccessTokenBase):
"""GitHub 访问令牌响应"""
token_type: str token_type: str
"""令牌类型"""
scope: str scope: str
"""授权范围"""
class GithubUserData(BaseModel): class GithubUserData(BaseModel):
"""GitHub 用户数据"""
login: str login: str
"""用户名"""
id: int id: int
"""用户 ID"""
node_id: str node_id: str
"""节点 ID"""
avatar_url: str avatar_url: str
"""头像 URL"""
gravatar_id: str | None gravatar_id: str | None
"""Gravatar ID"""
url: str url: str
"""API URL"""
html_url: str html_url: str
"""主页 URL"""
followers_url: str followers_url: str
"""粉丝列表 URL"""
following_url: str following_url: str
"""关注列表 URL"""
gists_url: str gists_url: str
"""Gists 列表 URL"""
starred_url: str starred_url: str
"""星标列表 URL"""
subscriptions_url: str subscriptions_url: str
"""订阅列表 URL"""
organizations_url: str organizations_url: str
"""组织列表 URL"""
repos_url: str repos_url: str
"""仓库列表 URL"""
events_url: str events_url: str
"""事件列表 URL"""
received_events_url: str received_events_url: str
"""接收的事件列表 URL"""
type: str type: str
"""用户类型"""
site_admin: bool site_admin: bool
"""是否为站点管理员"""
name: str | None name: str | None
"""显示名称"""
company: str | None company: str | None
"""公司"""
blog: str | None blog: str | None
"""博客"""
location: str | None location: str | None
"""位置"""
email: str | None email: str | None
"""邮箱"""
hireable: bool | None hireable: bool | None
"""是否可雇佣"""
bio: str | None bio: str | None
"""个人简介"""
twitter_username: str | None twitter_username: str | None
"""Twitter 用户名"""
public_repos: int public_repos: int
"""公开仓库数"""
public_gists: int public_gists: int
"""公开 Gists 数"""
followers: int followers: int
"""粉丝数"""
following: int following: int
created_at: str # ISO 8601 format date-time string """关注数"""
updated_at: str # ISO 8601 format date-time string created_at: str
"""创建时间ISO 8601 格式)"""
updated_at: str
"""更新时间ISO 8601 格式)"""
class GithubUserInfoResponse(BaseModel): class GithubUserInfoResponse(BaseModel):
"""GitHub 用户信息响应"""
code: str code: str
"""状态码"""
user_data: GithubUserData user_data: GithubUserData
"""用户数据"""
async def get_access_token(code: str) -> GithubAccessToken:
async with aiohttp.ClientSession() as session:
async with session.post(
url='https://github.com/login/oauth/access_token',
params={
'client_id': '',
'client_secret': '',
'code': code
},
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return GithubAccessToken(
access_token=access_data.get('access_token'),
token_type=access_data.get('token_type'),
scope=access_data.get('scope')
)
async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse: class GithubOAuth(OAuthBase):
if isinstance(access_token, GithubAccessToken): """GitHub OAuth2.0 客户端"""
access_token = access_token.access_token
access_token_url = "https://github.com/login/oauth/access_token"
async with aiohttp.ClientSession() as session: """获取 Access Token 的 API 地址"""
async with session.get(
url='https://api.github.com/user', user_info_url = "https://api.github.com/user"
headers={ """获取用户信息的 API 地址"""
'accept': 'application/json',
'Authorization': f'token {access_token}'}, http_method = "POST"
) as resp: """获取 token 的 HTTP 方法"""
user_data = await resp.json()
return GithubUserInfoResponse(**user_data) def _parse_token_response(self, data: dict) -> GithubAccessToken:
"""解析 GitHub token 响应"""
return GithubAccessToken(
access_token=data.get('access_token'),
token_type=data.get('token_type'),
scope=data.get('scope'),
)
def _build_user_info_headers(self, access_token: str) -> dict:
"""构建 GitHub 用户信息请求头"""
return {
'accept': 'application/json',
'Authorization': f'token {access_token}',
}
def _parse_user_response(self, data: dict) -> GithubUserInfoResponse:
"""解析 GitHub 用户信息响应"""
return GithubUserInfoResponse(
code='0' if data.get('login') else '1',
user_data=GithubUserData(**data),
)

View File

@@ -1,7 +1,158 @@
from pydantic import BaseModel """QQ OAuth2.0 认证实现"""
import aiohttp import aiohttp
async def get_access_token( from pydantic import BaseModel
from . import AccessTokenBase, OAuthBase
class QQAccessToken(AccessTokenBase):
"""QQ 访问令牌响应"""
expires_in: int
"""access token 的有效期,单位为秒"""
refresh_token: str
"""用于刷新 access token 的令牌"""
class QQOpenIDResponse(BaseModel):
"""QQ OpenID 响应"""
client_id: str
"""应用的 appid"""
openid: str
"""用户的唯一标识"""
class QQUserData(BaseModel):
"""QQ 用户数据"""
ret: int
"""返回码0 表示成功"""
msg: str
"""返回信息"""
nickname: str | None
"""用户昵称"""
gender: str | None
"""性别"""
figureurl: str | None
"""头像 URL"""
figureurl_1: str | None
"""头像 URL大图"""
figureurl_2: str | None
"""头像 URL更大图"""
figureurl_qq_1: str | None
"""QQ 头像 URL大图"""
figureurl_qq_2: str | None
"""QQ 头像 URL更大图"""
is_yellow_vip: str | None
"""是否黄钻用户"""
vip: str | None
"""是否 VIP 用户"""
yellow_vip_level: str | None
"""黄钻等级"""
level: str | None
"""等级"""
is_yellow_year_vip: str | None
"""是否年费黄钻"""
class QQUserInfoResponse(BaseModel):
"""QQ 用户信息响应"""
code: str code: str
): """状态码"""
... openid: str
"""用户 OpenID"""
user_data: QQUserData
"""用户数据"""
class QQOAuth(OAuthBase):
"""QQ OAuth2.0 客户端"""
access_token_url = "https://graph.qq.com/oauth2.0/token"
"""获取 Access Token 的 API 地址"""
user_info_url = "https://graph.qq.com/user/get_user_info"
"""获取用户信息的 API 地址"""
openid_url = "https://graph.qq.com/oauth2.0/me"
"""获取 OpenID 的 API 地址"""
http_method = "GET"
"""获取 token 的 HTTP 方法"""
async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode
Returns:
QQAccessToken: 访问令牌
文档:
https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token
"""
params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'redirect_uri': redirect_uri,
'fmt': 'json',
'need_openid': 1,
}
async with aiohttp.ClientSession() as session:
async with session.get(url=self.access_token_url, params=params) as access_resp:
access_data = await access_resp.json()
return QQAccessToken(
access_token=access_data.get('access_token'),
expires_in=access_data.get('expires_in'),
refresh_token=access_data.get('refresh_token'),
)
async def get_openid(self, access_token: str) -> QQOpenIDResponse:
"""
获取用户 OpenID
注意:如果在 get_access_token 时传入了 need_openid=1响应中已包含 openid
无需额外调用此接口。此函数用于单独获取 openid 的场景。
Args:
access_token: 访问令牌
Returns:
QQOpenIDResponse: 包含 client_id 和 openid
文档:
https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3
"""
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.openid_url,
params={
'access_token': access_token,
'fmt': 'json',
},
) as resp:
data = await resp.json()
return QQOpenIDResponse(
client_id=data.get('client_id'),
openid=data.get('openid'),
)
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""构建 QQ 用户信息请求参数"""
return {
'access_token': access_token,
'oauth_consumer_key': kwargs.get('app_id', self.client_id),
'openid': kwargs.get('openid', ''),
}
def _parse_user_response(self, data: dict) -> QQUserInfoResponse:
"""解析 QQ 用户信息响应"""
return QQUserInfoResponse(
code='0' if data.get('ret') == 0 else str(data.get('ret')),
openid=data.get('openid', ''),
user_data=QQUserData(**data),
)

View File

@@ -25,12 +25,12 @@ async def login(
# TODO: 验证码校验 # TODO: 验证码校验
# captcha_setting = await Setting.get( # captcha_setting = await Setting.get(
# session, # session,
# and_(Setting.type == "auth", Setting.name == "login_captcha") # (Setting.type == "auth") & (Setting.name == "login_captcha")
# ) # )
# is_captcha_required = captcha_setting and captcha_setting.value == "1" # is_captcha_required = captcha_setting and captcha_setting.value == "1"
# 获取用户信息 # 获取用户信息
current_user = await User.get(session, User.username == login_request.username, fetch_mode="first") current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
# 验证用户是否存在 # 验证用户是否存在
if not current_user: if not current_user:

View File

@@ -42,10 +42,9 @@ class GroupFactory:
speed_limit=kwargs.get("speed_limit", 0), speed_limit=kwargs.get("speed_limit", 0),
) )
group = await group.save(session)
# 如果提供了选项参数,创建 GroupOptions # 如果提供了选项参数,创建 GroupOptions
if kwargs.get("create_options", False): if kwargs.get("create_options", False):
group = await group.save(session, commit=False)
options = GroupOptions( options = GroupOptions(
group_id=group.id, group_id=group.id,
share_download=kwargs.get("share_download", True), share_download=kwargs.get("share_download", True),
@@ -55,7 +54,10 @@ class GroupFactory:
select_node=kwargs.get("select_node", False), select_node=kwargs.get("select_node", False),
advance_delete=kwargs.get("advance_delete", False), advance_delete=kwargs.get("advance_delete", False),
) )
await options.save(session) await options.save(session, commit=False)
await session.commit()
else:
group = await group.save(session)
return group return group
@@ -88,7 +90,7 @@ class GroupFactory:
speed_limit=0, speed_limit=0,
) )
admin_group = await admin_group.save(session) admin_group = await admin_group.save(session, commit=False)
# 创建管理员组选项 # 创建管理员组选项
admin_options = GroupOptions( admin_options = GroupOptions(
@@ -105,7 +107,8 @@ class GroupFactory:
aria2=True, aria2=True,
redirected_source=True, redirected_source=True,
) )
await admin_options.save(session) await admin_options.save(session, commit=False)
await session.commit()
return admin_group return admin_group
@@ -140,7 +143,7 @@ class GroupFactory:
speed_limit=1024, # 1MB/s speed_limit=1024, # 1MB/s
) )
limited_group = await limited_group.save(session) limited_group = await limited_group.save(session, commit=False)
# 创建限制组选项 # 创建限制组选项
limited_options = GroupOptions( limited_options = GroupOptions(
@@ -152,7 +155,8 @@ class GroupFactory:
select_node=False, select_node=False,
advance_delete=False, advance_delete=False,
) )
await limited_options.save(session) await limited_options.save(session, commit=False)
await session.commit()
return limited_group return limited_group
@@ -185,7 +189,7 @@ class GroupFactory:
speed_limit=512, # 512KB/s speed_limit=512, # 512KB/s
) )
free_group = await free_group.save(session) free_group = await free_group.save(session, commit=False)
# 创建免费组选项 # 创建免费组选项
free_options = GroupOptions( free_options = GroupOptions(
@@ -197,6 +201,7 @@ class GroupFactory:
select_node=False, select_node=False,
advance_delete=False, advance_delete=False,
) )
await free_options.save(session) await free_options.save(session, commit=False)
await session.commit()
return free_group return free_group

View File

@@ -13,7 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(
refreshUrl="/api/v1/user/session/refresh", refreshUrl="/api/v1/user/session/refresh",
) )
SECRET_KEY = '' SECRET_KEY: str = ''
async def load_secret_key() -> None: async def load_secret_key() -> None:
@@ -26,10 +26,10 @@ async def load_secret_key() -> None:
global SECRET_KEY global SECRET_KEY
async for session in get_session(): async for session in get_session():
setting = await Setting.get( setting: Setting = await Setting.get(
session, session,
(Setting.type == "auth") & (Setting.name == "secret_key") (Setting.type == "auth") & (Setting.name == "secret_key")
) ) # type: ignore
if setting: if setting:
SECRET_KEY = setting.value SECRET_KEY = setting.value
@@ -40,7 +40,14 @@ def build_token_payload(
algorithm: str, algorithm: str,
expires_delta: timedelta | None = None, expires_delta: timedelta | None = None,
) -> tuple[str, datetime]: ) -> tuple[str, datetime]:
"""构建令牌""" """
构建令牌。
:param data: 需要放进 JWT Payload 的字段
:param is_refresh: 是否为刷新令牌
:param algorithm: JWT 签名算法
:param expires_delta: 过期时间
"""
to_encode = data.copy() to_encode = data.copy()
@@ -61,8 +68,11 @@ def build_token_payload(
# 访问令牌 # 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None, def create_access_token(
algorithm: str = "HS256") -> AccessTokenBase: data: dict,
expires_delta: timedelta | None = None,
algorithm: str = "HS256"
) -> AccessTokenBase:
""" """
生成访问令牌,默认有效期 3 小时。 生成访问令牌,默认有效期 3 小时。
@@ -73,7 +83,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
:return: 包含密钥本身和过期时间的 `AccessTokenBase` :return: 包含密钥本身和过期时间的 `AccessTokenBase`
""" """
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta) access_token, expire_at = build_token_payload(
data,
False,
algorithm,
expires_delta
)
return AccessTokenBase( return AccessTokenBase(
access_token=access_token, access_token=access_token,
access_expires=expire_at, access_expires=expire_at,
@@ -81,8 +96,11 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
# 刷新令牌 # 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None, def create_refresh_token(
algorithm: str = "HS256") -> RefreshTokenBase: data: dict,
expires_delta: timedelta | None = None,
algorithm: str = "HS256"
) -> RefreshTokenBase:
""" """
生成刷新令牌,默认有效期 30 天。 生成刷新令牌,默认有效期 30 天。
@@ -93,7 +111,12 @@ def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
:return: 包含密钥本身和过期时间的 `RefreshTokenBase` :return: 包含密钥本身和过期时间的 `RefreshTokenBase`
""" """
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta) refresh_token, expire_at = build_token_payload(
data,
True,
algorithm,
expires_delta
)
return RefreshTokenBase( return RefreshTokenBase(
refresh_token=refresh_token, refresh_token=refresh_token,
refresh_expires=expire_at, refresh_expires=expire_at,

View File

@@ -13,7 +13,7 @@ license_info = {"name": "GPLv3", "url": "https://opensource.org/license/gpl-3.0"
BackendVersion = "0.0.1" BackendVersion = "0.0.1"
"""后端版本""" """后端版本"""
IsPro = False IsPro: bool = False
mode: str = os.getenv('MODE', 'master') mode: str = os.getenv('MODE', 'master')
"""运行模式""" """运行模式"""

View File

@@ -1,4 +1,5 @@
import secrets import secrets
from typing import Literal
from loguru import logger from loguru import logger
from argon2 import PasswordHasher from argon2 import PasswordHasher
@@ -11,7 +12,23 @@ from pydantic import BaseModel, Field
from utils.JWT import SECRET_KEY from utils.JWT import SECRET_KEY
from utils.conf import appmeta from utils.conf import appmeta
_ph = PasswordHasher() # FIRST RECOMMENDED option per RFC 9106.
_ph_lowmem = PasswordHasher(
salt_len=16,
hash_len=32,
time_cost=3,
memory_cost=65536, # 64 MiB
parallelism=4,
)
# SECOND RECOMMENDED option per RFC 9106.
_ph_highmem = PasswordHasher(
salt_len=16,
hash_len=32,
time_cost=1,
memory_cost=2097152, # 2 GiB
parallelism=4,
)
class PasswordStatus(StrEnum): class PasswordStatus(StrEnum):
"""密码校验状态枚举""" """密码校验状态枚举"""
@@ -48,7 +65,8 @@ class Password:
@staticmethod @staticmethod
def generate( def generate(
length: int = 8 length: int = 8,
url_safe: bool = False
) -> str: ) -> str:
""" """
生成指定长度的随机密码。 生成指定长度的随机密码。
@@ -58,7 +76,16 @@ class Password:
:return: 随机密码 :return: 随机密码
:rtype: str :rtype: str
""" """
return secrets.token_hex(length) if url_safe:
return secrets.token_urlsafe(length)
else:
return secrets.token_hex(length)
@staticmethod
def generate_hex(
length: int = 8
) -> bytes:
return secrets.token_bytes(length)
@staticmethod @staticmethod
def hash( def hash(
@@ -72,7 +99,7 @@ class Password:
:param password: 需要哈希的原始密码 :param password: 需要哈希的原始密码
:return: Argon2 哈希字符串 :return: Argon2 哈希字符串
""" """
return _ph.hash(password) return _ph_lowmem.hash(password)
@staticmethod @staticmethod
def verify( def verify(
@@ -87,21 +114,16 @@ class Password:
:return: 如果密码匹配返回 True, 否则返回 False :return: 如果密码匹配返回 True, 否则返回 False
""" """
try: try:
# verify 函数会自动解析 stored_password 中的盐和参数 _ph_lowmem.verify(hash, password)
_ph.verify(hash, password)
# 检查哈希参数是否已过时。如果返回True # 检查哈希参数是否已过时
# 意味着你应该使用新的参数重新哈希密码并更新存储。 if _ph_lowmem.check_needs_rehash(hash):
# 这是一个很好的实践,可以随着时间推移增强安全性。
if _ph.check_needs_rehash(hash):
logger.warning("密码哈希参数已过时,建议重新哈希并更新。") logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
return PasswordStatus.EXPIRED return PasswordStatus.EXPIRED
return PasswordStatus.VALID return PasswordStatus.VALID
except VerifyMismatchError: except VerifyMismatchError:
# 这是预期的异常,当密码不匹配时触发。
return PasswordStatus.INVALID return PasswordStatus.INVALID
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
@staticmethod @staticmethod
async def generate_totp( async def generate_totp(