- Add GroupClaims model for JWT permission snapshots - Add JWTPayload model for typed JWT decoding - Refactor auth middleware: jwt_required (no DB) -> admin_required (no DB) -> auth_required (DB) - Add UserBanStore for instant ban enforcement via Redis + memory fallback - Fix status check bug: StrEnum is always truthy, use explicit != ACTIVE - Shorten access_token expiry from 3h to 1h - Add CaptchaScene enum and verify_captcha_if_needed service - Add require_captcha dependency injection factory - Add CLA document and new default settings - Update all tests for new JWT API Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
445 lines
13 KiB
Python
445 lines
13 KiB
Python
from typing import Annotated, Literal
|
||
from uuid import UUID, uuid4
|
||
|
||
import jwt
|
||
from fastapi import APIRouter, Depends, Form, HTTPException
|
||
from loguru import logger
|
||
from webauthn import generate_registration_options
|
||
from webauthn.helpers import options_to_json_dict
|
||
|
||
import service
|
||
import sqlmodels
|
||
from middleware.auth import auth_required
|
||
from middleware.dependencies import SessionDep, require_captcha
|
||
from service.captcha import CaptchaScene
|
||
from sqlmodels.user import UserStatus
|
||
from utils import JWT, Password, http_exceptions
|
||
from .settings import user_settings_router
|
||
|
||
user_router = APIRouter(
|
||
prefix="/user",
|
||
tags=["user"],
|
||
)
|
||
|
||
user_router.include_router(user_settings_router)
|
||
|
||
class OAuth2PasswordWithExtrasForm:
|
||
"""
|
||
扩展 OAuth2 密码表单。
|
||
|
||
在标准 username/password 基础上添加 otp_code 字段。
|
||
captcha_code 由 require_captcha 依赖注入单独处理。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
username: Annotated[str, Form()],
|
||
password: Annotated[str, Form()],
|
||
otp_code: Annotated[str | None, Form(min_length=6, max_length=6)] = None,
|
||
):
|
||
self.username = username
|
||
self.password = password
|
||
self.otp_code = otp_code
|
||
|
||
|
||
@user_router.post(
|
||
path='/session',
|
||
summary='用户登录',
|
||
description='用户登录端点,支持验证码校验和两步验证。',
|
||
dependencies=[Depends(require_captcha(CaptchaScene.LOGIN))],
|
||
)
|
||
async def router_user_session(
|
||
session: SessionDep,
|
||
form_data: Annotated[OAuth2PasswordWithExtrasForm, Depends()],
|
||
) -> sqlmodels.TokenResponse:
|
||
"""
|
||
用户登录端点
|
||
|
||
表单字段:
|
||
- username: 用户邮箱
|
||
- password: 用户密码
|
||
- captcha_code: 验证码 token(可选,由 require_captcha 依赖校验)
|
||
- otp_code: 两步验证码(可选,仅在用户启用 2FA 时需要)
|
||
|
||
错误处理:
|
||
- 400: 需要验证码但未提供
|
||
- 401: 邮箱/密码错误,或 2FA 验证码错误
|
||
- 403: 账户已禁用 / 验证码验证失败
|
||
- 428: 需要两步验证但未提供 otp_code
|
||
"""
|
||
return await service.user.login(
|
||
session,
|
||
sqlmodels.LoginRequest(
|
||
email=form_data.username,
|
||
password=form_data.password,
|
||
two_fa_code=form_data.otp_code,
|
||
),
|
||
)
|
||
|
||
@user_router.post(
|
||
path='/session/refresh',
|
||
summary="用刷新令牌刷新会话",
|
||
description="Refresh the user session using a refresh token."
|
||
)
|
||
async def router_user_session_refresh(
|
||
session: SessionDep,
|
||
request: sqlmodels.RefreshTokenRequest,
|
||
) -> sqlmodels.TokenResponse:
|
||
"""
|
||
使用 refresh_token 签发新的 access_token 和 refresh_token。
|
||
|
||
流程:
|
||
1. 解码 refresh_token JWT
|
||
2. 验证 token_type 为 refresh
|
||
3. 验证用户存在且状态正常
|
||
4. 签发新的 access_token + refresh_token
|
||
|
||
:param session: 数据库会话
|
||
:param request: 刷新令牌请求
|
||
:return: 新的 TokenResponse
|
||
"""
|
||
|
||
try:
|
||
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||
except jwt.InvalidTokenError:
|
||
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
|
||
|
||
# 验证是 refresh token
|
||
if payload.get("token_type") != "refresh":
|
||
http_exceptions.raise_unauthorized("非刷新令牌")
|
||
|
||
user_id_str = payload.get("sub")
|
||
if not user_id_str:
|
||
http_exceptions.raise_unauthorized("令牌缺少用户标识")
|
||
|
||
user_id = UUID(user_id_str)
|
||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id, load=sqlmodels.User.group)
|
||
if not user:
|
||
http_exceptions.raise_unauthorized("用户不存在")
|
||
|
||
if user.status != UserStatus.ACTIVE:
|
||
http_exceptions.raise_forbidden("账户已被禁用")
|
||
|
||
# 加载 GroupOptions(获取最新权限)
|
||
group_options = await sqlmodels.GroupOptions.get(
|
||
session,
|
||
sqlmodels.GroupOptions.group_id == user.group_id,
|
||
)
|
||
user.group.options = group_options
|
||
group_claims = sqlmodels.GroupClaims.from_group(user.group)
|
||
|
||
# 签发新令牌
|
||
access_token = JWT.create_access_token(
|
||
sub=user.id,
|
||
jti=uuid4(),
|
||
status=user.status.value,
|
||
group=group_claims,
|
||
)
|
||
refresh_token = JWT.create_refresh_token(
|
||
sub=user.id,
|
||
jti=uuid4(),
|
||
)
|
||
|
||
return sqlmodels.TokenResponse(
|
||
access_token=access_token.access_token,
|
||
access_expires=access_token.access_expires,
|
||
refresh_token=refresh_token.refresh_token,
|
||
refresh_expires=refresh_token.refresh_expires,
|
||
)
|
||
|
||
@user_router.post(
|
||
path='/',
|
||
summary='用户注册',
|
||
description='User registration endpoint.',
|
||
status_code=204,
|
||
)
|
||
async def router_user_register(
|
||
session: SessionDep,
|
||
request: sqlmodels.RegisterRequest,
|
||
) -> None:
|
||
"""
|
||
用户注册端点
|
||
|
||
流程:
|
||
1. 验证用户名唯一性
|
||
2. 获取默认用户组
|
||
3. 创建用户记录
|
||
4. 创建用户根目录(name="/")
|
||
|
||
:param session: 数据库会话
|
||
:param request: 注册请求
|
||
:return: 注册结果
|
||
:raises HTTPException 400: 用户名已存在
|
||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||
"""
|
||
# 1. 验证邮箱唯一性
|
||
existing_user = await sqlmodels.User.get(
|
||
session,
|
||
sqlmodels.User.email == request.email
|
||
)
|
||
if existing_user:
|
||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||
|
||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||
session,
|
||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||
)
|
||
if default_group_setting is None or not default_group_setting.value:
|
||
logger.error("默认用户组不存在")
|
||
http_exceptions.raise_internal_error()
|
||
|
||
default_group_id = UUID(default_group_setting.value)
|
||
default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
|
||
if not default_group:
|
||
logger.error("默认用户组不存在")
|
||
http_exceptions.raise_internal_error()
|
||
|
||
# 3. 创建用户
|
||
hashed_password = Password.hash(request.password)
|
||
new_user = sqlmodels.User(
|
||
email=request.email,
|
||
password=hashed_password,
|
||
group_id=default_group.id,
|
||
)
|
||
new_user_id = new_user.id
|
||
await new_user.save(session)
|
||
|
||
# 4. 创建用户根目录
|
||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||
if not default_policy:
|
||
logger.error("默认存储策略不存在")
|
||
http_exceptions.raise_internal_error()
|
||
|
||
await sqlmodels.Object(
|
||
name="/",
|
||
type=sqlmodels.ObjectType.FOLDER,
|
||
owner_id=new_user_id,
|
||
parent_id=None,
|
||
policy_id=default_policy.id,
|
||
).save(session)
|
||
|
||
@user_router.post(
|
||
path='/code',
|
||
summary='发送验证码邮件',
|
||
description='Send a verification code email.',
|
||
)
|
||
def router_user_email_code(
|
||
reason: Literal['register', 'reset'] = 'register',
|
||
) -> sqlmodels.ResponseBase:
|
||
"""
|
||
Send a verification code email.
|
||
|
||
Returns:
|
||
dict: A dictionary containing information about the password reset email.
|
||
"""
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
@user_router.get(
|
||
path='/qq',
|
||
summary='初始化QQ登录',
|
||
description='Initialize QQ login for a user.',
|
||
)
|
||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||
"""
|
||
Initialize QQ login for a user.
|
||
|
||
Returns:
|
||
dict: A dictionary containing QQ login initialization information.
|
||
"""
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
@user_router.get(
|
||
path='authn/{username}',
|
||
summary='WebAuthn登录初始化',
|
||
description='Initialize WebAuthn login for a user.',
|
||
)
|
||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
@user_router.post(
|
||
path='authn/finish/{username}',
|
||
summary='WebAuthn登录',
|
||
description='Finish WebAuthn login for a user.',
|
||
)
|
||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||
"""
|
||
Finish WebAuthn login for a user.
|
||
|
||
Args:
|
||
username (str): The username of the user.
|
||
|
||
Returns:
|
||
dict: A dictionary containing WebAuthn login information.
|
||
"""
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
@user_router.get(
|
||
path='/profile/{id}',
|
||
summary='获取用户主页展示用分享',
|
||
description='Get user profile for display.',
|
||
)
|
||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||
"""
|
||
Get user profile for display.
|
||
|
||
Args:
|
||
id (str): The user ID.
|
||
|
||
Returns:
|
||
dict: A dictionary containing user profile information.
|
||
"""
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
@user_router.get(
|
||
path='/avatar/{id}/{size}',
|
||
summary='获取用户头像',
|
||
description='Get user avatar by ID and size.',
|
||
)
|
||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||
"""
|
||
Get user avatar by ID and size.
|
||
|
||
Args:
|
||
id (str): The user ID.
|
||
size (int): The size of the avatar image.
|
||
|
||
Returns:
|
||
str: A Base64 encoded string of the user avatar image.
|
||
"""
|
||
http_exceptions.raise_not_implemented()
|
||
|
||
#####################
|
||
# 需要登录的接口
|
||
#####################
|
||
|
||
@user_router.get(
|
||
path='/me',
|
||
summary='获取用户信息',
|
||
description='Get user information.',
|
||
dependencies=[Depends(dependency=auth_required)],
|
||
response_model=sqlmodels.UserResponse,
|
||
)
|
||
async def router_user_me(
|
||
session: SessionDep,
|
||
user: Annotated[sqlmodels.User, Depends(auth_required)],
|
||
) -> sqlmodels.UserResponse:
|
||
"""
|
||
获取用户信息.
|
||
|
||
:return: ResponseBase containing user information.
|
||
:rtype: ResponseBase
|
||
"""
|
||
# 加载 group 及其 options 关系
|
||
group = await sqlmodels.Group.get(
|
||
session,
|
||
sqlmodels.Group.id == user.group_id,
|
||
load=sqlmodels.Group.options
|
||
)
|
||
|
||
# 构建 GroupResponse
|
||
group_response = group.to_response() if group else None
|
||
|
||
# 异步加载 tags 关系
|
||
user_tags = await user.awaitable_attrs.tags
|
||
|
||
return sqlmodels.UserResponse(
|
||
id=user.id,
|
||
email=user.email,
|
||
status=user.status,
|
||
score=user.score,
|
||
nickname=user.nickname,
|
||
avatar=user.avatar,
|
||
created_at=user.created_at,
|
||
group=group_response,
|
||
tags=[tag.name for tag in user_tags] if user_tags else [],
|
||
)
|
||
|
||
@user_router.get(
|
||
path='/storage',
|
||
summary='存储信息',
|
||
description='Get user storage information.',
|
||
dependencies=[Depends(auth_required)],
|
||
)
|
||
async def router_user_storage(
|
||
session: SessionDep,
|
||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||
) -> sqlmodels.UserStorageResponse:
|
||
"""
|
||
获取用户存储空间信息。
|
||
"""
|
||
# 获取用户组的基础存储容量
|
||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||
if not group:
|
||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||
|
||
# [TODO] 总空间加上用户购买的额外空间
|
||
|
||
total: int = group.max_storage
|
||
used: int = user.storage
|
||
free: int = max(0, total - used)
|
||
|
||
return sqlmodels.UserStorageResponse(
|
||
used=used,
|
||
free=free,
|
||
total=total,
|
||
)
|
||
|
||
@user_router.put(
|
||
path='/authn/start',
|
||
summary='WebAuthn登录初始化',
|
||
description='Initialize WebAuthn login for a user.',
|
||
dependencies=[Depends(auth_required)],
|
||
)
|
||
async def router_user_authn_start(
|
||
session: SessionDep,
|
||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||
) -> sqlmodels.ResponseBase:
|
||
"""
|
||
Initialize WebAuthn login for a user.
|
||
|
||
Returns:
|
||
dict: A dictionary containing WebAuthn initialization information.
|
||
"""
|
||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||
authn_setting = await sqlmodels.Setting.get(
|
||
session,
|
||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||
)
|
||
if not authn_setting or authn_setting.value != "1":
|
||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||
|
||
site_url_setting = await sqlmodels.Setting.get(
|
||
session,
|
||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
|
||
)
|
||
site_title_setting = await sqlmodels.Setting.get(
|
||
session,
|
||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
|
||
)
|
||
|
||
options = generate_registration_options(
|
||
rp_id=site_url_setting.value if site_url_setting else "",
|
||
rp_name=site_title_setting.value if site_title_setting else "",
|
||
user_name=user.email,
|
||
user_display_name=user.nickname or user.email,
|
||
)
|
||
|
||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||
|
||
@user_router.put(
|
||
path='/authn/finish',
|
||
summary='WebAuthn登录',
|
||
description='Finish WebAuthn login for a user.',
|
||
dependencies=[Depends(auth_required)],
|
||
)
|
||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||
"""
|
||
Finish WebAuthn login for a user.
|
||
|
||
Returns:
|
||
dict: A dictionary containing WebAuthn login information.
|
||
"""
|
||
http_exceptions.raise_not_implemented() |