Refactor and enhance OAuth2.0 implementation; update models and routes

- Refactored AdminSummaryData and AdminSummaryResponse classes for better clarity.
- Added OAUTH type to SettingsType enum.
- Cleaned up imports in webdav.py.
- Updated admin router to improve summary data retrieval and response handling.
- Enhanced file management routes with better condition handling and user storage updates.
- Improved group management routes by optimizing data retrieval.
- Refined task management routes for better condition handling.
- Updated user management routes to streamline access token retrieval.
- Implemented a new captcha verification structure with abstract base class.
- Removed deprecated env.md file and replaced with a new structured version.
- Introduced a unified OAuth2.0 client base class for GitHub and QQ integrations.
- Enhanced password management with improved hashing strategies.
- Added detailed comments and documentation throughout the codebase for clarity.
This commit is contained in:
2026-01-12 18:07:44 +08:00
parent 61ddc96f17
commit d2c914cff8
29 changed files with 814 additions and 4609 deletions

View File

@@ -1,5 +1,39 @@
import abc
import aiohttp
from pydantic import BaseModel
from .gcaptcha import GCaptcha
from .turnstile import TurnstileCaptcha
class CaptchaRequestBase(BaseModel):
"""验证码验证请求"""
token: str
secret: str
"""验证 token"""
secret: str
"""验证密钥"""
class CaptchaBase(abc.ABC):
"""验证码验证器抽象基类"""
verify_url: str
"""验证 API 地址(子类必须定义)"""
async def verify_captcha(self, request: CaptchaRequestBase) -> bool:
"""
验证 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(self.verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)

View File

@@ -1,21 +1,7 @@
import aiohttp
from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Google reCAPTCHA v2/v3 的 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)
class GCaptcha(CaptchaBase):
"""Google reCAPTCHA v2/v3 验证器"""
verify_url = "https://www.google.com/recaptcha/api/siteverify"

View File

@@ -1,21 +1,7 @@
import aiohttp
from . import CaptchaBase
from . import CaptchaRequestBase
async def verify_captcha(request: CaptchaRequestBase) -> bool:
"""
验证 Turnstile 的 token 是否有效。
:return: 如果验证成功返回 True否则返回 False
:rtype: bool
"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
payload = request.model_dump()
async with aiohttp.ClientSession() as session:
async with session.post(verify_url, data=payload) as response:
if response.status != 200:
return False
result = await response.json()
return result.get('success', False)
class TurnstileCaptcha(CaptchaBase):
"""Cloudflare Turnstile 验证器"""
verify_url = "https://challenges.cloudflare.com/turnstile/v0/siteverify"

View File

@@ -1,12 +0,0 @@
# 环境变量字段
- `MODE` str 运行模式,默认 `master`
- `master` 主机模式
- `slave` 从机模式
- `DEBUG` bool 是否开启调试模式,默认 `false`
- `DATABASE_URL`: 数据库连接信息,默认 `sqlite+aiosqlite:///disknext.db`
- `REDIS_HOST`: Redis 主机地址
- `REDIS_PORT`: Redis 端口
- `REDIS_PASSWORD`: Redis 密码
- `REDIS_DB`: Redis 数据库
- `REDIS_PROTOCOL`

223
service/oauth/__init__.py Normal file
View File

