Refactor and enhance OAuth2.0 implementation; update models and routes

- Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity.
- Added OAUTH type to SettingsType enum.
- Cleaned up imports in webdav.py.
- Updated admin router to improve summary data retrieval and response handling.
- Enhanced file management routes with better condition handling and user storage updates.
- Improved group management routes by optimizing data retrieval.
- Refined task management routes for better condition handling.
- Updated user management routes to streamline access token retrieval.
- Implemented a new captcha verification structure with abstract base class.
- Removed deprecated env.md file and replaced with a new structured version.
- Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations.
- Enhanced password management with improved hashing strategies.
- Added detailed comments and documentation throughout the codebase for clarity.
This commit is contained in:
2026-01-12 18:07:44 +08:00
parent 61ddc96f17
commit d2c914cff8
29 changed files with 814 additions and 4609 deletions

View File

@@ -0,0 +1,9 @@
{
"permissions": {
"allow": [
"Bash(git rev-parse:*)",
"Bash(findstr:*)",
"Bash(find:*)"
]
}
}

4407
.xml

File diff suppressed because it is too large Load Diff

View File

@@ -9,4 +9,4 @@
- `REDIS_PORT`: Redis 端口
- `REDIS_PASSWORD`: Redis 密码
- `REDIS_DB`: Redis 数据库
- `REDIS_PROTOCOL`
- `REDIS_PROTOCOL`: Redis 协议

16
main.py
View File

@@ -33,8 +33,22 @@ app = FastAPI(
openapi_url="/openapi.json" if appmeta.debug else None,
)
# 添加跨域 CORS 中间件,仅在调试模式下启用,以允许所有来源访问 API
if appmeta.debug:
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(Exception)
async def handle_unexpected_exceptions(request: Request, exc: Exception) -> NoReturn:
async def handle_unexpected_exceptions(
request: Request,
exc: Exception
) -> NoReturn:
"""
捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。
"""

View File

@@ -60,10 +60,10 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None:
try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"])
if payload.get("type") != "download":
return None
http_exceptions.raise_unauthorized("Download token required")
jti = payload.get("jti")
if not jti:
return None
http_exceptions.raise_unauthorized("Download token required")
return jti, UUID(payload["file_id"]), UUID(payload["owner_id"])
except (jwt.ExpiredSignatureError, jwt.InvalidTokenError):
return None
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("Download token required")

View File

@@ -120,7 +120,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None:
from .setting import Setting
from .database import get_session
from sqlalchemy import and_
log.info('初始化设置...')
@@ -128,7 +127,7 @@ async def init_default_settings() -> None:
# 检查是否已经存在版本设置
ver = await Setting.get(
session,
and_(Setting.type == SettingsType.VERSION, Setting.name == f"db_version_{BackendVersion}")
(Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}")
)
if ver and ver.value == "installed":
return

View File

@@ -218,7 +218,7 @@ class TableBaseMixin(AsyncAttrs):
)
@classmethod
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]:
"""
向数据库中添加一个新的或多个新的记录.
@@ -230,6 +230,8 @@ class TableBaseMixin(AsyncAttrs):
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要添加的单个模型实例或模型实例列表.
refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns:
T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例.
@@ -244,6 +246,11 @@ class TableBaseMixin(AsyncAttrs):
# 添加单个实例
item3 = Item(name="Cherry")
added_item = await Item.add(session, item3)
# 批量操作,减少提交次数
await Item.add(session, [item1, item2], commit=False)
await Item.add(session, [item3, item4], commit=False)
await session.commit()
"""
is_list = False
if isinstance(instances, list):
@@ -252,7 +259,10 @@ class TableBaseMixin(AsyncAttrs):
else:
session.add(instances)
await session.commit()
if commit:
await session.commit()
else:
await session.flush()
if refresh:
if is_list:
@@ -266,15 +276,16 @@ class TableBaseMixin(AsyncAttrs):
async def save(
self: T,
session: AsyncSession,
load: RelationshipInfo | None = None,
refresh: bool = True
load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True,
commit: bool = True
) -> T:
"""
保存(插入或更新)当前模型实例到数据库.
这是一个实例方法,它将当前对象添加到会话中并提交更改。
可以用于创建新记录或更新现有记录。还可以选择在保存后
预加载eager load一个关联关系.
预加载eager load一个或多个关联关系.
**重要**调用此方法后session中的所有对象都会过期expired
如果需要继续使用该对象,必须使用返回值:
@@ -287,6 +298,14 @@ class TableBaseMixin(AsyncAttrs):
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.save(session, refresh=False)
# ✅ 正确:批量操作,减少提交次数
await item1.save(session, commit=False)
await item2.save(session, commit=False)
await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系
user = await user.save(session, load=[User.group, User.tags])
# ❌ 错误:需要返回值但未使用
await client.save(session)
return client # client 对象已过期
@@ -294,16 +313,22 @@ class TableBaseMixin(AsyncAttrs):
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性.
例如 `User.posts`.
load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性.
可以是单个关系或关系列表.
例如 `User.posts` 或 `[User.group, User.tags]`.
refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
"""
session.add(self)
await session.commit()
if commit:
await session.commit()
else:
await session.flush()
if not refresh:
return self
@@ -324,8 +349,9 @@ class TableBaseMixin(AsyncAttrs):
extra_data: dict[str, Any] | None = None,
exclude_unset: bool = True,
exclude: set[str] | None = None,
load: RelationshipInfo | None = None,
refresh: bool = True
load: RelationshipInfo | list[RelationshipInfo] | None = None,
refresh: bool = True,
commit: bool = True
) -> T:
"""
使用另一个模型实例或字典中的数据来更新当前实例.
@@ -348,6 +374,14 @@ class TableBaseMixin(AsyncAttrs):
# ✅ 正确:不需要返回值时,指定 refresh=False 节省性能
await client.update(session, update_data, refresh=False)
# ✅ 正确:批量操作,减少提交次数
await user1.update(session, data1, commit=False)
await user2.update(session, data2, commit=False)
await session.commit()
# ✅ 正确:批量操作并预加载多个关联关系
user = await user.update(session, data, load=[User.group, User.tags])
# ❌ 错误:需要返回值但未使用
await client.update(session, update_data)
return client # client 对象已过期
@@ -360,10 +394,13 @@ class TableBaseMixin(AsyncAttrs):
exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供)
的字段将被忽略. 默认为 True.
exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}.
load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性.
例如 `User.permission`.
load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性.
可以是单个关系或关系列表.
例如 `User.permission` 或 `[User.group, User.tags]`.
refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值,
设为 False 可节省一次数据库查询。默认为 True.
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns:
T: 如果 refresh=True返回已刷新的模型实例否则返回未刷新的 self.
@@ -374,7 +411,10 @@ class TableBaseMixin(AsyncAttrs):
)
session.add(self)
await session.commit()
if commit:
await session.commit()
else:
await session.flush()
if not refresh:
return self
@@ -388,33 +428,82 @@ class TableBaseMixin(AsyncAttrs):
return self
@classmethod
async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None:
async def delete(
cls: type[T],
session: AsyncSession,
instances: T | list[T] | None = None,
*,
condition: BinaryExpression | ClauseElement | None = None,
commit: bool = True
) -> int:
"""
从数据库中删除一个或多个记录.
从数据库中删除记录.
支持两种删除方式:
1. 实例删除:传入 instances 参数,先加载再删除
2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效)
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
instances (T | list[T]): 要删除的单个模型实例或模型实例列表.
instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选).
condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一).
commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数,
之后需要手动调用 `session.commit()`。默认为 True.
Returns:
None
int: 删除的记录数量
Usage:
# 实例删除
item_to_delete = await Item.get(session, Item.id == 1)
if item_to_delete:
await Item.delete(session, item_to_delete)
deleted_count = await Item.delete(session, item_to_delete)
items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all")
if items_to_delete:
await Item.delete(session, items_to_delete)
# 条件删除(更高效,无需加载实例)
deleted_count = await Item.delete(
session,
condition=(Item.status == "inactive") & (Item.created_at < cutoff_date)
)
# 批量删除后手动提交
await Item.delete(session, item1, commit=False)
await Item.delete(session, item2, commit=False)
await session.commit()
"""
# 条件删除模式
if condition is not None:
from sqlmodel import delete as sql_delete
if instances is not None:
raise ValueError("不能同时指定 instances 和 condition")
# 执行条件删除
stmt = sql_delete(cls).where(condition)
result = await session.exec(stmt)
deleted_count = result.rowcount
if commit:
await session.commit()
return deleted_count
# 实例删除模式(原有逻辑)
if instances is None:
raise ValueError("必须指定 instances 或 condition")
deleted_count = 0
if isinstance(instances, list):
for instance in instances:
await session.delete(instance)
deleted_count += 1
else:
await session.delete(instances)
deleted_count = 1
await session.commit()
if commit:
await session.commit()
return deleted_count
@classmethod
def _build_time_filters(
@@ -458,7 +547,7 @@ class TableBaseMixin(AsyncAttrs):
fetch_mode: Literal["one", "first", "all"] = "first",
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
with_for_update: bool = False,
@@ -491,8 +580,9 @@ class TableBaseMixin(AsyncAttrs):
例如 `User` 或 `(Profile, User.id == Profile.user_id)`.
options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据,
例如 `[selectinload(User.posts)]`.
load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系.
例如 `User.profile`.
load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系.
可以是单个关系或关系列表.
例如 `User.profile` 或 `[User.group, User.tags]`.
order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表.
例如 `[User.name.asc(), User.created_at.desc()]`.
filter (BinaryExpression | ClauseElement | None): 附加的过滤条件.
@@ -595,9 +685,15 @@ class TableBaseMixin(AsyncAttrs):
statement = statement.options(*options)
if load:
# 标准化为列表
load_list = load if isinstance(load, list) else [load]
# 处理多态加载
if load_polymorphic is not None:
target_class = load.property.mapper.class_
# 多态加载只支持单个关系
if len(load_list) > 1:
raise ValueError("load_polymorphic 仅支持单个关系")
target_class = load_list[0].property.mapper.class_
# 检查目标类是否继承自 PolymorphicBaseMixin
if not issubclass(target_class, PolymorphicBaseMixin):
@@ -609,7 +705,7 @@ class TableBaseMixin(AsyncAttrs):
if load_polymorphic == 'all':
# 两阶段查询:获取实际关联的多态类型
subclasses_to_load = await cls._resolve_polymorphic_subclasses(
session, condition, load, target_class
session, condition, load_list[0], target_class
)
else:
subclasses_to_load = load_polymorphic
@@ -618,12 +714,14 @@ class TableBaseMixin(AsyncAttrs):
# 关键selectin_polymorphic 必须作为 selectinload 的链式子选项
# 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading
statement = statement.options(
selectinload(load).selectin_polymorphic(subclasses_to_load)
selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load)
)
else:
statement = statement.options(selectinload(load))
statement = statement.options(selectinload(load_list[0]))
else:
statement = statement.options(selectinload(load))
# 为每个关系添加 selectinload
for rel in load_list:
statement = statement.options(selectinload(rel))
if order_by is not None:
statement = statement.order_by(*order_by)
@@ -796,7 +894,7 @@ class TableBaseMixin(AsyncAttrs):
*,
join: type[T] | tuple[type[T], _OnClauseArgument] | None = None,
options: list | None = None,
load: RelationshipInfo | None = None,
load: RelationshipInfo | list[RelationshipInfo] | None = None,
order_by: list[ClauseElement] | None = None,
filter: BinaryExpression | ClauseElement | None = None,
table_view: TableViewRequest | None = None,
@@ -865,7 +963,7 @@ class TableBaseMixin(AsyncAttrs):
return ListResponse(count=total_count, items=items)
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T:
async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T:
"""
根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常.
@@ -875,7 +973,8 @@ class TableBaseMixin(AsyncAttrs):
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (int): 要查找的记录的主键 ID.
load (Relationship | None): 可选的,用于预加载的关联属性.
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
可以是单个关系或关系列表.
Returns:
T: 找到的模型实例.
@@ -903,7 +1002,7 @@ class UUIDTableBaseMixin(TableBaseMixin):
@override
@classmethod
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T:
async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T:
"""
根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常.
@@ -913,7 +1012,8 @@ class UUIDTableBaseMixin(TableBaseMixin):
Args:
session (AsyncSession): 用于数据库操作的异步会话对象.
id (uuid.UUID): 要查找的记录的 UUID 主键.
load (Relationship | None): 可选的,用于预加载的关联属性.
load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性.
可以是单个关系或关系列表.
Returns:
T: 找到的模型实例.

View File

@@ -79,9 +79,8 @@ class VersionInfo(SQLModelBase):
commit: str
"""提交哈希"""
class AdminSummaryData(SQLModelBase):
"""管理员概况数据"""
class AdminSummaryResponse(ResponseBase):
"""管理员概况响应"""
metrics_summary: MetricsSummary
"""统计摘要"""
@@ -95,13 +94,6 @@ class AdminSummaryData(SQLModelBase):
version: VersionInfo
"""版本信息"""
class AdminSummaryResponse(ResponseBase):
"""管理员概况响应"""
data: AdminSummaryData | None = None
"""响应数据"""
class MCPMethod(StrEnum):
"""MCP 方法枚举"""

View File

@@ -104,6 +104,7 @@ class SettingsType(StrEnum):
MAIL = "mail"
MAIL_TEMPLATE = "mail_template"
MOBILE = "mobile"
OAUTH = "oauth"
PATH = "path"
PREVIEW = "preview"
PWA = "pwa"

View File

@@ -2,7 +2,7 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase
from .mixin import TableBaseMixin

View File

@@ -2,14 +2,12 @@ from datetime import datetime, timedelta
from fastapi import APIRouter, Depends
from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from models import (
User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
AdminSummaryData,
)
from models.base import SQLModelBase
from models.setting import (
@@ -75,8 +73,8 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
Returns:
AdminSummaryResponse: 包含站点概况信息的响应模型。
"""
# 统计最近 12 天的数据
days_count = 12
# 统计最近 14 天的数据
days_count = 14
now = datetime.now()
today_start = now.replace(hour=0, minute=0, second=0, microsecond=0)
@@ -135,7 +133,7 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
site_urls: list[str] = []
site_url_setting = await Setting.get(
session,
and_(Setting.type == SettingsType.BASIC, Setting.name == "siteURL"),
(Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"),
)
if site_url_setting and site_url_setting.value:
site_urls.append(site_url_setting.value)
@@ -156,15 +154,13 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse:
commit="dev",
)
data = AdminSummaryData(
return AdminSummaryResponse(
metrics_summary=metrics_summary,
site_urls=site_urls,
license=license_info,
version=version_info,
)
return AdminSummaryResponse(data=data)
@admin_router.get(
path='/news',
summary='获取社区新闻',
@@ -203,7 +199,7 @@ async def router_admin_update_settings(
for item in request.settings:
existing = await Setting.get(
session,
and_(Setting.type == item.type, Setting.name == item.name)
(Setting.type == item.type) & (Setting.name == item.name)
)
if existing:
@@ -245,7 +241,12 @@ async def router_admin_get_settings(
if name:
conditions.append(Setting.name == name)
condition = and_(*conditions) if conditions else None
if conditions:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = None
settings = await Setting.get(session, condition, fetch_mode="all")

View File

@@ -5,7 +5,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -51,7 +50,12 @@ async def router_admin_get_file_list(
if keyword:
conditions.append(Object.name.ilike(f"%{keyword}%"))
condition = and_(*conditions) if len(conditions) > 1 else conditions[0]
if len(conditions) > 1:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = conditions[0]
result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner)
# 构建响应
@@ -197,13 +201,15 @@ async def router_admin_delete_file(
except Exception as e:
l.warning(f"删除物理文件失败: {e}")
# 更新用户存储量
owner = await User.get(session, User.id == owner_id)
if owner:
owner.storage = max(0, owner.storage - file_size)
await owner.save(session)
# 更新用户存储量(使用 SQL UPDATE 直接更新,无需加载实例)
from sqlmodel import update as sql_update
stmt = sql_update(User).where(User.id == owner_id).values(
storage=max(0, User.storage - file_size)
)
await session.exec(stmt)
await Object.delete(session, file_obj)
# 使用条件删除
await Object.delete(session, condition=Object.id == file_obj.id)
l.info(f"管理员删除了文件: {file_name}")
return ResponseBase(data={"deleted": True})

View File

@@ -63,12 +63,13 @@ async def router_admin_get_group(
:param group_id: 用户组UUID
:return: 用户组详情
"""
group = await Group.get(session, Group.id == group_id, load=Group.options)
group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies])
if not group:
raise HTTPException(status_code=404, detail="用户组不存在")
policies = await group.awaitable_attrs.policies
# 直接访问已加载的关系,无需额外查询
policies = group.policies
user_count = await User.count(session, User.group_id == group_id)
response = GroupDetailResponse.from_group(group, user_count, policies)

View File

@@ -2,7 +2,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from sqlalchemy import and_
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -43,7 +42,12 @@ async def router_admin_get_task_list(
if status:
conditions.append(Task.status == status)
condition = and_(*conditions) if conditions else None
if conditions:
condition = conditions[0]
for c in conditions[1:]:
condition = condition & c
else:
condition = None
result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user)
items: list[TaskSummary] = []

View File

@@ -2,7 +2,7 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger as l
from sqlalchemy import func, and_
from sqlalchemy import func
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
@@ -198,7 +198,7 @@ async def router_admin_calibrate_storage(
from sqlmodel import select
result = await session.execute(
select(func.sum(Object.size), func.count(Object.id)).where(
and_(Object.owner_id == user_id, Object.type == ObjectType.FILE)
(Object.owner_id == user_id) & (Object.type == ObjectType.FILE)
)
)
row = result.one()

View File

@@ -233,7 +233,7 @@ async def upload_chunk(
policy_id=upload_session.policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session)
physical_file = await physical_file.save(session, commit=False)
# 创建 Object 记录
file_object = Object(
@@ -246,11 +246,18 @@ async def upload_chunk(
owner_id=user_id,
policy_id=upload_session.policy_id,
)
file_object = await file_object.save(session)
file_object = await file_object.save(session, commit=False)
file_object_id = file_object.id
# 删除上传会话
await UploadSession.delete(session, upload_session)
# 删除上传会话(使用条件删除)
await UploadSession.delete(
session,
condition=UploadSession.id == upload_session.id,
commit=False
)
# 统一提交所有更改
await session.commit()
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")

View File

@@ -1,5 +1,4 @@
from fastapi import APIRouter
from sqlalchemy import and_
from middleware.dependencies import SessionDep
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
@@ -55,5 +54,5 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse:
dict: The site configuration.
"""
return SiteConfigResponse(
title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")),
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
)

View File

@@ -3,7 +3,6 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy import and_
from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
@@ -115,7 +114,7 @@ async def router_user_register(
# 2. 获取默认用户组(从设置中读取 UUID
default_group_setting: models.Setting | None = await models.Setting.get(
session,
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
)
if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在")
@@ -352,18 +351,18 @@ async def router_user_authn_start(
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await models.Setting.get(
session,
and_(models.Setting.type == "authn", models.Setting.name == "authn_enabled")
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
)
if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
site_url_setting = await models.Setting.get(
session,
and_(models.Setting.type == "basic", models.Setting.name == "siteURL")
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
)
site_title_setting = await models.Setting.get(
session,
and_(models.Setting.type == "basic", models.Setting.name == "siteTitle")
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
)
options = generate_registration_options(

View File

@@ -1,5 +1,39 @@
import abc
import aiohttp
from pydantic import BaseModel
from .gcaptcha import GCaptcha
from .turnstile import TurnstileCaptcha
class CaptchaRequestBase(BaseModel):
"""验证码验证请求"""
token: str
secret: str
"""验证 token"""
secret: str
"""验证密钥"""
class CaptchaBase(abc.ABC):
"""验证码验证器抽象基类"""
verify_url: str
"""验证 API 地址(子类必须定义)"""
async def verify_captcha(self, request: CaptchaRequestBase) -> bool:
"""
验证 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(self.verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

View File

@@ -1,21 +1,7 @@
import aiohttp
from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)
class GCaptcha(CaptchaBase):
"""Google reCAPTCHA v2/v3 验证器"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"

View File

@@ -1,21 +1,7 @@
import aiohttp
from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Turnstile 的 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)
class TurnstileCaptcha(CaptchaBase):
"""Cloudflare Turnstile 验证器"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"

223
service/oauth/__init__.py Normal file
View File

@@ -0,0 +1,223 @@
"""
OAuth2.0 认证模块
提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。
"""
import abc
import aiohttp
from pydantic import BaseModel
# ==================== 共享数据模型 ====================
class AccessTokenBase(BaseModel):
"""访问令牌基类"""
access_token: str
"""访问令牌"""
class OAuthUserData(BaseModel):
"""OAuth 用户数据通用 DTO"""
openid: str
"""用户唯一标识GitHub 为 idQQ 为 openid"""
nickname: str | None
"""用户昵称"""
avatar_url: str | None
"""头像 URL"""
email: str | None
"""邮箱"""
bio: str | None
"""个人简介"""
class OAuthUserInfoResponse(BaseModel):
"""OAuth 用户信息响应"""
code: str
"""状态码"""
openid: str
"""用户唯一标识"""
user_data: OAuthUserData
"""用户数据"""
# ==================== OAuth2.0 抽象基类 ====================
class OAuthBase(abc.ABC):
"""
OAuth2.0 客户端抽象基类
子类需要定义以下类属性:
- access_token_url: 获取 Access Token 的 API 地址
- user_info_url: 获取用户信息的 API 地址
- http_method: 获取 token 的 HTTP 方法POST 或 GET
"""
# 子类必须定义的类属性
access_token_url: str
"""获取 Access Token 的 API 地址"""
user_info_url: str
"""获取用户信息的 API 地址"""
http_method: str = "POST"
"""获取 token 的 HTTP 方法POST 或 GET"""
# 实例属性(构造函数传入)
client_id: str
client_secret: str
def __init__(self, client_id: str, client_secret: str) -> None:
"""
初始化 OAuth 客户端
Args:
client_id: 应用 client_id
client_secret: 应用 client_secret
"""
self.client_id = client_id
self.client_secret = client_secret
async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
**kwargs: 额外参数(如 QQ 需要 redirect_uri
Returns:
AccessTokenBase: 访问令牌
"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
}
params.update(kwargs)
async with aiohttp.ClientSession() as session:
if self.http_method == "POST":
async with session.post(
url=self.access_token_url,
params=params,
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
else:
async with session.get(
url=self.access_token_url,
params=params,
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
async def get_user_info(
self,
access_token: str | AccessTokenBase,
**kwargs
) -> OAuthUserInfoResponse:
"""
获取用户信息
Args:
access_token: 访问令牌
**kwargs: 额外参数(如 QQ 需要 app_id, openid
Returns:
OAuthUserInfoResponse: 用户信息
"""
if isinstance(access_token, AccessTokenBase):
access_token = access_token.access_token
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.user_info_url,
params=self._build_user_info_params(access_token, **kwargs),
headers=self._build_user_info_headers(access_token),
) as resp:
user_data = await resp.json()
return self._parse_user_response(user_data)
# ==================== 钩子方法(子类可覆盖) ====================
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""
构建获取用户信息的请求参数
Args:
access_token: 访问令牌
**kwargs: 额外参数
Returns:
dict: 请求参数
"""
return {}
def _build_user_info_headers(self, access_token: str) -> dict:
"""
构建获取用户信息的请求头
Args:
access_token: 访问令牌
Returns:
dict: 请求头
"""
return {
'accept': 'application/json',
}
def _parse_token_response(self, data: dict) -> AccessTokenBase:
"""
解析 token 响应
Args:
data: API 返回的数据
Returns:
AccessTokenBase: 访问令牌
"""
return AccessTokenBase(access_token=data.get('access_token'))
def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse:
"""
解析用户信息响应
Args:
data: API 返回的数据
Returns:
OAuthUserInfoResponse: 用户信息
"""
return OAuthUserInfoResponse(
code='0',
openid='',
user_data=OAuthUserData(openid=''),
)
# ==================== 导出 ====================
from .github import GithubOAuth, GithubAccessToken, GithubUserData
from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData
__all__ = [
# 共享模型
'AccessTokenBase',
'OAuthUserData',
'OAuthUserInfoResponse',
'OAuthBase',
# GitHub
'GithubOAuth',
'GithubAccessToken',
'GithubUserData',
# QQ
'QQOAuth',
'QQAccessToken',
'QQOpenIDResponse',
'QQUserData',
]

View File

@@ -1,77 +1,127 @@
from pydantic import BaseModel
import aiohttp
"""GitHub OAuth2.0 认证实现"""
from typing import TYPE_CHECKING
class GithubAccessToken(BaseModel):
access_token: str
from pydantic import BaseModel
from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse
if TYPE_CHECKING:
from . import OAuthUserData
class GithubAccessToken(AccessTokenBase):
"""GitHub 访问令牌响应"""
token_type: str
"""令牌类型"""
scope: str
"""授权范围"""
class GithubUserData(BaseModel):
"""GitHub 用户数据"""
login: str
"""用户名"""
id: int
"""用户 ID"""
node_id: str
"""节点 ID"""
avatar_url: str
"""头像 URL"""
gravatar_id: str | None
"""Gravatar ID"""
url: str
"""API URL"""
html_url: str
"""主页 URL"""
followers_url: str
"""粉丝列表 URL"""
following_url: str
"""关注列表 URL"""
gists_url: str
"""Gists 列表 URL"""
starred_url: str
"""星标列表 URL"""
subscriptions_url: str
"""订阅列表 URL"""
organizations_url: str
"""组织列表 URL"""
repos_url: str
"""仓库列表 URL"""
events_url: str
"""事件列表 URL"""
received_events_url: str
"""接收的事件列表 URL"""
type: str
"""用户类型"""
site_admin: bool
"""是否为站点管理员"""
name: str | None
"""显示名称"""
company: str | None
"""公司"""
blog: str | None
"""博客"""
location: str | None
"""位置"""
email: str | None
"""邮箱"""
hireable: bool | None
"""是否可雇佣"""
bio: str | None
"""个人简介"""
twitter_username: str | None
"""Twitter 用户名"""
public_repos: int
"""公开仓库数"""
public_gists: int
"""公开 Gists 数"""
followers: int
"""粉丝数"""
following: int
created_at: str # ISO 8601 format date-time string
updated_at: str # ISO 8601 format date-time string
"""关注数"""
created_at: str
"""创建时间ISO 8601 格式)"""
updated_at: str
"""更新时间ISO 8601 格式)"""
class GithubUserInfoResponse(BaseModel):
"""GitHub 用户信息响应"""
code: str
"""状态码"""
user_data: GithubUserData
"""用户数据"""
async def get_access_token(code: str) -> GithubAccessToken:
async with aiohttp.ClientSession() as session:
async with session.post(
url='https://github.com/login/oauth/access_token',
params={
'client_id': '',
'client_secret': '',
'code': code
},
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return GithubAccessToken(
access_token=access_data.get('access_token'),
token_type=access_data.get('token_type'),
scope=access_data.get('scope')
)
async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse:
if isinstance(access_token, GithubAccessToken):
access_token = access_token.access_token
async with aiohttp.ClientSession() as session:
async with session.get(
url='https://api.github.com/user',
headers={
'accept': 'application/json',
'Authorization': f'token {access_token}'},
) as resp:
user_data = await resp.json()
return GithubUserInfoResponse(**user_data)
class GithubOAuth(OAuthBase):
"""GitHub OAuth2.0 客户端"""
access_token_url = "https://github.com/login/oauth/access_token"
"""获取 Access Token 的 API 地址"""
user_info_url = "https://api.github.com/user"
"""获取用户信息的 API 地址"""
http_method = "POST"
"""获取 token 的 HTTP 方法"""
def _parse_token_response(self, data: dict) -> GithubAccessToken:
"""解析 GitHub token 响应"""
return GithubAccessToken(
access_token=data.get('access_token'),
token_type=data.get('token_type'),
scope=data.get('scope'),
)
def _build_user_info_headers(self, access_token: str) -> dict:
"""构建 GitHub 用户信息请求头"""
return {
'accept': 'application/json',
'Authorization': f'token {access_token}',
}
def _parse_user_response(self, data: dict) -> GithubUserInfoResponse:
"""解析 GitHub 用户信息响应"""
return GithubUserInfoResponse(
code='0' if data.get('login') else '1',
user_data=GithubUserData(**data),
)

View File

@@ -1,7 +1,158 @@
from pydantic import BaseModel
"""QQ OAuth2.0 认证实现"""
import aiohttp
async def get_access_token(
from pydantic import BaseModel
from . import AccessTokenBase, OAuthBase
class QQAccessToken(AccessTokenBase):
"""QQ 访问令牌响应"""
expires_in: int
"""access token 的有效期,单位为秒"""
refresh_token: str
"""用于刷新 access token 的令牌"""
class QQOpenIDResponse(BaseModel):
"""QQ OpenID 响应"""
client_id: str
"""应用的 appid"""
openid: str
"""用户的唯一标识"""
class QQUserData(BaseModel):
"""QQ 用户数据"""
ret: int
"""返回码0 表示成功"""
msg: str
"""返回信息"""
nickname: str | None
"""用户昵称"""
gender: str | None
"""性别"""
figureurl: str | None
"""头像 URL"""
figureurl_1: str | None
"""头像 URL大图"""
figureurl_2: str | None
"""头像 URL更大图"""
figureurl_qq_1: str | None
"""QQ 头像 URL大图"""
figureurl_qq_2: str | None
"""QQ 头像 URL更大图"""
is_yellow_vip: str | None
"""是否黄钻用户"""
vip: str | None
"""是否 VIP 用户"""
yellow_vip_level: str | None
"""黄钻等级"""
level: str | None
"""等级"""
is_yellow_year_vip: str | None
"""是否年费黄钻"""
class QQUserInfoResponse(BaseModel):
"""QQ 用户信息响应"""
code: str
):
...
"""状态码"""
openid: str
"""用户 OpenID"""
user_data: QQUserData
"""用户数据"""
class QQOAuth(OAuthBase):
"""QQ OAuth2.0 客户端"""
access_token_url = "https://graph.qq.com/oauth2.0/token"
"""获取 Access Token 的 API 地址"""
user_info_url = "https://graph.qq.com/user/get_user_info"
"""获取用户信息的 API 地址"""
openid_url = "https://graph.qq.com/oauth2.0/me"
"""获取 OpenID 的 API 地址"""
http_method = "GET"
"""获取 token 的 HTTP 方法"""
async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode
Returns:
QQAccessToken: 访问令牌
文档:
https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token
"""
params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'redirect_uri': redirect_uri,
'fmt': 'json',
'need_openid': 1,
}
async with aiohttp.ClientSession() as session:
async with session.get(url=self.access_token_url, params=params) as access_resp:
access_data = await access_resp.json()
return QQAccessToken(
access_token=access_data.get('access_token'),
expires_in=access_data.get('expires_in'),
refresh_token=access_data.get('refresh_token'),
)
async def get_openid(self, access_token: str) -> QQOpenIDResponse:
"""
获取用户 OpenID
注意:如果在 get_access_token 时传入了 need_openid=1响应中已包含 openid
无需额外调用此接口。此函数用于单独获取 openid 的场景。
Args:
access_token: 访问令牌
Returns:
QQOpenIDResponse: 包含 client_id 和 openid
文档:
https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3
"""
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.openid_url,
params={
'access_token': access_token,
'fmt': 'json',
},
) as resp:
data = await resp.json()
return QQOpenIDResponse(
client_id=data.get('client_id'),
openid=data.get('openid'),
)
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""构建 QQ 用户信息请求参数"""
return {
'access_token': access_token,
'oauth_consumer_key': kwargs.get('app_id', self.client_id),
'openid': kwargs.get('openid', ''),
}
def _parse_user_response(self, data: dict) -> QQUserInfoResponse:
"""解析 QQ 用户信息响应"""
return QQUserInfoResponse(
code='0' if data.get('ret') == 0 else str(data.get('ret')),
openid=data.get('openid', ''),
user_data=QQUserData(**data),
)

View File

@@ -25,12 +25,12 @@ async def login(
# TODO: 验证码校验
# captcha_setting = await Setting.get(
# session,
# and_(Setting.type == "auth", Setting.name == "login_captcha")
# (Setting.type == "auth") & (Setting.name == "login_captcha")
# )
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
# 获取用户信息
current_user = await User.get(session, User.username == login_request.username, fetch_mode="first")
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
# 验证用户是否存在
if not current_user:

View File

@@ -42,10 +42,9 @@ class GroupFactory:
speed_limit=kwargs.get("speed_limit", 0),
)
group = await group.save(session)
# 如果提供了选项参数,创建 GroupOptions
if kwargs.get("create_options", False):
group = await group.save(session, commit=False)
options = GroupOptions(
group_id=group.id,
share_download=kwargs.get("share_download", True),
@@ -55,7 +54,10 @@ class GroupFactory:
select_node=kwargs.get("select_node", False),
advance_delete=kwargs.get("advance_delete", False),
)
await options.save(session)
await options.save(session, commit=False)
await session.commit()
else:
group = await group.save(session)
return group
@@ -88,7 +90,7 @@ class GroupFactory:
speed_limit=0,
)
admin_group = await admin_group.save(session)
admin_group = await admin_group.save(session, commit=False)
# 创建管理员组选项
admin_options = GroupOptions(
@@ -105,7 +107,8 @@ class GroupFactory:
aria2=True,
redirected_source=True,
)
await admin_options.save(session)
await admin_options.save(session, commit=False)
await session.commit()
return admin_group
@@ -140,7 +143,7 @@ class GroupFactory:
speed_limit=1024, # 1MB/s
)
limited_group = await limited_group.save(session)
limited_group = await limited_group.save(session, commit=False)
# 创建限制组选项
limited_options = GroupOptions(
@@ -152,7 +155,8 @@ class GroupFactory:
select_node=False,
advance_delete=False,
)
await limited_options.save(session)
await limited_options.save(session, commit=False)
await session.commit()
return limited_group
@@ -185,7 +189,7 @@ class GroupFactory:
speed_limit=512, # 512KB/s
)
free_group = await free_group.save(session)
free_group = await free_group.save(session, commit=False)
# 创建免费组选项
free_options = GroupOptions(
@@ -197,6 +201,7 @@ class GroupFactory:
select_node=False,
advance_delete=False,
)
await free_options.save(session)
await free_options.save(session, commit=False)
await session.commit()
return free_group

View File

@@ -13,7 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer(
refreshUrl="/api/v1/user/session/refresh",
)
SECRET_KEY = ''
SECRET_KEY: str = ''
async def load_secret_key() -> None:
@@ -26,10 +26,10 @@ async def load_secret_key() -> None:
global SECRET_KEY
async for session in get_session():
setting = await Setting.get(
setting: Setting = await Setting.get(
session,
(Setting.type == "auth") & (Setting.name == "secret_key")
)
) # type: ignore
if setting:
SECRET_KEY = setting.value
@@ -40,7 +40,14 @@ def build_token_payload(
algorithm: str,
expires_delta: timedelta | None = None,
) -> tuple[str, datetime]:
"""构建令牌"""
"""
构建令牌。
:param data: 需要放进 JWT Payload 的字段
:param is_refresh: 是否为刷新令牌
:param algorithm: JWT 签名算法
:param expires_delta: 过期时间
"""
to_encode = data.copy()
@@ -61,8 +68,11 @@ def build_token_payload(
# 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None,
algorithm: str = "HS256") -> AccessTokenBase:
def create_access_token(
data: dict,
expires_delta: timedelta | None = None,
algorithm: str = "HS256"
) -> AccessTokenBase:
"""
生成访问令牌,默认有效期 3 小时。
@@ -73,7 +83,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
:return: 包含密钥本身和过期时间的 `AccessTokenBase`
"""
access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta)
access_token, expire_at = build_token_payload(
data,
False,
algorithm,
expires_delta
)
return AccessTokenBase(
access_token=access_token,
access_expires=expire_at,
@@ -81,8 +96,11 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None,
# 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
algorithm: str = "HS256") -> RefreshTokenBase:
def create_refresh_token(
data: dict,
expires_delta: timedelta | None = None,
algorithm: str = "HS256"
) -> RefreshTokenBase:
"""
生成刷新令牌,默认有效期 30 天。
@@ -93,7 +111,12 @@ def create_refresh_token(data: dict, expires_delta: timedelta | None = None,
:return: 包含密钥本身和过期时间的 `RefreshTokenBase`
"""
refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta)
refresh_token, expire_at = build_token_payload(
data,
True,
algorithm,
expires_delta
)
return RefreshTokenBase(
refresh_token=refresh_token,
refresh_expires=expire_at,

View File

@@ -13,7 +13,7 @@ license_info = {"name": "GPLv3", "url": "https://opensource.org/license/gpl-3.0"
BackendVersion = "0.0.1"
"""后端版本"""
IsPro = False
IsPro: bool = False
mode: str = os.getenv('MODE', 'master')
"""运行模式"""

View File

@@ -1,4 +1,5 @@
import secrets
from typing import Literal
from loguru import logger
from argon2 import PasswordHasher
@@ -11,7 +12,23 @@ from pydantic import BaseModel, Field
from utils.JWT import SECRET_KEY
from utils.conf import appmeta
_ph = PasswordHasher()
# FIRST RECOMMENDED option per RFC 9106.
_ph_lowmem = PasswordHasher(
salt_len=16,
hash_len=32,
time_cost=3,
memory_cost=65536, # 64 MiB
parallelism=4,
)
# SECOND RECOMMENDED option per RFC 9106.
_ph_highmem = PasswordHasher(
salt_len=16,
hash_len=32,
time_cost=1,
memory_cost=2097152, # 2 GiB
parallelism=4,
)
class PasswordStatus(StrEnum):
"""密码校验状态枚举"""
@@ -48,7 +65,8 @@ class Password:
@staticmethod
def generate(
length: int = 8
length: int = 8,
url_safe: bool = False
) -> str:
"""
生成指定长度的随机密码。
@@ -58,7 +76,16 @@ class Password:
:return: 随机密码
:rtype: str
"""
return secrets.token_hex(length)
if url_safe:
return secrets.token_urlsafe(length)
else:
return secrets.token_hex(length)
@staticmethod
def generate_hex(
length: int = 8
) -> bytes:
return secrets.token_bytes(length)
@staticmethod
def hash(
@@ -72,7 +99,7 @@ class Password:
:param password: 需要哈希的原始密码
:return: Argon2 哈希字符串
"""
return _ph.hash(password)
return _ph_lowmem.hash(password)
@staticmethod
def verify(
@@ -87,21 +114,16 @@ class Password:
:return: 如果密码匹配返回 True, 否则返回 False
"""
try:
# verify 函数会自动解析 stored_password 中的盐和参数
_ph.verify(hash, password)
_ph_lowmem.verify(hash, password)
# 检查哈希参数是否已过时。如果返回True
# 意味着你应该使用新的参数重新哈希密码并更新存储。
# 这是一个很好的实践,可以随着时间推移增强安全性。
if _ph.check_needs_rehash(hash):
# 检查哈希参数是否已过时
if _ph_lowmem.check_needs_rehash(hash):
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
return PasswordStatus.EXPIRED
return PasswordStatus.VALID
except VerifyMismatchError:
# 这是预期的异常,当密码不匹配时触发。
return PasswordStatus.INVALID
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
@staticmethod
async def generate_totp(