diff --git a/models/group.py b/models/group.py index 87e6cb1..725fa0b 100644 --- a/models/group.py +++ b/models/group.py @@ -25,7 +25,7 @@ class GroupOptionsBase(SQLModelBase): """是否允许分享下载""" share_free: bool = False - """是否免积分分享""" + """是否免积分获取需要积分的内容""" relocate: bool = False """是否允许文件重定位""" @@ -136,3 +136,22 @@ class Group(GroupBase, TableBase, table=True): back_populates="previous_group", sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"} ) + + def to_response(self) -> "GroupResponse": + """转换为响应 DTO""" + opts = self.options + return GroupResponse( + id=self.id, + name=self.name, + allow_share=self.share_enabled, + webdav=self.web_dav_enabled, + share_download=opts.share_download if opts else False, + share_free=opts.share_free if opts else False, + relocate=opts.relocate if opts else False, + source_batch=opts.source_batch if opts else 0, + select_node=opts.select_node if opts else False, + advance_delete=opts.advance_delete if opts else False, + allow_remote_download=opts.aria2 if opts else False, + allow_archive_download=opts.archive_download if opts else False, + allow_webdav_proxy=opts.webdav_proxy if opts else False, + ) diff --git a/models/migration.py b/models/migration.py index e2d779e..3668ff5 100644 --- a/models/migration.py +++ b/models/migration.py @@ -15,8 +15,8 @@ async def migration() -> None: log.info('开始进行数据库初始化...') await init_default_settings() - await init_default_group() await init_default_policy() + await init_default_group() await init_default_user() log.info('数据库初始化结束') @@ -147,6 +147,7 @@ async def init_default_group() -> None: if not await Group.get(session, Group.id == 1): admin_group = await Group( name="管理员", + policies="1", max_storage=1 * 1024 * 1024 * 1024, # 1GB share_enabled=True, web_dav_enabled=True, @@ -158,7 +159,10 @@ async def init_default_group() -> None: archive_download=True, archive_task=True, share_download=True, + share_free=True, aria2=True, + select_node=True, + advance_delete=True, ).save(session) # 未找到初始注册会员时,则创建 diff --git a/models/object.py b/models/object.py index 754dfd0..e79c4f9 100644 --- a/models/object.py +++ b/models/object.py @@ -84,8 +84,8 @@ class PolicyResponse(SQLModelBase): max_size: int = 0 """单文件最大限制,单位字节,0表示不限制""" - file_type: list[str] = [] - """允许的文件类型列表,空列表表示不限制""" + file_type: list[str] | None = None + """允许的文件类型列表,None 表示不限制""" class DirectoryResponse(SQLModelBase): diff --git a/models/user.py b/models/user.py index 04e3e9e..cbb37ac 100644 --- a/models/user.py +++ b/models/user.py @@ -1,4 +1,5 @@ from datetime import datetime +from enum import StrEnum from typing import Literal, Optional, TYPE_CHECKING from sqlmodel import Field, Relationship @@ -26,6 +27,13 @@ Option 需求 - 切换到不同存储策略是否提醒 """ +class AvatarType(StrEnum): + """头像类型枚举""" + + DEFAULT = "default" + GRAVATAR = "gravatar" + FILE = "file" + # ==================== Base 模型 ==================== @@ -227,7 +235,7 @@ class User(UserBase, TableBase, table=True): username: str = Field(max_length=50, unique=True, index=True) """用户名,唯一,一经注册不可更改""" - nick: str | None = Field(default=None, max_length=50) + nickname: str | None = Field(default=None, max_length=50) """用于公开展示的名字,可使用真实姓名或昵称""" password: str = Field(max_length=255) @@ -242,7 +250,7 @@ class User(UserBase, TableBase, table=True): two_factor: str | None = Field(default=None, min_length=32, max_length=32) """两步验证密钥""" - avatar: str | None = Field(default=None, max_length=255) + avatar: str = Field(default="default", max_length=255) """头像地址""" options: str | None = None diff --git a/pkg/__init__.py b/pkg/__init__.py index d8d4fd5..e69de29 100644 --- a/pkg/__init__.py +++ b/pkg/__init__.py @@ -1,4 +0,0 @@ -# 延迟导入以避免循环依赖 -# JWT 和 lifespan 应在需要时直接从子模块导入 -# from .JWT import JWT -# from .lifespan import lifespan \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cec3f65..b1972ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,10 @@ dependencies = [ "aiosqlite>=0.21.0", "argon2-cffi>=25.1.0", "fastapi[standard]>=0.122.0", + "itsdangerous>=2.2.0", "loguru>=0.7.3", "pyjwt>=2.10.1", + "pyotp>=2.9.0", "python-dotenv>=1.2.1", "python-multipart>=0.0.20", "sqlalchemy>=2.0.44", diff --git a/routers/controllers/directory.py b/routers/controllers/directory.py index 3d468cd..702cbba 100644 --- a/routers/controllers/directory.py +++ b/routers/controllers/directory.py @@ -28,13 +28,13 @@ async def router_directory_get( session: SessionDep, user: Annotated[User, Depends(AuthRequired)], path: str = "" -) -> response.ResponseModel: +) -> DirectoryResponse: """ 获取目录内容 :param session: 数据库会话 :param user: 当前登录用户 - :param path: 目录路径,空或 "/" 表示根目录 + :param path: 目录路径, "~" 表示根目录 :return: 目录内容 """ folder = await Object.get_by_path(session, user.id, path or "/") @@ -44,6 +44,9 @@ async def router_directory_get( if not folder.is_folder: raise HTTPException(status_code=400, detail="指定路径不是目录") + + if path != "~": + path = path.lstrip("~") children = await Object.get_children(session, user.id, folder.id) policy = await folder.awaitable_attrs.policy @@ -55,7 +58,7 @@ async def router_directory_get( path=f"/{child.name}", # TODO: 完整路径 thumb=False, size=child.size, - type="folder" if child.is_folder else "file", + type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE, date=child.updated_at, create_date=child.created_at, source_enabled=False, @@ -63,18 +66,17 @@ async def router_directory_get( for child in children ] - return response.ResponseModel( - data=DirectoryResponse( - parent=str(folder.parent_id) if folder.parent_id else None, - objects=objects, - policy=PolicyResponse( - id=str(policy.id), - name=policy.name, - type=policy.type.value, - max_size=policy.max_size, - file_type=[], - ), - ) + policy=PolicyResponse( + id=str(policy.id), + name=policy.name, + type=policy.type.value, + max_size=policy.max_size, + ) + + return DirectoryResponse( + parent=str(folder.parent_id) if folder.parent_id else None, + objects=objects, + policy=policy, ) diff --git a/routers/controllers/user.py b/routers/controllers/user.py index 93ad9e0..38f2e04 100644 --- a/routers/controllers/user.py +++ b/routers/controllers/user.py @@ -1,15 +1,18 @@ -from typing import Annotated +from typing import Annotated, Literal from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm from sqlalchemy import and_ from webauthn import generate_registration_options from webauthn.helpers import options_to_json_dict +import pyotp +from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired import models import service from middleware.auth import AuthRequired from middleware.dependencies import SessionDep +from pkg.JWT.JWT import SECRET_KEY user_router = APIRouter( prefix="/user", @@ -25,18 +28,46 @@ user_settings_router = APIRouter( @user_router.post( path='/session', summary='用户登录', - description='User login endpoint.', + description='User login endpoint. 当用户启用两步验证时,需要传入 otp 参数。', ) async def router_user_session( session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()], ) -> models.TokenResponse: + """ + 用户登录端点。 + + 根据 OAuth2.1 规范,使用 password grant type 进行登录。 + 当用户启用两步验证时,需要在表单中传入 otp 参数(通过 scopes 字段传递)。 + + OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码 + + :raises HTTPException 401: 用户名或密码错误 + :raises HTTPException 403: 用户账号被封禁或未完成注册 + :raises HTTPException 428: 需要两步验证但未提供验证码 + :raises HTTPException 400: 两步验证码无效 + """ username = form_data.username password = form_data.password + # 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式) + # scopes 格式可以是 ["otp:123456"] 或 ["123456"] + otp_code: str | None = None + for scope in form_data.scopes: + if scope.startswith("otp:"): + otp_code = scope[4:] + break + elif scope.isdigit() and len(scope) == 6: + otp_code = scope + break + result = await service.user.Login( session, - models.LoginRequest(username=username, password=password), + models.LoginRequest( + username=username, + password=password, + two_fa_code=otp_code, + ), ) if isinstance(result, models.TokenResponse): @@ -45,6 +76,14 @@ async def router_user_session( raise HTTPException(status_code=401, detail="Invalid username or password") elif result is False: raise HTTPException(status_code=403, detail="User account is banned or not fully registered") + elif result == "2fa_required": + raise HTTPException( + status_code=428, + detail="Two-factor authentication required", + headers={"X-2FA-Required": "true"}, + ) + elif result == "2fa_invalid": + raise HTTPException(status_code=400, detail="Invalid two-factor authentication code") else: raise HTTPException(status_code=500, detail="Internal server error during login") @@ -62,26 +101,14 @@ def router_user_register() -> models.response.ResponseModel: """ pass -@user_router.post( - path='/2fa', - summary='用两步验证登录', - description='Two-factor authentication login endpoint.', -) -def router_user_2fa() -> models.response.ResponseModel: - """ - Two-factor authentication login endpoint. - - Returns: - dict: A dictionary containing two-factor authentication information. - """ - pass - @user_router.post( path='/code', summary='发送验证码邮件', description='Send a verification code email.', ) -def router_user_email_code() -> models.response.ResponseModel: +def router_user_email_code( + reason: Literal['register', 'reset'] = 'register', +) -> models.response.ResponseModel: """ Send a verification code email. @@ -90,21 +117,6 @@ def router_user_email_code() -> models.response.ResponseModel: """ pass -@user_router.patch( - path='/reset', - summary='通过邮件里的链接重设密码', - description='Reset password via email link.', - deprecated=True, -) -def router_user_reset_patch() -> models.response.ResponseModel: - """ - Reset password via email link. - - Returns: - dict: A dictionary containing information about the password reset. - """ - pass - @user_router.get( path='/qq', summary='初始化QQ登录', @@ -193,7 +205,7 @@ def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseMode ) async def router_user_me( session: SessionDep, - user: Annotated[models.user.User, Depends(AuthRequired)], + user: Annotated[models.User, Depends(AuthRequired)], ) -> models.response.ResponseModel: """ 获取用户信息. @@ -201,25 +213,32 @@ async def router_user_me( :return: response.ResponseModel containing user information. :rtype: response.ResponseModel """ - group = await models.Group.get(session, models.Group.id == user.group_id) - - user_group = models.GroupResponse( - id=group.id, - name=group.name, - allow_share=group.share_enabled, + # 加载 group 及其 options 关系 + group = await models.Group.get( + session, + models.Group.id == user.group_id, + load=models.Group.options ) - users = models.UserResponse( + # 构建 GroupResponse + group_response = group.to_response() if group else None + + # 异步加载 tags 关系 + user_tags = await user.awaitable_attrs.tags + + user_response = models.UserResponse( id=user.id, username=user.username, - nickname=user.nick, status=user.status, - created_at=user.created_at, score=user.score, - group=user_group, - ).model_dump() + nickname=user.nickname, + avatar=user.avatar, + created_at=user.created_at, + group=group_response, + tags=[tag.name for tag in user_tags] if user_tags else [], + ) - return models.response.ResponseModel(data=users) + return models.response.ResponseModel(data=user_response.model_dump()) @user_router.get( path='/storage', @@ -425,11 +444,77 @@ def router_user_settings_patch(option: str) -> models.response.ResponseModel: description='Get two-factor authentication initialization information.', dependencies=[Depends(AuthRequired)], ) -def router_user_settings_2fa() -> models.response.ResponseModel: +async def router_user_settings_2fa( + session: SessionDep, + user: Annotated[models.user.User, Depends(AuthRequired)], +) -> models.response.ResponseModel: """ Get two-factor authentication initialization information. Returns: dict: A dictionary containing two-factor authentication setup information. """ - pass \ No newline at end of file + + serializer = URLSafeTimedSerializer(SECRET_KEY) + + secret = pyotp.random_base32() + + setup_token = serializer.dumps( + secret, + salt="2fa-setup-salt" + ) + + site_Name = await models.Setting.get(session, (models.Setting.type == models.SettingsType.BASIC) & (models.Setting.name == "siteName")) + + otp_uri = pyotp.totp.TOTP(secret).provisioning_uri( + name=user.username, + issuer_name=site_Name.value + ) + + return models.response.ResponseModel( + data={ + "setup_token": setup_token, + "otp_uri": otp_uri, + } + ) + +@user_settings_router.post( + path='/2fa', + summary='启用两步验证', + description='Enable two-factor authentication.', + dependencies=[Depends(AuthRequired)], +) +async def router_user_settings_2fa_enable( + session: SessionDep, + user: Annotated[models.user.User, Depends(AuthRequired)], + setup_token: str, + code: str, +) -> models.response.ResponseModel: + """ + Enable two-factor authentication for the user. + + Returns: + dict: A dictionary containing the result of enabling two-factor authentication. + """ + + serializer = URLSafeTimedSerializer(SECRET_KEY) + + try: + # 1. 解包 Token,设置有效期(例如 600秒) + secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600) + except SignatureExpired: + raise HTTPException(status_code=400, detail="Setup session expired") + except BadSignature: + raise HTTPException(status_code=400, detail="Invalid token") + + # 2. 验证用户输入的 6 位验证码 + if not service.user.verify_totp(secret, code): + raise HTTPException(status_code=400, detail="Invalid OTP code") + + # 3. 将 secret 存储到用户的数据库记录中,启用 2FA + user.two_factor = secret + user = await user.save(session) + + return models.response.ResponseModel( + data={"message": "Two-factor authentication enabled successfully"} + ) \ No newline at end of file diff --git a/service/__init__.py b/service/__init__.py index 54e9d6d..062c801 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -2,4 +2,4 @@ 服务层 """ -from .user import login \ No newline at end of file +from . import user \ No newline at end of file diff --git a/service/user/__init__.py b/service/user/__init__.py index 751474d..e72eaf5 100644 --- a/service/user/__init__.py +++ b/service/user/__init__.py @@ -1 +1,2 @@ -from .login import Login \ No newline at end of file +from .login import Login +from .totp import verify_totp \ No newline at end of file diff --git a/service/user/login.py b/service/user/login.py index fff8fec..578dcfa 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,17 +1,25 @@ +from typing import Literal + from loguru import logger as log from sqlmodel.ext.asyncio.session import AsyncSession from models import LoginRequest, TokenResponse, User from pkg.JWT.JWT import create_access_token, create_refresh_token +from .totp import verify_totp -async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResponse | bool | None: +async def Login( + session: AsyncSession, + login_request: LoginRequest, +) -> TokenResponse | bool | Literal["2fa_required", "2fa_invalid"] | None: """ 根据账号密码进行登录。 如果登录成功,返回一个 TokenResponse 对象,包含访问令牌和刷新令牌以及它们的过期时间。 如果登录异常,返回 `False`(未完成注册或账号被封禁)。 如果登录失败,返回 `None`。 + 如果需要两步验证但未提供验证码,返回 `"2fa_required"`。 + 如果两步验证码无效,返回 `"2fa_invalid"`。 :param session: 数据库会话 :param login_request: 登录请求 @@ -45,6 +53,18 @@ async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenResp # 未完成注册 or 账号已被封禁 return False + # 检查两步验证 + if current_user.two_factor: + # 用户已启用两步验证 + if not login_request.two_fa_code: + log.debug(f"2FA required for user: {login_request.username}") + return "2fa_required" + + # 验证 OTP 码 + if not verify_totp(current_user.two_factor, login_request.two_fa_code): + log.debug(f"Invalid 2FA code for user: {login_request.username}") + return "2fa_invalid" + # 创建令牌 access_token, access_expire = create_access_token(data={'sub': current_user.username}) refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username}) diff --git a/service/user/totp.py b/service/user/totp.py new file mode 100644 index 0000000..6afd906 --- /dev/null +++ b/service/user/totp.py @@ -0,0 +1,13 @@ +import pyotp + + +def verify_totp(secret: str, code: str) -> bool: + """ + 验证 TOTP 验证码。 + + :param secret: TOTP 密钥(Base32 编码) + :param code: 用户输入的 6 位验证码 + :return: 验证是否成功 + """ + totp = pyotp.TOTP(secret) + return totp.verify(code) diff --git a/uv.lock b/uv.lock index e3b04a1..43df353 100644 --- a/uv.lock +++ b/uv.lock @@ -364,8 +364,10 @@ dependencies = [ { name = "aiosqlite" }, { name = "argon2-cffi" }, { name = "fastapi", extra = ["standard"] }, + { name = "itsdangerous" }, { name = "loguru" }, { name = "pyjwt" }, + { name = "pyotp" }, { name = "python-dotenv" }, { name = "python-multipart" }, { name = "sqlalchemy" }, @@ -380,8 +382,10 @@ requires-dist = [ { name = "aiosqlite", specifier = ">=0.21.0" }, { name = "argon2-cffi", specifier = ">=25.1.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.122.0" }, + { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "loguru", specifier = ">=0.7.3" }, { name = "pyjwt", specifier = ">=2.10.1" }, + { name = "pyotp", specifier = ">=2.9.0" }, { name = "python-dotenv", specifier = ">=1.2.1" }, { name = "python-multipart", specifier = ">=0.0.20" }, { name = "sqlalchemy", specifier = ">=2.0.44" }, @@ -698,6 +702,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -1058,6 +1071,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/81/ef2b1dfd1862567d573a4fdbc9f969067621764fbb74338496840a1d2977/pyopenssl-25.3.0-py3-none-any.whl", hash = "sha256:1fda6fc034d5e3d179d39e59c1895c9faeaf40a79de5fc4cbbfbe0d36f4a77b6", size = 57268, upload-time = "2025-09-17T00:32:19.474Z" }, ] +[[package]] +name = "pyotp" +version = "2.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/b2/1d5994ba2acde054a443bd5e2d384175449c7d2b6d1a0614dbca3a63abfc/pyotp-2.9.0.tar.gz", hash = "sha256:346b6642e0dbdde3b4ff5a930b664ca82abfa116356ed48cc42c7d6590d36f63", size = 17763, upload-time = "2023-07-27T23:41:03.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/c0/c33c8792c3e50193ef55adb95c1c3c2786fe281123291c2dbf0eaab95a6f/pyotp-2.9.0-py3-none-any.whl", hash = "sha256:81c2e5865b8ac55e825b0358e496e1d9387c811e85bb40e71a3b29b288963612", size = 13376, upload-time = "2023-07-27T23:41:01.685Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.1"