@@ -0,0 +1,223 @@
"""
OAuth2.0 认证模块
提供统一的 OAuth2.0 客户端基类,支持多种第三方登录平台。
"""
import abc
import aiohttp
from pydantic import BaseModel
# ==================== 共享数据模型 ====================
class AccessTokenBase(BaseModel):
"""访问令牌基类"""
access_token: str
"""访问令牌"""
class OAuthUserData(BaseModel):
"""OAuth 用户数据通用 DTO"""
openid: str
"""用户唯一标识GitHub 为 idQQ 为 openid"""
nickname: str | None
"""用户昵称"""
avatar_url: str | None
"""头像 URL"""
email: str | None
"""邮箱"""
bio: str | None
"""个人简介"""
class OAuthUserInfoResponse(BaseModel):
"""OAuth 用户信息响应"""
code: str
"""状态码"""
openid: str
"""用户唯一标识"""
user_data: OAuthUserData
"""用户数据"""
# ==================== OAuth2.0 抽象基类 ====================
class OAuthBase(abc.ABC):
"""
OAuth2.0 客户端抽象基类
子类需要定义以下类属性:
- access_token_url: 获取 Access Token 的 API 地址
- user_info_url: 获取用户信息的 API 地址
- http_method: 获取 token 的 HTTP 方法POST 或 GET
"""
# 子类必须定义的类属性
access_token_url: str
"""获取 Access Token 的 API 地址"""
user_info_url: str
"""获取用户信息的 API 地址"""
http_method: str = "POST"
"""获取 token 的 HTTP 方法POST 或 GET"""
# 实例属性(构造函数传入)
client_id: str
client_secret: str
def __init__(self, client_id: str, client_secret: str) -> None:
"""
初始化 OAuth 客户端
Args:
client_id: 应用 client_id
client_secret: 应用 client_secret
"""
self.client_id = client_id
self.client_secret = client_secret
async def get_access_token(self, code: str, **kwargs) -> AccessTokenBase:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
**kwargs: 额外参数(如 QQ 需要 redirect_uri
Returns:
AccessTokenBase: 访问令牌
"""
params = {
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
}
params.update(kwargs)
async with aiohttp.ClientSession() as session:
if self.http_method == "POST":
async with session.post(
url=self.access_token_url,
params=params,
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
else:
async with session.get(
url=self.access_token_url,
params=params,
) as access_resp:
access_data = await access_resp.json()
return self._parse_token_response(access_data)
async def get_user_info(
self,
access_token: str | AccessTokenBase,
**kwargs
) -> OAuthUserInfoResponse:
"""
获取用户信息
Args:
access_token: 访问令牌
**kwargs: 额外参数(如 QQ 需要 app_id, openid
Returns:
OAuthUserInfoResponse: 用户信息
"""
if isinstance(access_token, AccessTokenBase):
access_token = access_token.access_token
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.user_info_url,
params=self._build_user_info_params(access_token, **kwargs),
headers=self._build_user_info_headers(access_token),
) as resp:
user_data = await resp.json()
return self._parse_user_response(user_data)
# ==================== 钩子方法(子类可覆盖) ====================
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""
构建获取用户信息的请求参数
Args:
access_token: 访问令牌
**kwargs: 额外参数
Returns:
dict: 请求参数
"""
return {}
def _build_user_info_headers(self, access_token: str) -> dict:
"""
构建获取用户信息的请求头
Args:
access_token: 访问令牌
Returns:
dict: 请求头
"""
return {
'accept': 'application/json',
}
def _parse_token_response(self, data: dict) -> AccessTokenBase:
"""
解析 token 响应
Args:
data: API 返回的数据
Returns:
AccessTokenBase: 访问令牌
"""
return AccessTokenBase(access_token=data.get('access_token'))
def _parse_user_response(self, data: dict) -> OAuthUserInfoResponse:
"""
解析用户信息响应
Args:
data: API 返回的数据
Returns:
OAuthUserInfoResponse: 用户信息
"""
return OAuthUserInfoResponse(
code='0',
openid='',
user_data=OAuthUserData(openid=''),
)
# ==================== 导出 ====================
from .github import GithubOAuth, GithubAccessToken, GithubUserData
from .qq import QQOAuth, QQAccessToken, QQOpenIDResponse, QQUserData
__all__ = [
# 共享模型
'AccessTokenBase',
'OAuthUserData',
'OAuthUserInfoResponse',
'OAuthBase',
# GitHub
'GithubOAuth',
'GithubAccessToken',
'GithubUserData',
# QQ
'QQOAuth',
'QQAccessToken',
'QQOpenIDResponse',
'QQUserData',
]

View File

@@ -1,77 +1,127 @@
from pydantic import BaseModel
import aiohttp
"""GitHub OAuth2.0 认证实现"""
from typing import TYPE_CHECKING
class GithubAccessToken(BaseModel):
access_token: str
from pydantic import BaseModel
from . import AccessTokenBase, OAuthBase, OAuthUserInfoResponse
if TYPE_CHECKING:
from . import OAuthUserData
class GithubAccessToken(AccessTokenBase):
"""GitHub 访问令牌响应"""
token_type: str
"""令牌类型"""
scope: str
"""授权范围"""
class GithubUserData(BaseModel):
"""GitHub 用户数据"""
login: str
"""用户名"""
id: int
"""用户 ID"""
node_id: str
"""节点 ID"""
avatar_url: str
"""头像 URL"""
gravatar_id: str | None
"""Gravatar ID"""
url: str
"""API URL"""
html_url: str
"""主页 URL"""
followers_url: str
"""粉丝列表 URL"""
following_url: str
"""关注列表 URL"""
gists_url: str
"""Gists 列表 URL"""
starred_url: str
"""星标列表 URL"""
subscriptions_url: str
"""订阅列表 URL"""
organizations_url: str
"""组织列表 URL"""
repos_url: str
"""仓库列表 URL"""
events_url: str
"""事件列表 URL"""
received_events_url: str
"""接收的事件列表 URL"""
type: str
"""用户类型"""
site_admin: bool
"""是否为站点管理员"""
name: str | None
"""显示名称"""
company: str | None
"""公司"""
blog: str | None
"""博客"""
location: str | None
"""位置"""
email: str | None
"""邮箱"""
hireable: bool | None
"""是否可雇佣"""
bio: str | None
"""个人简介"""
twitter_username: str | None
"""Twitter 用户名"""
public_repos: int
"""公开仓库数"""
public_gists: int
"""公开 Gists 数"""
followers: int
"""粉丝数"""
following: int
created_at: str # ISO 8601 format date-time string
updated_at: str # ISO 8601 format date-time string
"""关注数"""
created_at: str
"""创建时间ISO 8601 格式)"""
updated_at: str
"""更新时间ISO 8601 格式)"""
class GithubUserInfoResponse(BaseModel):
"""GitHub 用户信息响应"""
code: str
"""状态码"""
user_data: GithubUserData
"""用户数据"""
async def get_access_token(code: str) -> GithubAccessToken:
async with aiohttp.ClientSession() as session:
async with session.post(
url='https://github.com/login/oauth/access_token',
params={
'client_id': '',
'client_secret': '',
'code': code
},
headers={'accept': 'application/json'},
) as access_resp:
access_data = await access_resp.json()
return GithubAccessToken(
access_token=access_data.get('access_token'),
token_type=access_data.get('token_type'),
scope=access_data.get('scope')
)
async def get_user_info(access_token: str | GithubAccessToken) -> GithubUserInfoResponse:
if isinstance(access_token, GithubAccessToken):
access_token = access_token.access_token
async with aiohttp.ClientSession() as session:
async with session.get(
url='https://api.github.com/user',
headers={
'accept': 'application/json',
'Authorization': f'token {access_token}'},
) as resp:
user_data = await resp.json()
return GithubUserInfoResponse(**user_data)
class GithubOAuth(OAuthBase):
"""GitHub OAuth2.0 客户端"""
access_token_url = "https://github.com/login/oauth/access_token"
"""获取 Access Token 的 API 地址"""
user_info_url = "https://api.github.com/user"
"""获取用户信息的 API 地址"""
http_method = "POST"
"""获取 token 的 HTTP 方法"""
def _parse_token_response(self, data: dict) -> GithubAccessToken:
"""解析 GitHub token 响应"""
return GithubAccessToken(
access_token=data.get('access_token'),
token_type=data.get('token_type'),
scope=data.get('scope'),
)
def _build_user_info_headers(self, access_token: str) -> dict:
"""构建 GitHub 用户信息请求头"""
return {
'accept': 'application/json',
'Authorization': f'token {access_token}',
}
def _parse_user_response(self, data: dict) -> GithubUserInfoResponse:
"""解析 GitHub 用户信息响应"""
return GithubUserInfoResponse(
code='0' if data.get('login') else '1',
user_data=GithubUserData(**data),
)

View File

@@ -1,7 +1,158 @@
from pydantic import BaseModel
"""QQ OAuth2.0 认证实现"""
import aiohttp
async def get_access_token(
from pydantic import BaseModel
from . import AccessTokenBase, OAuthBase
class QQAccessToken(AccessTokenBase):
"""QQ 访问令牌响应"""
expires_in: int
"""access token 的有效期,单位为秒"""
refresh_token: str
"""用于刷新 access token 的令牌"""
class QQOpenIDResponse(BaseModel):
"""QQ OpenID 响应"""
client_id: str
"""应用的 appid"""
openid: str
"""用户的唯一标识"""
class QQUserData(BaseModel):
"""QQ 用户数据"""
ret: int
"""返回码0 表示成功"""
msg: str
"""返回信息"""
nickname: str | None
"""用户昵称"""
gender: str | None
"""性别"""
figureurl: str | None
"""头像 URL"""
figureurl_1: str | None
"""头像 URL大图"""
figureurl_2: str | None
"""头像 URL更大图"""
figureurl_qq_1: str | None
"""QQ 头像 URL大图"""
figureurl_qq_2: str | None
"""QQ 头像 URL更大图"""
is_yellow_vip: str | None
"""是否黄钻用户"""
vip: str | None
"""是否 VIP 用户"""
yellow_vip_level: str | None
"""黄钻等级"""
level: str | None
"""等级"""
is_yellow_year_vip: str | None
"""是否年费黄钻"""
class QQUserInfoResponse(BaseModel):
"""QQ 用户信息响应"""
code: str
):
...
"""状态码"""
openid: str
"""用户 OpenID"""
user_data: QQUserData
"""用户数据"""
class QQOAuth(OAuthBase):
"""QQ OAuth2.0 客户端"""
access_token_url = "https://graph.qq.com/oauth2.0/token"
"""获取 Access Token 的 API 地址"""
user_info_url = "https://graph.qq.com/user/get_user_info"
"""获取用户信息的 API 地址"""
openid_url = "https://graph.qq.com/oauth2.0/me"
"""获取 OpenID 的 API 地址"""
http_method = "GET"
"""获取 token 的 HTTP 方法"""
async def get_access_token(self, code: str, redirect_uri: str) -> QQAccessToken:
"""
通过 Authorization Code 获取 Access Token
Args:
code: 授权码
redirect_uri: 与授权时传入的 redirect_uri 保持一致,需要 URLEncode
Returns:
QQAccessToken: 访问令牌
文档:
https://wiki.connect.qq.com/%E4%BD%BF%E7%94%A8authorization_code%E8%8E%B7%E5%8F%96access_token
"""
params = {
'grant_type': 'authorization_code',
'client_id': self.client_id,
'client_secret': self.client_secret,
'code': code,
'redirect_uri': redirect_uri,
'fmt': 'json',
'need_openid': 1,
}
async with aiohttp.ClientSession() as session:
async with session.get(url=self.access_token_url, params=params) as access_resp:
access_data = await access_resp.json()
return QQAccessToken(
access_token=access_data.get('access_token'),
expires_in=access_data.get('expires_in'),
refresh_token=access_data.get('refresh_token'),
)
async def get_openid(self, access_token: str) -> QQOpenIDResponse:
"""
获取用户 OpenID
注意:如果在 get_access_token 时传入了 need_openid=1响应中已包含 openid
无需额外调用此接口。此函数用于单独获取 openid 的场景。
Args:
access_token: 访问令牌
Returns:
QQOpenIDResponse: 包含 client_id 和 openid
文档:
https://wiki.connect.qq.com/%E8%8E%B7%E5%8F%96%E7%94%A8%E6%88%B7openid%E7%9A%84oauth2.0%E6%8E%A5%E5%8F%A3
"""
async with aiohttp.ClientSession() as session:
async with session.get(
url=self.openid_url,
params={
'access_token': access_token,
'fmt': 'json',
},
) as resp:
data = await resp.json()
return QQOpenIDResponse(
client_id=data.get('client_id'),
openid=data.get('openid'),
)
def _build_user_info_params(self, access_token: str, **kwargs) -> dict:
"""构建 QQ 用户信息请求参数"""
return {
'access_token': access_token,
'oauth_consumer_key': kwargs.get('app_id', self.client_id),
'openid': kwargs.get('openid', ''),
}
def _parse_user_response(self, data: dict) -> QQUserInfoResponse:
"""解析 QQ 用户信息响应"""
return QQUserInfoResponse(
code='0' if data.get('ret') == 0 else str(data.get('ret')),
openid=data.get('openid', ''),
user_data=QQUserData(**data),
)

View File

@@ -25,12 +25,12 @@ async def login(
# TODO: 验证码校验
# captcha_setting = await Setting.get(
# session,
# and_(Setting.type == "auth", Setting.name == "login_captcha")
# (Setting.type == "auth") & (Setting.name == "login_captcha")
# )
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
# 获取用户信息
current_user = await User.get(session, User.username == login_request.username, fetch_mode="first")
current_user: User = await User.get(session, User.username == login_request.username, fetch_mode="first") #type: ignore
# 验证用户是否存在
if not current_user: