From d2c914cff89e566b0b7ab7a2655ae72905f888f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Mon, 12 Jan 2026 18:07:44 +0800 Subject: [PATCH] Refactor and enhance OAuth2.0 implementation; update models and routes - Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity. - Added OAUTH type to SettingsType enum. - Cleaned up imports in webdav.py. - Updated admin router to improve summary data retrieval and response handling. - Enhanced file management routes with better condition handling and user storage updates. - Improved group management routes by optimizing data retrieval. - Refined task management routes for better condition handling. - Updated user management routes to streamline access token retrieval. - Implemented a new captcha verification structure with abstract base class. - Removed deprecated env.md file and replaced with a new structured version. - Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations. - Enhanced password management with improved hashing strategies. - Added detailed comments and documentation throughout the codebase for clarity. --- .claude/settings.local.json | 9 + .xml | 4407 ------------------------ service/env.md => env.md | 2 +- main.py | 16 +- middleware/auth.py | 8 +- models/migration.py | 3 +- models/mixin/table.py | 170 +- models/model_base.py | 12 +- models/setting.py | 1 + models/webdav.py | 2 +- routers/api/v1/admin/__init__.py | 21 +- routers/api/v1/admin/file/__init__.py | 22 +- routers/api/v1/admin/group/__init__.py | 5 +- routers/api/v1/admin/task/__init__.py | 8 +- routers/api/v1/admin/user/__init__.py | 4 +- routers/api/v1/file/__init__.py | 15 +- routers/api/v1/site/__init__.py | 3 +- routers/api/v1/user/__init__.py | 9 +- service/captcha/__init__.py | 36 +- service/captcha/gcaptcha.py | 24 +- service/captcha/turnstile.py | 24 +- service/oauth/__init__.py | 223 ++ service/oauth/github.py | 122 +- service/oauth/qq.py | 159 +- service/user/login.py | 4 +- tests/fixtures/groups.py | 23 +- utils/JWT/__init__.py | 43 +- utils/conf/appmeta.py | 2 +- utils/password/pwd.py | 46 +- 29 files changed, 814 insertions(+), 4609 deletions(-) create mode 100644 .claude/settings.local.json delete mode 100644 .xml rename service/env.md => env.md (92%) create mode 100644 service/oauth/__init__.py diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..b06f87e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(git rev-parse:*)", + "Bash(findstr:*)", + "Bash(find:*)" + ] + } +} diff --git a/.xml b/.xml deleted file mode 100644 index 81272a6..0000000 --- a/.xml +++ /dev/null @@ -1,4407 +0,0 @@ - - - - - - C:\Users\Administrator\Documents\Code\Server\middleware - C:\Users\Administrator\Documents\Code\Server\models - C:\Users\Administrator\Documents\Code\Server\routers - C:\Users\Administrator\Documents\Code\Server\service - C:\Users\Administrator\Documents\Code\Server\utils - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/service/env.md b/env.md similarity index 92% rename from service/env.md rename to env.md index d5e9bad..29e3b15 100644 --- a/service/env.md +++ b/env.md @@ -9,4 +9,4 @@ - `REDIS_PORT`: Redis 端口 - `REDIS_PASSWORD`: Redis 密码 - `REDIS_DB`: Redis 数据库 -- `REDIS_PROTOCOL` \ No newline at end of file +- `REDIS_PROTOCOL`: Redis 协议 diff --git a/main.py b/main.py index f1a3335..9052746 100644 --- a/main.py +++ b/main.py @@ -33,8 +33,22 @@ app = FastAPI( openapi_url="/openapi.json" if appmeta.debug else None, ) +# 添加跨域 CORS 中间件,仅在调试模式下启用,以允许所有来源访问 API +if appmeta.debug: + from fastapi.middleware.cors import CORSMiddleware + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + @app.exception_handler(Exception) -async def handle_unexpected_exceptions(request: Request, exc: Exception) -> NoReturn: +async def handle_unexpected_exceptions( + request: Request, + exc: Exception +) -> NoReturn: """ 捕获所有未经处理的 FastAPI 异常,防止敏感信息泄露。 """ diff --git a/middleware/auth.py b/middleware/auth.py index b57a664..e39d435 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -60,10 +60,10 @@ def verify_download_token(token: str) -> tuple[str, UUID, UUID] | None: try: payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=["HS256"]) if payload.get("type") != "download": - return None + http_exceptions.raise_unauthorized("Download token required") jti = payload.get("jti") if not jti: - return None + http_exceptions.raise_unauthorized("Download token required") return jti, UUID(payload["file_id"]), UUID(payload["owner_id"]) - except (jwt.ExpiredSignatureError, jwt.InvalidTokenError): - return None \ No newline at end of file + except jwt.InvalidTokenError: + http_exceptions.raise_unauthorized("Download token required") \ No newline at end of file diff --git a/models/migration.py b/models/migration.py index bc747f5..3733f24 100644 --- a/models/migration.py +++ b/models/migration.py @@ -120,7 +120,6 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti async def init_default_settings() -> None: from .setting import Setting from .database import get_session - from sqlalchemy import and_ log.info('初始化设置...') @@ -128,7 +127,7 @@ async def init_default_settings() -> None: # 检查是否已经存在版本设置 ver = await Setting.get( session, - and_(Setting.type == SettingsType.VERSION, Setting.name == f"db_version_{BackendVersion}") + (Setting.type == SettingsType.VERSION) & (Setting.name == f"db_version_{BackendVersion}") ) if ver and ver.value == "installed": return diff --git a/models/mixin/table.py b/models/mixin/table.py index c133ff5..419f298 100644 --- a/models/mixin/table.py +++ b/models/mixin/table.py @@ -218,7 +218,7 @@ class TableBaseMixin(AsyncAttrs): ) @classmethod - async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]: + async def add(cls: type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True, commit: bool = True) -> T | list[T]: """ 向数据库中添加一个新的或多个新的记录. @@ -230,6 +230,8 @@ class TableBaseMixin(AsyncAttrs): session (AsyncSession): 用于数据库操作的异步会话对象. instances (T | list[T]): 要添加的单个模型实例或模型实例列表. refresh (bool): 如果为 True, 将在提交后刷新实例以同步数据库状态. 默认为 True. + commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, + 之后需要手动调用 `session.commit()`。默认为 True. Returns: T | list[T]: 已添加并(可选地)刷新的一个或多个模型实例. @@ -244,6 +246,11 @@ class TableBaseMixin(AsyncAttrs): # 添加单个实例 item3 = Item(name="Cherry") added_item = await Item.add(session, item3) + + # 批量操作,减少提交次数 + await Item.add(session, [item1, item2], commit=False) + await Item.add(session, [item3, item4], commit=False) + await session.commit() """ is_list = False if isinstance(instances, list): @@ -252,7 +259,10 @@ class TableBaseMixin(AsyncAttrs): else: session.add(instances) - await session.commit() + if commit: + await session.commit() + else: + await session.flush() if refresh: if is_list: @@ -266,15 +276,16 @@ class TableBaseMixin(AsyncAttrs): async def save( self: T, session: AsyncSession, - load: RelationshipInfo | None = None, - refresh: bool = True + load: RelationshipInfo | list[RelationshipInfo] | None = None, + refresh: bool = True, + commit: bool = True ) -> T: """ 保存(插入或更新)当前模型实例到数据库. 这是一个实例方法,它将当前对象添加到会话中并提交更改。 可以用于创建新记录或更新现有记录。还可以选择在保存后 - 预加载(eager load)一个关联关系. + 预加载(eager load)一个或多个关联关系. **重要**:调用此方法后,session中的所有对象都会过期(expired)。 如果需要继续使用该对象,必须使用返回值: @@ -287,6 +298,14 @@ class TableBaseMixin(AsyncAttrs): # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 await client.save(session, refresh=False) + # ✅ 正确:批量操作,减少提交次数 + await item1.save(session, commit=False) + await item2.save(session, commit=False) + await session.commit() + + # ✅ 正确:批量操作并预加载多个关联关系 + user = await user.save(session, load=[User.group, User.tags]) + # ❌ 错误:需要返回值但未使用 await client.save(session) return client # client 对象已过期 @@ -294,16 +313,22 @@ class TableBaseMixin(AsyncAttrs): Args: session (AsyncSession): 用于数据库操作的异步会话对象. - load (Relationship | None): 可选的,指定在保存和刷新后要预加载的关联属性. - 例如 `User.posts`. + load (Relationship | list[Relationship] | None): 可选的,指定在保存和刷新后要预加载的关联属性. + 可以是单个关系或关系列表. + 例如 `User.posts` 或 `[User.group, User.tags]`. refresh (bool): 是否在保存后刷新对象。如果不需要使用返回值, 设为 False 可节省一次数据库查询。默认为 True. + commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, + 之后需要手动调用 `session.commit()`。默认为 True. Returns: T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. """ session.add(self) - await session.commit() + if commit: + await session.commit() + else: + await session.flush() if not refresh: return self @@ -324,8 +349,9 @@ class TableBaseMixin(AsyncAttrs): extra_data: dict[str, Any] | None = None, exclude_unset: bool = True, exclude: set[str] | None = None, - load: RelationshipInfo | None = None, - refresh: bool = True + load: RelationshipInfo | list[RelationshipInfo] | None = None, + refresh: bool = True, + commit: bool = True ) -> T: """ 使用另一个模型实例或字典中的数据来更新当前实例. @@ -348,6 +374,14 @@ class TableBaseMixin(AsyncAttrs): # ✅ 正确:不需要返回值时,指定 refresh=False 节省性能 await client.update(session, update_data, refresh=False) + # ✅ 正确:批量操作,减少提交次数 + await user1.update(session, data1, commit=False) + await user2.update(session, data2, commit=False) + await session.commit() + + # ✅ 正确:批量操作并预加载多个关联关系 + user = await user.update(session, data, load=[User.group, User.tags]) + # ❌ 错误:需要返回值但未使用 await client.update(session, update_data) return client # client 对象已过期 @@ -360,10 +394,13 @@ class TableBaseMixin(AsyncAttrs): exclude_unset (bool): 如果为 True, `other` 对象中未设置(即值为 None 或未提供) 的字段将被忽略. 默认为 True. exclude (set[str] | None): 要从更新中排除的字段名集合。例如 {'permission'}. - load (RelationshipInfo | None): 可选的,指定在更新和刷新后要预加载的关联属性. - 例如 `User.permission`. + load (Relationship | list[Relationship] | None): 可选的,指定在更新和刷新后要预加载的关联属性. + 可以是单个关系或关系列表. + 例如 `User.permission` 或 `[User.group, User.tags]`. refresh (bool): 是否在更新后刷新对象。如果不需要使用返回值, 设为 False 可节省一次数据库查询。默认为 True. + commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, + 之后需要手动调用 `session.commit()`。默认为 True. Returns: T: 如果 refresh=True,返回已刷新的模型实例;否则返回未刷新的 self. @@ -374,7 +411,10 @@ class TableBaseMixin(AsyncAttrs): ) session.add(self) - await session.commit() + if commit: + await session.commit() + else: + await session.flush() if not refresh: return self @@ -388,33 +428,82 @@ class TableBaseMixin(AsyncAttrs): return self @classmethod - async def delete(cls: type[T], session: AsyncSession, instances: T | list[T]) -> None: + async def delete( + cls: type[T], + session: AsyncSession, + instances: T | list[T] | None = None, + *, + condition: BinaryExpression | ClauseElement | None = None, + commit: bool = True + ) -> int: """ - 从数据库中删除一个或多个记录. + 从数据库中删除记录. + + 支持两种删除方式: + 1. 实例删除:传入 instances 参数,先加载再删除 + 2. 条件删除:传入 condition 参数,直接 SQL 删除(更高效) Args: session (AsyncSession): 用于数据库操作的异步会话对象. - instances (T | list[T]): 要删除的单个模型实例或模型实例列表. + instances (T | list[T] | None): 要删除的单个模型实例或模型实例列表(可选). + condition (BinaryExpression | ClauseElement | None): 删除条件(可选,与 instances 二选一). + commit (bool): 是否提交事务。设为 False 可在批量操作时减少提交次数, + 之后需要手动调用 `session.commit()`。默认为 True. Returns: - None + int: 删除的记录数量 Usage: + # 实例删除 item_to_delete = await Item.get(session, Item.id == 1) if item_to_delete: - await Item.delete(session, item_to_delete) + deleted_count = await Item.delete(session, item_to_delete) - items_to_delete = await Item.get(session, Item.name.in_(["Apple", "Banana"]), fetch_mode="all") - if items_to_delete: - await Item.delete(session, items_to_delete) + # 条件删除(更高效,无需加载实例) + deleted_count = await Item.delete( + session, + condition=(Item.status == "inactive") & (Item.created_at < cutoff_date) + ) + + # 批量删除后手动提交 + await Item.delete(session, item1, commit=False) + await Item.delete(session, item2, commit=False) + await session.commit() """ + # 条件删除模式 + if condition is not None: + from sqlmodel import delete as sql_delete + + if instances is not None: + raise ValueError("不能同时指定 instances 和 condition") + + # 执行条件删除 + stmt = sql_delete(cls).where(condition) + result = await session.exec(stmt) + deleted_count = result.rowcount + + if commit: + await session.commit() + + return deleted_count + + # 实例删除模式(原有逻辑) + if instances is None: + raise ValueError("必须指定 instances 或 condition") + + deleted_count = 0 if isinstance(instances, list): for instance in instances: await session.delete(instance) + deleted_count += 1 else: await session.delete(instances) + deleted_count = 1 - await session.commit() + if commit: + await session.commit() + + return deleted_count @classmethod def _build_time_filters( @@ -458,7 +547,7 @@ class TableBaseMixin(AsyncAttrs): fetch_mode: Literal["one", "first", "all"] = "first", join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, options: list | None = None, - load: RelationshipInfo | None = None, + load: RelationshipInfo | list[RelationshipInfo] | None = None, order_by: list[ClauseElement] | None = None, filter: BinaryExpression | ClauseElement | None = None, with_for_update: bool = False, @@ -491,8 +580,9 @@ class TableBaseMixin(AsyncAttrs): 例如 `User` 或 `(Profile, User.id == Profile.user_id)`. options (list | None): SQLAlchemy 查询选项列表, 通常用于预加载关联数据, 例如 `[selectinload(User.posts)]`. - load (Relationship | None): `selectinload` 的快捷方式,用于预加载单个关联关系. - 例如 `User.profile`. + load (Relationship | list[Relationship] | None): `selectinload` 的快捷方式,用于预加载关联关系. + 可以是单个关系或关系列表. + 例如 `User.profile` 或 `[User.group, User.tags]`. order_by (list[ClauseElement] | None): 用于排序的排序列或表达式的列表. 例如 `[User.name.asc(), User.created_at.desc()]`. filter (BinaryExpression | ClauseElement | None): 附加的过滤条件. @@ -595,9 +685,15 @@ class TableBaseMixin(AsyncAttrs): statement = statement.options(*options) if load: + # 标准化为列表 + load_list = load if isinstance(load, list) else [load] + # 处理多态加载 if load_polymorphic is not None: - target_class = load.property.mapper.class_ + # 多态加载只支持单个关系 + if len(load_list) > 1: + raise ValueError("load_polymorphic 仅支持单个关系") + target_class = load_list[0].property.mapper.class_ # 检查目标类是否继承自 PolymorphicBaseMixin if not issubclass(target_class, PolymorphicBaseMixin): @@ -609,7 +705,7 @@ class TableBaseMixin(AsyncAttrs): if load_polymorphic == 'all': # 两阶段查询:获取实际关联的多态类型 subclasses_to_load = await cls._resolve_polymorphic_subclasses( - session, condition, load, target_class + session, condition, load_list[0], target_class ) else: subclasses_to_load = load_polymorphic @@ -618,12 +714,14 @@ class TableBaseMixin(AsyncAttrs): # 关键:selectin_polymorphic 必须作为 selectinload 的链式子选项 # 参考: https://docs.sqlalchemy.org/en/20/orm/queryguide/relationships.html#polymorphic-eager-loading statement = statement.options( - selectinload(load).selectin_polymorphic(subclasses_to_load) + selectinload(load_list[0]).selectin_polymorphic(subclasses_to_load) ) else: - statement = statement.options(selectinload(load)) + statement = statement.options(selectinload(load_list[0])) else: - statement = statement.options(selectinload(load)) + # 为每个关系添加 selectinload + for rel in load_list: + statement = statement.options(selectinload(rel)) if order_by is not None: statement = statement.order_by(*order_by) @@ -796,7 +894,7 @@ class TableBaseMixin(AsyncAttrs): *, join: type[T] | tuple[type[T], _OnClauseArgument] | None = None, options: list | None = None, - load: RelationshipInfo | None = None, + load: RelationshipInfo | list[RelationshipInfo] | None = None, order_by: list[ClauseElement] | None = None, filter: BinaryExpression | ClauseElement | None = None, table_view: TableViewRequest | None = None, @@ -865,7 +963,7 @@ class TableBaseMixin(AsyncAttrs): return ListResponse(count=total_count, items=items) @classmethod - async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | None = None) -> T: + async def get_exist_one(cls: type[T], session: AsyncSession, id: int, load: RelationshipInfo | list[RelationshipInfo] | None = None) -> T: """ 根据主键 ID 获取一个存在的记录, 如果不存在则抛出 404 异常. @@ -875,7 +973,8 @@ class TableBaseMixin(AsyncAttrs): Args: session (AsyncSession): 用于数据库操作的异步会话对象. id (int): 要查找的记录的主键 ID. - load (Relationship | None): 可选的,用于预加载的关联属性. + load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性. + 可以是单个关系或关系列表. Returns: T: 找到的模型实例. @@ -903,7 +1002,7 @@ class UUIDTableBaseMixin(TableBaseMixin): @override @classmethod - async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | None = None) -> T: + async def get_exist_one(cls: type[T], session: AsyncSession, id: uuid.UUID, load: Relationship | list[Relationship] | None = None) -> T: """ 根据 UUID 主键获取一个存在的记录, 如果不存在则抛出 404 异常. @@ -913,7 +1012,8 @@ class UUIDTableBaseMixin(TableBaseMixin): Args: session (AsyncSession): 用于数据库操作的异步会话对象. id (uuid.UUID): 要查找的记录的 UUID 主键. - load (Relationship | None): 可选的,用于预加载的关联属性. + load (Relationship | list[Relationship] | None): 可选的,用于预加载的关联属性. + 可以是单个关系或关系列表. Returns: T: 找到的模型实例. diff --git a/models/model_base.py b/models/model_base.py index a2f76f5..913a4e5 100644 --- a/models/model_base.py +++ b/models/model_base.py @@ -79,9 +79,8 @@ class VersionInfo(SQLModelBase): commit: str """提交哈希""" - -class AdminSummaryData(SQLModelBase): - """管理员概况数据""" +class AdminSummaryResponse(ResponseBase): + """管理员概况响应""" metrics_summary: MetricsSummary """统计摘要""" @@ -95,13 +94,6 @@ class AdminSummaryData(SQLModelBase): version: VersionInfo """版本信息""" - -class AdminSummaryResponse(ResponseBase): - """管理员概况响应""" - - data: AdminSummaryData | None = None - """响应数据""" - class MCPMethod(StrEnum): """MCP 方法枚举""" diff --git a/models/setting.py b/models/setting.py index 10dfeb1..5e6b5cf 100644 --- a/models/setting.py +++ b/models/setting.py @@ -104,6 +104,7 @@ class SettingsType(StrEnum): MAIL = "mail" MAIL_TEMPLATE = "mail_template" MOBILE = "mobile" + OAUTH = "oauth" PATH = "path" PREVIEW = "preview" PWA = "pwa" diff --git a/models/webdav.py b/models/webdav.py index 97098e7..8d7f9ce 100644 --- a/models/webdav.py +++ b/models/webdav.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING from uuid import UUID -from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime +from sqlmodel import Field, Relationship, UniqueConstraint from .base import SQLModelBase from .mixin import TableBaseMixin diff --git a/routers/api/v1/admin/__init__.py b/routers/api/v1/admin/__init__.py index 1ad8eee..455ddeb 100644 --- a/routers/api/v1/admin/__init__.py +++ b/routers/api/v1/admin/__init__.py @@ -2,14 +2,12 @@ from datetime import datetime, timedelta from fastapi import APIRouter, Depends from loguru import logger as l -from sqlalchemy import and_ from middleware.auth import admin_required from middleware.dependencies import SessionDep from models import ( User, ResponseBase, Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo, - AdminSummaryData, ) from models.base import SQLModelBase from models.setting import ( @@ -75,8 +73,8 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse: Returns: AdminSummaryResponse: 包含站点概况信息的响应模型。 """ - # 统计最近 12 天的数据 - days_count = 12 + # 统计最近 14 天的数据 + days_count = 14 now = datetime.now() today_start = now.replace(hour=0, minute=0, second=0, microsecond=0) @@ -135,7 +133,7 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse: site_urls: list[str] = [] site_url_setting = await Setting.get( session, - and_(Setting.type == SettingsType.BASIC, Setting.name == "siteURL"), + (Setting.type == SettingsType.BASIC) & (Setting.name == "siteURL"), ) if site_url_setting and site_url_setting.value: site_urls.append(site_url_setting.value) @@ -156,15 +154,13 @@ async def router_admin_get_summary(session: SessionDep) -> AdminSummaryResponse: commit="dev", ) - data = AdminSummaryData( + return AdminSummaryResponse( metrics_summary=metrics_summary, site_urls=site_urls, license=license_info, version=version_info, ) - return AdminSummaryResponse(data=data) - @admin_router.get( path='/news', summary='获取社区新闻', @@ -203,7 +199,7 @@ async def router_admin_update_settings( for item in request.settings: existing = await Setting.get( session, - and_(Setting.type == item.type, Setting.name == item.name) + (Setting.type == item.type) & (Setting.name == item.name) ) if existing: @@ -245,7 +241,12 @@ async def router_admin_get_settings( if name: conditions.append(Setting.name == name) - condition = and_(*conditions) if conditions else None + if conditions: + condition = conditions[0] + for c in conditions[1:]: + condition = condition & c + else: + condition = None settings = await Setting.get(session, condition, fetch_mode="all") diff --git a/routers/api/v1/admin/file/__init__.py b/routers/api/v1/admin/file/__init__.py index 0296af2..df53574 100644 --- a/routers/api/v1/admin/file/__init__.py +++ b/routers/api/v1/admin/file/__init__.py @@ -5,7 +5,6 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import FileResponse from loguru import logger as l -from sqlalchemy import and_ from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep @@ -51,7 +50,12 @@ async def router_admin_get_file_list( if keyword: conditions.append(Object.name.ilike(f"%{keyword}%")) - condition = and_(*conditions) if len(conditions) > 1 else conditions[0] + if len(conditions) > 1: + condition = conditions[0] + for c in conditions[1:]: + condition = condition & c + else: + condition = conditions[0] result = await Object.get_with_count(session, condition, table_view=table_view, load=Object.owner) # 构建响应 @@ -197,13 +201,15 @@ async def router_admin_delete_file( except Exception as e: l.warning(f"删除物理文件失败: {e}") - # 更新用户存储量 - owner = await User.get(session, User.id == owner_id) - if owner: - owner.storage = max(0, owner.storage - file_size) - await owner.save(session) + # 更新用户存储量(使用 SQL UPDATE 直接更新,无需加载实例) + from sqlmodel import update as sql_update + stmt = sql_update(User).where(User.id == owner_id).values( + storage=max(0, User.storage - file_size) + ) + await session.exec(stmt) - await Object.delete(session, file_obj) + # 使用条件删除 + await Object.delete(session, condition=Object.id == file_obj.id) l.info(f"管理员删除了文件: {file_name}") return ResponseBase(data={"deleted": True}) \ No newline at end of file diff --git a/routers/api/v1/admin/group/__init__.py b/routers/api/v1/admin/group/__init__.py index d010f27..7e58f4f 100644 --- a/routers/api/v1/admin/group/__init__.py +++ b/routers/api/v1/admin/group/__init__.py @@ -63,12 +63,13 @@ async def router_admin_get_group( :param group_id: 用户组UUID :return: 用户组详情 """ - group = await Group.get(session, Group.id == group_id, load=Group.options) + group = await Group.get(session, Group.id == group_id, load=[Group.options, Group.policies]) if not group: raise HTTPException(status_code=404, detail="用户组不存在") - policies = await group.awaitable_attrs.policies + # 直接访问已加载的关系,无需额外查询 + policies = group.policies user_count = await User.count(session, User.group_id == group_id) response = GroupDetailResponse.from_group(group, user_count, policies) diff --git a/routers/api/v1/admin/task/__init__.py b/routers/api/v1/admin/task/__init__.py index 558c1c8..e4c8577 100644 --- a/routers/api/v1/admin/task/__init__.py +++ b/routers/api/v1/admin/task/__init__.py @@ -2,7 +2,6 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from loguru import logger as l -from sqlalchemy import and_ from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep @@ -43,7 +42,12 @@ async def router_admin_get_task_list( if status: conditions.append(Task.status == status) - condition = and_(*conditions) if conditions else None + if conditions: + condition = conditions[0] + for c in conditions[1:]: + condition = condition & c + else: + condition = None result = await Task.get_with_count(session, condition, table_view=table_view, load=Task.user) items: list[TaskSummary] = [] diff --git a/routers/api/v1/admin/user/__init__.py b/routers/api/v1/admin/user/__init__.py index 5357799..f7f5f77 100644 --- a/routers/api/v1/admin/user/__init__.py +++ b/routers/api/v1/admin/user/__init__.py @@ -2,7 +2,7 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from loguru import logger as l -from sqlalchemy import func, and_ +from sqlalchemy import func from middleware.auth import admin_required from middleware.dependencies import SessionDep, TableViewRequestDep @@ -198,7 +198,7 @@ async def router_admin_calibrate_storage( from sqlmodel import select result = await session.execute( select(func.sum(Object.size), func.count(Object.id)).where( - and_(Object.owner_id == user_id, Object.type == ObjectType.FILE) + (Object.owner_id == user_id) & (Object.type == ObjectType.FILE) ) ) row = result.one() diff --git a/routers/api/v1/file/__init__.py b/routers/api/v1/file/__init__.py index 9eabfb0..81fee7b 100644 --- a/routers/api/v1/file/__init__.py +++ b/routers/api/v1/file/__init__.py @@ -233,7 +233,7 @@ async def upload_chunk( policy_id=upload_session.policy_id, reference_count=1, ) - physical_file = await physical_file.save(session) + physical_file = await physical_file.save(session, commit=False) # 创建 Object 记录 file_object = Object( @@ -246,11 +246,18 @@ async def upload_chunk( owner_id=user_id, policy_id=upload_session.policy_id, ) - file_object = await file_object.save(session) + file_object = await file_object.save(session, commit=False) file_object_id = file_object.id - # 删除上传会话 - await UploadSession.delete(session, upload_session) + # 删除上传会话(使用条件删除) + await UploadSession.delete( + session, + condition=UploadSession.id == upload_session.id, + commit=False + ) + + # 统一提交所有更改 + await session.commit() l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}") diff --git a/routers/api/v1/site/__init__.py b/routers/api/v1/site/__init__.py index b05f02e..69b0db6 100644 --- a/routers/api/v1/site/__init__.py +++ b/routers/api/v1/site/__init__.py @@ -1,5 +1,4 @@ from fastapi import APIRouter -from sqlalchemy import and_ from middleware.dependencies import SessionDep from models import ResponseBase, Setting, SettingsType, SiteConfigResponse @@ -55,5 +54,5 @@ async def router_site_config(session: SessionDep) -> SiteConfigResponse: dict: The site configuration. """ return SiteConfigResponse( - title=await Setting.get(session, and_(Setting.type == SettingsType.BASIC, Setting.name == "siteName")), + title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")), ) \ No newline at end of file diff --git a/routers/api/v1/user/__init__.py b/routers/api/v1/user/__init__.py index 5b07530..f951ed9 100644 --- a/routers/api/v1/user/__init__.py +++ b/routers/api/v1/user/__init__.py @@ -3,7 +3,6 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy import and_ from webauthn import generate_registration_options from webauthn.helpers import options_to_json_dict from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired @@ -115,7 +114,7 @@ async def router_user_register( # 2. 获取默认用户组(从设置中读取 UUID) default_group_setting: models.Setting | None = await models.Setting.get( session, - and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group") + (models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group") ) if default_group_setting is None or not default_group_setting.value: logger.error("默认用户组不存在") @@ -352,18 +351,18 @@ async def router_user_authn_start( # TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等 authn_setting = await models.Setting.get( session, - and_(models.Setting.type == "authn", models.Setting.name == "authn_enabled") + (models.Setting.type == "authn") & (models.Setting.name == "authn_enabled") ) if not authn_setting or authn_setting.value != "1": raise HTTPException(status_code=400, detail="WebAuthn is not enabled") site_url_setting = await models.Setting.get( session, - and_(models.Setting.type == "basic", models.Setting.name == "siteURL") + (models.Setting.type == "basic") & (models.Setting.name == "siteURL") ) site_title_setting = await models.Setting.get( session, - and_(models.Setting.type == "basic", models.Setting.name == "siteTitle") + (models.Setting.type == "basic") & (models.Setting.name == "siteTitle") ) options = generate_registration_options( diff --git a/service/captcha/__init__.py b/service/captcha/__init__.py index a5565fb..9ccccc8 100644 --- a/service/captcha/__init__.py +++ b/service/captcha/__init__.py @@ -1,5 +1,39 @@ +import abc +import aiohttp + from pydantic import BaseModel +from .gcaptcha import GCaptcha +from .turnstile import TurnstileCaptcha + + class CaptchaRequestBase(BaseModel): + """验证码验证请求""" token: str - secret: str \ No newline at end of file + """验证 token""" + secret: str + """验证密钥""" + + +class CaptchaBase(abc.ABC): + """验证码验证器抽象基类""" + + verify_url: str + """验证 API 地址(子类必须定义)""" + + async def verify_captcha(self, request: CaptchaRequestBase) -> bool: + """ + 验证 token 是否有效。 + + :return: 如果验证成功返回 True,否则返回 False + :rtype: bool + """ + payload = request.model_dump() + + async with aiohttp.ClientSession() as session: + async with session.post(self.verify_url, data=payload) as response: + if response.status != 200: + return False + + result = await response.json() + return result.get('success', False) \ No newline at end of file diff --git a/service/captcha/gcaptcha.py b/service/captcha/gcaptcha.py index bb9e2ac..b30deb1 100644 --- a/service/captcha/gcaptcha.py +++ b/service/captcha/gcaptcha.py @@ -1,21 +1,7 @@ -import aiohttp +from . import CaptchaBase -from . import CaptchaRequestBase -async def verify_captcha(request: CaptchaRequestBase) -> bool: - """ - 验证 Google reCAPTCHA v2/v3 的 token 是否有效。 - - :return: 如果验证成功返回 True,否则返回 False - :rtype: bool - """ - verify_url = "https://www.google.com/recaptcha/api/siteverify" - payload = request.model_dump() - - async with aiohttp.ClientSession() as session: - async with session.post(verify_url, data=payload) as response: - if response.status != 200: - return False - - result = await response.json() - return result.get('success', False) \ No newline at end of file +class GCaptcha(CaptchaBase): + """Google reCAPTCHA v2/v3 验证器""" + + verify_url = "https://www.google.com/recaptcha/api/siteverify" \ No newline at end of file diff --git a/service/captcha/turnstile.py b/service/captcha/turnstile.py index eba255f..8ab43fc 100644 --- a/service/captcha/turnstile.py +++ b/service/captcha/turnstile.py @@ -1,21 +1,7 @@ -import aiohttp +from . import CaptchaBase -from . import CaptchaRequestBase -async def verify_captcha(request: CaptchaRequestBase) -> bool: - """ - 验证 Turnstile 的 token 是否有效。 - - :return: 如果验证成功返回 True,否则返回 False - :rtype: bool - """ - verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify" - payload = request.model_dump() - - async with aiohttp.ClientSession() as session: - async with session.post(verify_url, data=payload) as response: - if response.status != 200: - return False - - result = await response.json() - return result.get('success', False) \ No newline at end of file +class TurnstileCaptcha(CaptchaBase): + """Cloudflare Turnstile 验证器""" + + verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify" \ No newline at end of file diff --git a/service/oauth/__init__.py b/service/oauth/__init__.py new file mode 100644 index 0000000..c396e7b --- /dev/null +++ b/service/oauth/__init__.py @@ -0,0 +1,223 @@ +""" +OAuth2.0 认证模块 + +提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。 +""" +import abc +import aiohttp + +from pydantic import BaseModel + + +# ==================== 共享数据模型 ==================== + +class AccessTokenBase(BaseModel): + """访问令牌基类""" + access_token: str + """访问令牌""" + + +class OAuthUserData(BaseModel): + """OAuth 用户数据通用 DTO""" + openid: str + """用户唯一标识(GitHub 为 id,QQ 为 openid)""" + nickname: str | None + """用户昵称""" + avatar_url: str | None + """头像 URL""" + email: str | None + """邮箱""" + bio: str | None + """个人简介""" + + +class OAuthUserInfoResponse(BaseModel): + """OAuth 用户信息响应""" + code: str + """状态码""" + openid: str + """用户唯一标识""" + user_data: OAuthUserData + """用户数据""" + + +# ==================== OAuth2.0 抽象基类 ==================== + +class OAuthBase(abc.ABC): + """ + OAuth2.0 客户端抽象基类 + + 子类需要定义以下类属性: + - access_token_url: 获取 Access Token 的 API 地址 + - user_info_url: 获取用户信息的 API 地址 + - http_method: 获取 token 的 HTTP 方法(POST 或 GET) + """ + + # 子类必须定义的类属性 + access_token_url: str + """获取 Access Token 的 API 地址""" + + user_info_url: str + """获取用户信息的 API 地址""" + + http_method: str = "POST" + """获取 token 的 HTTP 方法:POST 或 GET""" + + # 实例属性(构造函数传入) + client_id: str + client_secret: str + + def __init__(self, client_id: str, client_secret: str) -> None: + """ + 初始化 OAuth 客户端 + + Args: + client_id: 应用 client_id + client_secret: 应用 client_secret + """ + self.client_id = client_id + self.client_secret = client_secret + + async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase: + """ + 通过 Authorization Code 获取 Access Token + + Args: + code: 授权码 + **kwargs: 额外参数(如 QQ 需要 redirect_uri) + + Returns: + AccessTokenBase: 访问令牌 + """ + params = { + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': code, + } + params.update(kwargs) + + async with aiohttp.ClientSession() as session: + if self.http_method == "POST": + async with session.post( + url=self.access_token_url, + params=params, + headers={'accept': 'application/json'}, + ) as access_resp: + access_data = await access_resp.json() + return self._parse_token_response(access_data) + else: + async with session.get( + url=self.access_token_url, + params=params, + ) as access_resp: + access_data = await access_resp.json() + return self._parse_token_response(access_data) + + async def get_user_info( + self, + access_token: str | AccessTokenBase, + **kwargs + ) -> OAuthUserInfoResponse: + """ + 获取用户信息 + + Args: + access_token: 访问令牌 + **kwargs: 额外参数(如 QQ 需要 app_id, openid) + + Returns: + OAuthUserInfoResponse: 用户信息 + """ + if isinstance(access_token, AccessTokenBase): + access_token = access_token.access_token + + async with aiohttp.ClientSession() as session: + async with session.get( + url=self.user_info_url, + params=self._build_user_info_params(access_token, **kwargs), + headers=self._build_user_info_headers(access_token), + ) as resp: + user_data = await resp.json() + return self._parse_user_response(user_data) + + # ==================== 钩子方法(子类可覆盖) ==================== + + def _build_user_info_params(self, access_token: str, **kwargs) -> dict: + """ + 构建获取用户信息的请求参数 + + Args: + access_token: 访问令牌 + **kwargs: 额外参数 + + Returns: + dict: 请求参数 + """ + return {} + + def _build_user_info_headers(self, access_token: str) -> dict: + """ + 构建获取用户信息的请求头 + + Args: + access_token: 访问令牌 + + Returns: + dict: 请求头 + """ + return { + 'accept': 'application/json', + } + + def _parse_token_response(self, data: dict) -> AccessTokenBase: + """ + 解析 token 响应 + + Args: + data: API 返回的数据 + + Returns: + AccessTokenBase: 访问令牌 + """ + return AccessTokenBase(access_token=data.get('access_token')) + + def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse: + """ + 解析用户信息响应 + + Args: + data: API 返回的数据 + + Returns: + OAuthUserInfoResponse: 用户信息 + """ + return OAuthUserInfoResponse( + code='0', + openid='', + user_data=OAuthUserData(openid=''), + ) + + +# ==================== 导出 ==================== + +from .github import GithubOAuth, GithubAccessToken, GithubUserData +from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData + +__all__ = [ + # 共享模型 + 'AccessTokenBase', + 'OAuthUserData', + 'OAuthUserInfoResponse', + 'OAuthBase', + + # GitHub + 'GithubOAuth', + 'GithubAccessToken', + 'GithubUserData', + + # QQ + 'QQOAuth', + 'QQAccessToken', + 'QQOpenIDResponse', + 'QQUserData', +] diff --git a/service/oauth/github.py b/service/oauth/github.py index 9b5ec21..2f14dc8 100644 --- a/service/oauth/github.py +++ b/service/oauth/github.py @@ -1,77 +1,127 @@ -from pydantic import BaseModel -import aiohttp +"""GitHub OAuth2.0 认证实现""" +from typing import TYPE_CHECKING -class GithubAccessToken(BaseModel): - access_token: str +from pydantic import BaseModel +from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse + +if TYPE_CHECKING: + from . import OAuthUserData + + +class GithubAccessToken(AccessTokenBase): + """GitHub 访问令牌响应""" token_type: str + """令牌类型""" scope: str + """授权范围""" + class GithubUserData(BaseModel): + """GitHub 用户数据""" login: str + """用户名""" id: int + """用户 ID""" node_id: str + """节点 ID""" avatar_url: str + """头像 URL""" gravatar_id: str | None + """Gravatar ID""" url: str + """API URL""" html_url: str + """主页 URL""" followers_url: str + """粉丝列表 URL""" following_url: str + """关注列表 URL""" gists_url: str + """Gists 列表 URL""" starred_url: str + """星标列表 URL""" subscriptions_url: str + """订阅列表 URL""" organizations_url: str + """组织列表 URL""" repos_url: str + """仓库列表 URL""" events_url: str + """事件列表 URL""" received_events_url: str + """接收的事件列表 URL""" type: str + """用户类型""" site_admin: bool + """是否为站点管理员""" name: str | None + """显示名称""" company: str | None + """公司""" blog: str | None + """博客""" location: str | None + """位置""" email: str | None + """邮箱""" hireable: bool | None + """是否可雇佣""" bio: str | None + """个人简介""" twitter_username: str | None + """Twitter 用户名""" public_repos: int + """公开仓库数""" public_gists: int + """公开 Gists 数""" followers: int + """粉丝数""" following: int - created_at: str # ISO 8601 format date-time string - updated_at: str # ISO 8601 format date-time string + """关注数""" + created_at: str + """创建时间(ISO 8601 格式)""" + updated_at: str + """更新时间(ISO 8601 格式)""" + class GithubUserInfoResponse(BaseModel): + """GitHub 用户信息响应""" code: str + """状态码""" user_data: GithubUserData + """用户数据""" -async def get_access_token(code: str) -> GithubAccessToken: - async with aiohttp.ClientSession() as session: - async with session.post( - url='https://github.com/login/oauth/access_token', - params={ - 'client_id': '', - 'client_secret': '', - 'code': code - }, - headers={'accept': 'application/json'}, - ) as access_resp: - access_data = await access_resp.json() - return GithubAccessToken( - access_token=access_data.get('access_token'), - token_type=access_data.get('token_type'), - scope=access_data.get('scope') - ) -async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse: - if isinstance(access_token, GithubAccessToken): - access_token = access_token.access_token - - async with aiohttp.ClientSession() as session: - async with session.get( - url='https://api.github.com/user', - headers={ - 'accept': 'application/json', - 'Authorization': f'token {access_token}'}, - ) as resp: - user_data = await resp.json() - return GithubUserInfoResponse(**user_data) \ No newline at end of file +class GithubOAuth(OAuthBase): + """GitHub OAuth2.0 客户端""" + + access_token_url = "https://github.com/login/oauth/access_token" + """获取 Access Token 的 API 地址""" + + user_info_url = "https://api.github.com/user" + """获取用户信息的 API 地址""" + + http_method = "POST" + """获取 token 的 HTTP 方法""" + + def _parse_token_response(self, data: dict) -> GithubAccessToken: + """解析 GitHub token 响应""" + return GithubAccessToken( + access_token=data.get('access_token'), + token_type=data.get('token_type'), + scope=data.get('scope'), + ) + + def _build_user_info_headers(self, access_token: str) -> dict: + """构建 GitHub 用户信息请求头""" + return { + 'accept': 'application/json', + 'Authorization': f'token {access_token}', + } + + def _parse_user_response(self, data: dict) -> GithubUserInfoResponse: + """解析 GitHub 用户信息响应""" + return GithubUserInfoResponse( + code='0' if data.get('login') else '1', + user_data=GithubUserData(**data), + ) diff --git a/service/oauth/qq.py b/service/oauth/qq.py index 5a44485..fd92113 100644 --- a/service/oauth/qq.py +++ b/service/oauth/qq.py @@ -1,7 +1,158 @@ -from pydantic import BaseModel +"""QQ OAuth2.0 认证实现""" import aiohttp -async def get_access_token( +from pydantic import BaseModel +from . import AccessTokenBase, OAuthBase + + +class QQAccessToken(AccessTokenBase): + """QQ 访问令牌响应""" + expires_in: int + """access token 的有效期,单位为秒""" + refresh_token: str + """用于刷新 access token 的令牌""" + + +class QQOpenIDResponse(BaseModel): + """QQ OpenID 响应""" + client_id: str + """应用的 appid""" + openid: str + """用户的唯一标识""" + + +class QQUserData(BaseModel): + """QQ 用户数据""" + ret: int + """返回码,0 表示成功""" + msg: str + """返回信息""" + nickname: str | None + """用户昵称""" + gender: str | None + """性别""" + figureurl: str | None + """头像 URL""" + figureurl_1: str | None + """头像 URL(大图)""" + figureurl_2: str | None + """头像 URL(更大图)""" + figureurl_qq_1: str | None + """QQ 头像 URL(大图)""" + figureurl_qq_2: str | None + """QQ 头像 URL(更大图)""" + is_yellow_vip: str | None + """是否黄钻用户""" + vip: str | None + """是否 VIP 用户""" + yellow_vip_level: str | None + """黄钻等级""" + level: str | None + """等级""" + is_yellow_year_vip: str | None + """是否年费黄钻""" + + +class QQUserInfoResponse(BaseModel): + """QQ 用户信息响应""" code: str -): - ... \ No newline at end of file + """状态码""" + openid: str + """用户 OpenID""" + user_data: QQUserData + """用户数据""" + + +class QQOAuth(OAuthBase): + """QQ OAuth2.0 客户端""" + + access_token_url = "https://graph.qq.com/oauth2.0/token" + """获取 Access Token 的 API 地址""" + + user_info_url = "https://graph.qq.com/user/get_user_info" + """获取用户信息的 API 地址""" + + openid_url = "https://graph.qq.com/oauth2.0/me" + """获取 OpenID 的 API 地址""" + + http_method = "GET" + """获取 token 的 HTTP 方法""" + + async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken: + """ + 通过 Authorization Code 获取 Access Token + + Args: + code: 授权码 + redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode + + Returns: + QQAccessToken: 访问令牌 + + 文档: + https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token + """ + params = { + 'grant_type': 'authorization_code', + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'code': code, + 'redirect_uri': redirect_uri, + 'fmt': 'json', + 'need_openid': 1, + } + + async with aiohttp.ClientSession() as session: + async with session.get(url=self.access_token_url, params=params) as access_resp: + access_data = await access_resp.json() + return QQAccessToken( + access_token=access_data.get('access_token'), + expires_in=access_data.get('expires_in'), + refresh_token=access_data.get('refresh_token'), + ) + + async def get_openid(self, access_token: str) -> QQOpenIDResponse: + """ + 获取用户 OpenID + + 注意:如果在 get_access_token 时传入了 need_openid=1,响应中已包含 openid, + 无需额外调用此接口。此函数用于单独获取 openid 的场景。 + + Args: + access_token: 访问令牌 + + Returns: + QQOpenIDResponse: 包含 client_id 和 openid + + 文档: + https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3 + """ + async with aiohttp.ClientSession() as session: + async with session.get( + url=self.openid_url, + params={ + 'access_token': access_token, + 'fmt': 'json', + }, + ) as resp: + data = await resp.json() + return QQOpenIDResponse( + client_id=data.get('client_id'), + openid=data.get('openid'), + ) + + def _build_user_info_params(self, access_token: str, **kwargs) -> dict: + """构建 QQ 用户信息请求参数""" + return { + 'access_token': access_token, + 'oauth_consumer_key': kwargs.get('app_id', self.client_id), + 'openid': kwargs.get('openid', ''), + } + + def _parse_user_response(self, data: dict) -> QQUserInfoResponse: + """解析 QQ 用户信息响应""" + return QQUserInfoResponse( + code='0' if data.get('ret') == 0 else str(data.get('ret')), + openid=data.get('openid', ''), + user_data=QQUserData(**data), + ) diff --git a/service/user/login.py b/service/user/login.py index 88f808f..5ee1dec 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -25,12 +25,12 @@ async def login( # TODO: 验证码校验 # captcha_setting = await Setting.get( # session, - # and_(Setting.type == "auth", Setting.name == "login_captcha") + # (Setting.type == "auth") & (Setting.name == "login_captcha") # ) # is_captcha_required = captcha_setting and captcha_setting.value == "1" # 获取用户信息 - current_user = await User.get(session, User.username == login_request.username, fetch_mode="first") + current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore # 验证用户是否存在 if not current_user: diff --git a/tests/fixtures/groups.py b/tests/fixtures/groups.py index 56cc0ef..3198c59 100644 --- a/tests/fixtures/groups.py +++ b/tests/fixtures/groups.py @@ -42,10 +42,9 @@ class GroupFactory: speed_limit=kwargs.get("speed_limit", 0), ) - group = await group.save(session) - # 如果提供了选项参数,创建 GroupOptions if kwargs.get("create_options", False): + group = await group.save(session, commit=False) options = GroupOptions( group_id=group.id, share_download=kwargs.get("share_download", True), @@ -55,7 +54,10 @@ class GroupFactory: select_node=kwargs.get("select_node", False), advance_delete=kwargs.get("advance_delete", False), ) - await options.save(session) + await options.save(session, commit=False) + await session.commit() + else: + group = await group.save(session) return group @@ -88,7 +90,7 @@ class GroupFactory: speed_limit=0, ) - admin_group = await admin_group.save(session) + admin_group = await admin_group.save(session, commit=False) # 创建管理员组选项 admin_options = GroupOptions( @@ -105,7 +107,8 @@ class GroupFactory: aria2=True, redirected_source=True, ) - await admin_options.save(session) + await admin_options.save(session, commit=False) + await session.commit() return admin_group @@ -140,7 +143,7 @@ class GroupFactory: speed_limit=1024, # 1MB/s ) - limited_group = await limited_group.save(session) + limited_group = await limited_group.save(session, commit=False) # 创建限制组选项 limited_options = GroupOptions( @@ -152,7 +155,8 @@ class GroupFactory: select_node=False, advance_delete=False, ) - await limited_options.save(session) + await limited_options.save(session, commit=False) + await session.commit() return limited_group @@ -185,7 +189,7 @@ class GroupFactory: speed_limit=512, # 512KB/s ) - free_group = await free_group.save(session) + free_group = await free_group.save(session, commit=False) # 创建免费组选项 free_options = GroupOptions( @@ -197,6 +201,7 @@ class GroupFactory: select_node=False, advance_delete=False, ) - await free_options.save(session) + await free_options.save(session, commit=False) + await session.commit() return free_group diff --git a/utils/JWT/__init__.py b/utils/JWT/__init__.py index 6e50030..408114f 100644 --- a/utils/JWT/__init__.py +++ b/utils/JWT/__init__.py @@ -13,7 +13,7 @@ oauth2_scheme = OAuth2PasswordBearer( refreshUrl="/api/v1/user/session/refresh", ) -SECRET_KEY = '' +SECRET_KEY: str = '' async def load_secret_key() -> None: @@ -26,10 +26,10 @@ async def load_secret_key() -> None: global SECRET_KEY async for session in get_session(): - setting = await Setting.get( + setting: Setting = await Setting.get( session, (Setting.type == "auth") & (Setting.name == "secret_key") - ) + ) # type: ignore if setting: SECRET_KEY = setting.value @@ -40,7 +40,14 @@ def build_token_payload( algorithm: str, expires_delta: timedelta | None = None, ) -> tuple[str, datetime]: - """构建令牌""" + """ + 构建令牌。 + + :param data: 需要放进 JWT Payload 的字段 + :param is_refresh: 是否为刷新令牌 + :param algorithm: JWT 签名算法 + :param expires_delta: 过期时间 + """ to_encode = data.copy() @@ -61,8 +68,11 @@ def build_token_payload( # 访问令牌 -def create_access_token(data: dict, expires_delta: timedelta | None = None, - algorithm: str = "HS256") -> AccessTokenBase: +def create_access_token( + data: dict, + expires_delta: timedelta | None = None, + algorithm: str = "HS256" +) -> AccessTokenBase: """ 生成访问令牌,默认有效期 3 小时。 @@ -73,7 +83,12 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None, :return: 包含密钥本身和过期时间的 `AccessTokenBase` """ - access_token, expire_at = build_token_payload(data, False, algorithm, expires_delta) + access_token, expire_at = build_token_payload( + data, + False, + algorithm, + expires_delta + ) return AccessTokenBase( access_token=access_token, access_expires=expire_at, @@ -81,8 +96,11 @@ def create_access_token(data: dict, expires_delta: timedelta | None = None, # 刷新令牌 -def create_refresh_token(data: dict, expires_delta: timedelta | None = None, - algorithm: str = "HS256") -> RefreshTokenBase: +def create_refresh_token( + data: dict, + expires_delta: timedelta | None = None, + algorithm: str = "HS256" +) -> RefreshTokenBase: """ 生成刷新令牌,默认有效期 30 天。 @@ -93,7 +111,12 @@ def create_refresh_token(data: dict, expires_delta: timedelta | None = None, :return: 包含密钥本身和过期时间的 `RefreshTokenBase` """ - refresh_token, expire_at = build_token_payload(data, True, algorithm, expires_delta) + refresh_token, expire_at = build_token_payload( + data, + True, + algorithm, + expires_delta + ) return RefreshTokenBase( refresh_token=refresh_token, refresh_expires=expire_at, diff --git a/utils/conf/appmeta.py b/utils/conf/appmeta.py index 7c92fbe..61461c5 100644 --- a/utils/conf/appmeta.py +++ b/utils/conf/appmeta.py @@ -13,7 +13,7 @@ license_info = {"name": "GPLv3", "url": "https://opensource.org/license/gpl-3.0" BackendVersion = "0.0.1" """后端版本""" -IsPro = False +IsPro: bool = False mode: str = os.getenv('MODE', 'master') """运行模式""" diff --git a/utils/password/pwd.py b/utils/password/pwd.py index 76397e1..eb96844 100644 --- a/utils/password/pwd.py +++ b/utils/password/pwd.py @@ -1,4 +1,5 @@ import secrets +from typing import Literal from loguru import logger from argon2 import PasswordHasher @@ -11,7 +12,23 @@ from pydantic import BaseModel, Field from utils.JWT import SECRET_KEY from utils.conf import appmeta -_ph = PasswordHasher() +# FIRST RECOMMENDED option per RFC 9106. +_ph_lowmem = PasswordHasher( + salt_len=16, + hash_len=32, + time_cost=3, + memory_cost=65536, # 64 MiB + parallelism=4, +) + +# SECOND RECOMMENDED option per RFC 9106. +_ph_highmem = PasswordHasher( + salt_len=16, + hash_len=32, + time_cost=1, + memory_cost=2097152, # 2 GiB + parallelism=4, +) class PasswordStatus(StrEnum): """密码校验状态枚举""" @@ -48,7 +65,8 @@ class Password: @staticmethod def generate( - length: int = 8 + length: int = 8, + url_safe: bool = False ) -> str: """ 生成指定长度的随机密码。 @@ -58,7 +76,16 @@ class Password: :return: 随机密码 :rtype: str """ - return secrets.token_hex(length) + if url_safe: + return secrets.token_urlsafe(length) + else: + return secrets.token_hex(length) + + @staticmethod + def generate_hex( + length: int = 8 + ) -> bytes: + return secrets.token_bytes(length) @staticmethod def hash( @@ -72,7 +99,7 @@ class Password: :param password: 需要哈希的原始密码 :return: Argon2 哈希字符串 """ - return _ph.hash(password) + return _ph_lowmem.hash(password) @staticmethod def verify( @@ -87,21 +114,16 @@ class Password: :return: 如果密码匹配返回 True, 否则返回 False """ try: - # verify 函数会自动解析 stored_password 中的盐和参数 - _ph.verify(hash, password) + _ph_lowmem.verify(hash, password) - # 检查哈希参数是否已过时。如果返回True, - # 意味着你应该使用新的参数重新哈希密码并更新存储。 - # 这是一个很好的实践,可以随着时间推移增强安全性。 - if _ph.check_needs_rehash(hash): + # 检查哈希参数是否已过时 + if _ph_lowmem.check_needs_rehash(hash): logger.warning("密码哈希参数已过时,建议重新哈希并更新。") return PasswordStatus.EXPIRED return PasswordStatus.VALID except VerifyMismatchError: - # 这是预期的异常,当密码不匹配时触发。 return PasswordStatus.INVALID - # 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题 @staticmethod async def generate_totp(