feat: add models for physical files, policies, and user management

- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
2026-02-10 16:25:49 +08:00
parent 62c671e07b
commit 209cb24ab4
92 changed files with 3640 additions and 1444 deletions

View File

@@ -5,15 +5,15 @@ from loguru import logger as l
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from models import (
from sqlmodels import (
User, ResponseBase,
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
)
from models.base import SQLModelBase
from models.setting import (
from sqlmodels.base import SQLModelBase
from sqlmodels.setting import (
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
)
from models.setting import SettingsType
from sqlmodels.setting import SettingsType
from utils import http_exceptions
from utils.conf import appmeta
from .file import admin_file_router

View File

@@ -5,14 +5,60 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger as l
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from models import (
Policy, PolicyType, User, ResponseBase, ListResponse,
from sqlmodels import (
Policy, PolicyType, User, ListResponse,
Object, ObjectType, AdminFileResponse, FileBanRequest, )
from service.storage import LocalStorageService
async def _set_ban_recursive(
session: AsyncSession,
obj: Object,
ban: bool,
admin_id: UUID,
reason: str | None,
) -> int:
"""
递归设置封禁状态,返回受影响对象数量。
:param session: 数据库会话
:param obj: 要封禁/解禁的对象
:param ban: True=封禁, False=解禁
:param admin_id: 管理员UUID
:param reason: 封禁原因
:return: 受影响的对象数量
"""
count = 0
# 如果是文件夹,先递归处理子对象
if obj.is_folder:
children = await Object.get(
session,
Object.parent_id == obj.id,
fetch_mode="all",
)
for child in children:
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
# 设置当前对象
obj.is_banned = ban
if ban:
obj.banned_at = datetime.now()
obj.banned_by = admin_id
obj.ban_reason = reason
else:
obj.banned_at = None
obj.banned_by = None
obj.ban_reason = None
await obj.save(session)
count += 1
return count
admin_file_router = APIRouter(
prefix="/file",
tags=["admin", "admin_file"],
@@ -119,15 +165,17 @@ async def router_admin_preview_file(
summary='封禁/解禁文件',
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_ban_file(
session: SessionDep,
file_id: UUID,
request: FileBanRequest,
admin: Annotated[User, Depends(admin_required)],
) -> ResponseBase:
) -> None:
"""
封禁或解禁文件。封禁后用户无法访问该文件。
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
封禁文件夹时会级联封禁所有子对象。
:param session: 数据库会话
:param file_id: 文件UUID
@@ -139,24 +187,10 @@ async def router_admin_ban_file(
if not file_obj:
raise HTTPException(status_code=404, detail="文件不存在")
file_obj.is_banned = request.is_banned
if request.is_banned:
file_obj.banned_at = datetime.now()
file_obj.banned_by = admin.id
file_obj.ban_reason = request.reason
else:
file_obj.banned_at = None
file_obj.banned_by = None
file_obj.ban_reason = None
count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason)
file_obj = await file_obj.save(session)
action = "封禁" if request.is_banned else "解禁"
l.info(f"管理员{action}了文件: {file_obj.name}")
return ResponseBase(data={
"id": str(file_obj.id),
"is_banned": file_obj.is_banned,
})
action = "封禁" if request.ban else "解禁"
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
@admin_file_router.delete(
@@ -164,12 +198,13 @@ async def router_admin_ban_file(
summary='删除文件',
description='Delete file by ID',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_delete_file(
session: SessionDep,
file_id: UUID,
delete_physical: bool = True,
) -> ResponseBase:
) -> None:
"""
删除文件。
@@ -211,5 +246,4 @@ async def router_admin_delete_file(
# 使用条件删除
await Object.delete(session, condition=Object.id == file_obj.id)
l.info(f"管理员删除了文件: {file_name}")
return ResponseBase(data={"deleted": True})
l.info(f"管理员删除了文件: {file_name}")

View File

@@ -5,12 +5,12 @@ from loguru import logger as l
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from models import (
from sqlmodels import (
User, ResponseBase, UserPublic, ListResponse,
Group, GroupOptions, )
from models.group import (
from sqlmodels.group import (
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
from models.policy import GroupPolicyLink
from sqlmodels.policy import GroupPolicyLink
admin_group_router = APIRouter(
prefix="/group",
@@ -113,11 +113,12 @@ async def router_admin_get_group_members(
summary='创建用户组',
description='Create a new user group',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_create_group(
session: SessionDep,
request: GroupCreateRequest,
) -> ResponseBase:
) -> None:
"""
创建新的用户组。
@@ -164,7 +165,6 @@ async def router_admin_create_group(
await session.commit()
l.info(f"管理员创建了用户组: {group.name}")
return ResponseBase(data={"id": str(group.id), "name": group.name})
@admin_group_router.patch(
@@ -172,12 +172,13 @@ async def router_admin_create_group(
summary='更新用户组信息',
description='Update user group information by ID',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_update_group(
session: SessionDep,
group_id: UUID,
request: GroupUpdateRequest,
) -> ResponseBase:
) -> None:
"""
根据用户组ID更新用户组信息。
@@ -233,8 +234,7 @@ async def router_admin_update_group(
session.add(link)
await session.commit()
l.info(f"管理员更新了用户组: {group.name}")
return ResponseBase(data={"id": str(group.id)})
l.info(f"管理员更新了用户组: {group_id}")
@admin_group_router.delete(
@@ -242,11 +242,12 @@ async def router_admin_update_group(
summary='删除用户组',
description='Delete user group by ID',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_delete_group(
session: SessionDep,
group_id: UUID,
) -> ResponseBase:
) -> None:
"""
根据用户组ID删除用户组。
@@ -271,5 +272,4 @@ async def router_admin_delete_group(
group_name = group.name
await Group.delete(session, group)
l.info(f"管理员删除了用户组: {group_name}")
return ResponseBase(data={"deleted": True})
l.info(f"管理员删除了用户组: {group_id}")

View File

@@ -6,10 +6,10 @@ from sqlmodel import Field
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from models import (
from sqlmodels import (
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
ListResponse, Object, )
from models.base import SQLModelBase
from sqlmodels.base import SQLModelBase
from service.storage import DirectoryCreationError, LocalStorageService
admin_policy_router = APIRouter(

View File

@@ -5,7 +5,7 @@ from loguru import logger as l
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from models import (
from sqlmodels import (
ResponseBase, ListResponse,
Share, AdminShareListItem, )
@@ -80,7 +80,7 @@ async def router_admin_get_share(
"score": share.score,
"has_password": bool(share.password),
"user_id": str(share.user_id),
"username": user.username if user else None,
"username": user.email if user else None,
"object": {
"id": str(obj.id),
"name": obj.name,

View File

@@ -5,7 +5,7 @@ from loguru import logger as l
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep
from models import (
from sqlmodels import (
ResponseBase, ListResponse,
Task, TaskSummary,
)
@@ -89,7 +89,7 @@ async def router_admin_get_task(
"progress": task.progress,
"error": task.error,
"user_id": str(task.user_id),
"username": user.username if user else None,
"username": user.email if user else None,
"props": props.model_dump() if props else None,
"created_at": task.created_at.isoformat(),
"updated_at": task.updated_at.isoformat(),

View File

@@ -6,11 +6,13 @@ from sqlalchemy import func
from middleware.auth import admin_required
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
from models import (
from sqlmodels import (
User, ResponseBase, UserPublic, ListResponse,
Group, Object, ObjectType, )
from models.user import (
UserAdminUpdateRequest, UserCalibrateResponse,
Group, Object, ObjectType, Setting, SettingsType,
BatchDeleteRequest,
)
from sqlmodels.user import (
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse,
)
from utils import Password, http_exceptions
@@ -26,19 +28,19 @@ admin_user_router = APIRouter(
description='Get user information by ID',
dependencies=[Depends(admin_required)],
)
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
"""
根据用户ID获取用户信息包括用户名、邮箱、注册时间等。
Args:
session(SessionDep): 数据库会话依赖项。
user_id (int): 用户ID。
user_id (UUID): 用户ID。
Returns:
ResponseBase: 包含用户信息的响应模型。
"""
user = await User.get_exist_one(session, user_id)
return ResponseBase(data=user.to_public().model_dump())
return user.to_public()
@admin_user_router.get(
@@ -60,7 +62,7 @@ async def router_admin_get_users(
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
:return: 分页用户列表
"""
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view)
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
return ListResponse(
items=[user.to_public() for user in result.items],
count=result.count,
@@ -75,22 +77,33 @@ async def router_admin_get_users(
)
async def router_admin_create_user(
session: SessionDep,
user: User,
) -> ResponseBase:
request: UserAdminCreateRequest,
) -> UserPublic:
"""
创建一个新的用户,设置用户名、密码等信息。
创建一个新的用户,设置邮箱、密码、用户组等信息。
Returns:
ResponseBase: 包含创建结果的响应模型。
:param session: 数据库会话
:param request: 创建用户请求 DTO
:return: 创建结果
"""
existing_user = await User.get(session, User.username == user.username)
existing_user = await User.get(session, User.email == request.email)
if existing_user:
return ResponseBase(
code=400,
msg="User with this username already exists."
)
raise HTTPException(status_code=409, detail="该邮箱已被注册")
# 验证用户组存在
group = await Group.get(session, Group.id == request.group_id)
if not group:
raise HTTPException(status_code=400, detail="目标用户组不存在")
user = User(
email=request.email,
password=Password.hash(request.password),
nickname=request.nickname,
group_id=request.group_id,
status=request.status,
)
user = await user.save(session)
return ResponseBase(data=user.to_public().model_dump())
return user.to_public()
@admin_user_router.patch(
@@ -98,12 +111,13 @@ async def router_admin_create_user(
summary='更新用户信息',
description='Update user information by ID',
dependencies=[Depends(admin_required)],
status_code=204
)
async def router_admin_update_user(
session: SessionDep,
user_id: UUID,
request: UserAdminUpdateRequest,
) -> ResponseBase:
) -> None:
"""
根据用户ID更新用户信息。
@@ -116,8 +130,15 @@ async def router_admin_update_user(
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
# 默认管理员(用户名为 admin不允许更改用户组
if request.group_id and user.username == "admin" and request.group_id != user.group_id:
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
default_admin_setting = await Setting.get(
session,
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
)
if (request.group_id
and default_admin_setting
and default_admin_setting.value == str(user_id)
and request.group_id != user.group_id):
http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
# 如果更新用户组,验证新组存在
@@ -143,38 +164,35 @@ async def router_admin_update_user(
setattr(user, key, value)
user = await user.save(session)
l.info(f"管理员更新了用户: {user.username}")
return ResponseBase(data=user.to_public().model_dump())
l.info(f"管理员更新了用户: {request.email}")
@admin_user_router.delete(
path='/{user_id}',
summary='删除用户',
description='Delete user by ID',
path='/',
summary='删除用户(支持批量)',
description='Delete users by ID list',
dependencies=[Depends(admin_required)],
status_code=204,
)
async def router_admin_delete_user(
async def router_admin_delete_users(
session: SessionDep,
user_id: UUID,
) -> ResponseBase:
request: BatchDeleteRequest,
) -> None:
"""
根据用户ID删除用户及其所有数据。
批量删除用户及其所有数据。
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
:param session: 数据库会话
:param user_id: 用户UUID
:return: 删除结果
:param request: 批量删除请求,包含待删除用户的 UUID 列表
:return: 删除结果(已删除数 / 总请求数)
"""
user = await User.get(session, User.id == user_id)
if not user:
raise HTTPException(status_code=404, detail="用户不存在")
username = user.username
await User.delete(session, user)
l.info(f"管理员删除了用户: {username}")
return ResponseBase(data={"deleted": True})
deleted = 0
for uid in request.ids:
user = await User.get(session, User.id == uid)
if user:
await User.delete(session, user)
l.info(f"管理员删除了用户: {user.email}")
@admin_user_router.post(
@@ -186,7 +204,7 @@ async def router_admin_delete_user(
async def router_admin_calibrate_storage(
session: SessionDep,
user_id: UUID,
) -> ResponseBase:
) -> UserCalibrateResponse:
"""
重新计算用户的已用存储空间。
@@ -228,5 +246,5 @@ async def router_admin_calibrate_storage(
file_count=file_count,
)
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}")
return ResponseBase(data=response.model_dump())
l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
return response

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import admin_required
from middleware.dependencies import SessionDep
from models import (
from sqlmodels import (
ResponseBase,
)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Query
from fastapi.responses import PlainTextResponse
from models import ResponseBase
from sqlmodels import ResponseBase
import service.oauth
from utils import http_exceptions

View File

@@ -1,10 +1,12 @@
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import (
from sqlmodels import (
DirectoryCreateRequest,
DirectoryResponse,
Object,
@@ -14,50 +16,28 @@ from models import (
User,
ResponseBase,
)
from utils import http_exceptions
directory_router = APIRouter(
prefix="/directory",
tags=["directory"]
)
@directory_router.get(
path="/{path:path}",
summary="获取目录内容",
)
async def router_directory_get(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
path: str
async def _get_directory_response(
session: AsyncSession,
user_id: UUID,
folder: Object,
) -> DirectoryResponse:
"""
获取目录内容
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
构建目录响应 DTO
:param session: 数据库会话
:param user: 当前登录用户
:param path: 目录路径(必须以用户名开头)
:return: 目录内容
:param user_id: 用户UUID
:param folder: 目录对象
:return: DirectoryResponse
"""
# 路径必须以用户名开头
path = path.strip("/")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
path_parts = path.split("/")
if path_parts[0] != user.username:
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
if not folder:
raise HTTPException(status_code=404, detail="目录不存在")
if not folder.is_folder:
raise HTTPException(status_code=400, detail="指定路径不是目录")
children = await Object.get_children(session, user.id, folder.id)
children = await Object.get_children(session, user_id, folder.id)
policy = await folder.awaitable_attrs.policy
objects = [
@@ -67,8 +47,8 @@ async def router_directory_get(
thumb=False,
size=child.size,
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
date=child.updated_at,
create_date=child.created_at,
created_at=child.created_at,
updated_at=child.updated_at,
source_enabled=False,
)
for child in children
@@ -89,7 +69,74 @@ async def router_directory_get(
)
@directory_router.put(
@directory_router.get(
path="/",
summary="获取根目录内容",
)
async def router_directory_root(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
) -> DirectoryResponse:
"""
获取当前用户的根目录内容
:param session: 数据库会话
:param user: 当前登录用户
:return: 根目录内容
"""
root = await Object.get_root(session, user.id)
if not root:
raise HTTPException(status_code=404, detail="根目录不存在")
if root.is_banned:
http_exceptions.raise_banned()
return await _get_directory_response(session, user.id, root)
@directory_router.get(
path="/{path:path}",
summary="获取目录内容",
)
async def router_directory_get(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
path: str
) -> DirectoryResponse:
"""
获取目录内容
路径从用户根目录开始,不包含用户名前缀。
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
:param session: 数据库会话
:param user: 当前登录用户
:param path: 目录路径(从根目录开始的相对路径)
:return: 目录内容
"""
path = path.strip("/")
if not path:
# 空路径交给根目录端点处理(理论上不会到达这里)
root = await Object.get_root(session, user.id)
if not root:
raise HTTPException(status_code=404, detail="根目录不存在")
return await _get_directory_response(session, user.id, root)
folder = await Object.get_by_path(session, user.id, "/" + path)
if not folder:
raise HTTPException(status_code=404, detail="目录不存在")
if not folder.is_folder:
raise HTTPException(status_code=400, detail="指定路径不是目录")
if folder.is_banned:
http_exceptions.raise_banned()
return await _get_directory_response(session, user.id, folder)
@directory_router.post(
path="/",
summary="创建目录",
)
@@ -123,6 +170,9 @@ async def router_directory_create(
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父路径不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名对象
existing = await Object.get(
session,

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from models import ResponseBase
from sqlmodels import ResponseBase
from utils import http_exceptions
download_router = APIRouter(

View File

@@ -18,7 +18,7 @@ from loguru import logger as l
from middleware.auth import auth_required, verify_download_token
from middleware.dependencies import SessionDep
from models import (
from sqlmodels import (
CreateFileRequest,
CreateUploadSessionRequest,
Object,
@@ -91,6 +91,9 @@ async def create_upload_session(
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id)
@@ -100,7 +103,7 @@ async def create_upload_session(
# 验证文件大小限制
if policy.max_size > 0 and request.file_size > policy.max_size:
raise HTTPException(
status_code=400,
status_code=413,
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
)
@@ -221,30 +224,40 @@ async def upload_chunk(
upload_session.uploaded_size += len(content)
upload_session = await upload_session.save(session)
# 检查是否完成
# 在后续可能的 commit 前保存需要的属性
is_complete = upload_session.is_complete
uploaded_chunks = upload_session.uploaded_chunks
total_chunks = upload_session.total_chunks
file_object_id: UUID | None = None
if is_complete:
# 保存 upload_session 属性commit 后会过期)
file_name = upload_session.file_name
uploaded_size = upload_session.uploaded_size
storage_path = upload_session.storage_path
upload_session_id = upload_session.id
parent_id = upload_session.parent_id
policy_id = upload_session.policy_id
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=upload_session.storage_path,
size=upload_session.uploaded_size,
policy_id=upload_session.policy_id,
storage_path=storage_path,
size=uploaded_size,
policy_id=policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session, commit=False)
# 创建 Object 记录
file_object = Object(
name=upload_session.file_name,
name=file_name,
type=ObjectType.FILE,
size=upload_session.uploaded_size,
size=uploaded_size,
physical_file_id=physical_file.id,
upload_session_id=str(upload_session.id),
parent_id=upload_session.parent_id,
upload_session_id=str(upload_session_id),
parent_id=parent_id,
owner_id=user_id,
policy_id=upload_session.policy_id,
policy_id=policy_id,
)
file_object = await file_object.save(session, commit=False)
file_object_id = file_object.id
@@ -252,18 +265,18 @@ async def upload_chunk(
# 删除上传会话(使用条件删除)
await UploadSession.delete(
session,
condition=UploadSession.id == upload_session.id,
condition=UploadSession.id == upload_session_id,
commit=False
)
# 统一提交所有更改
await session.commit()
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}")
return UploadChunkResponse(
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks,
total_chunks=upload_session.total_chunks,
uploaded_chunks=uploaded_chunks if not is_complete else total_chunks,
total_chunks=total_chunks,
is_complete=is_complete,
object_id=file_object_id,
)
@@ -368,6 +381,9 @@ async def create_download_token_endpoint(
if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件")
if file_obj.is_banned:
http_exceptions.raise_banned()
token = create_download_token(file_id, user.id)
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
@@ -410,6 +426,9 @@ async def download_file(
if not file_obj.is_file:
raise HTTPException(status_code=400, detail="对象不是文件")
if file_obj.is_banned:
http_exceptions.raise_banned()
# 预加载 physical_file 关系以获取存储路径
physical_file = await file_obj.awaitable_attrs.physical_file
if not physical_file or not physical_file.storage_path:
@@ -470,6 +489,9 @@ async def create_empty_file(
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名文件
existing = await Object.get(
session,

View File

@@ -1,19 +0,0 @@
from fastapi import APIRouter
from models import MCPRequestBase, MCPResponseBase, MCPMethod
# MCP 路由
MCP_router = APIRouter(
prefix='/mcp',
tags=["mcp"],
)
@MCP_router.get(
"/",
)
async def mcp_root(
param: MCPRequestBase
):
match param.method:
case MCPMethod.PING:
return MCPResponseBase(result="pong", **param.model_dump())

View File

@@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import (
from sqlmodels import (
CreateFileRequest,
Object,
ObjectCopyRequest,
ObjectDeleteRequest,
@@ -26,10 +27,11 @@ from models import (
PhysicalFile,
Policy,
PolicyType,
ResponseBase,
User,
)
from models import ResponseBase
from service.storage import LocalStorageService
from utils import http_exceptions
object_router = APIRouter(
prefix="/object",
@@ -59,15 +61,22 @@ async def _delete_object_recursive(
"""
deleted_count = 0
if obj.is_folder:
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
obj_id = obj.id
obj_name = obj.name
obj_is_folder = obj.is_folder
obj_is_file = obj.is_file
obj_physical_file_id = obj.physical_file_id
if obj_is_folder:
# 递归删除子对象
children = await Object.get_children(session, user_id, obj.id)
children = await Object.get_children(session, user_id, obj_id)
for child in children:
deleted_count += await _delete_object_recursive(session, child, user_id)
# 如果是文件,处理物理文件引用
if obj.is_file and obj.physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
if obj_is_file and obj_physical_file_id:
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
if physical_file:
# 减少引用计数
new_count = physical_file.decrement_reference()
@@ -81,11 +90,11 @@ async def _delete_object_recursive(
await storage_service.move_to_trash(
source_path=physical_file.storage_path,
user_id=user_id,
object_id=obj.id,
object_id=obj_id,
)
l.debug(f"物理文件已移动到回收站: {obj.name}")
l.debug(f"物理文件已移动到回收站: {obj_name}")
except Exception as e:
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
# 删除 PhysicalFile 记录
await PhysicalFile.delete(session, physical_file)
@@ -95,8 +104,8 @@ async def _delete_object_recursive(
await physical_file.save(session)
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
# 删除数据库记录
await Object.delete(session, obj)
# 使用条件删除,避免访问过期的 obj 实例
await Object.delete(session, condition=Object.id == obj_id)
deleted_count += 1
return deleted_count
@@ -168,6 +177,97 @@ async def _copy_object_recursive(
return copied_count, new_ids
@object_router.post(
path='/',
summary='创建空白文件',
description='在指定目录下创建空白文件。',
)
async def router_object_create(
session: SessionDep,
user: Annotated[User, Depends(auth_required)],
request: CreateFileRequest,
) -> ResponseBase:
"""
创建空白文件端点
:param session: 数据库会话
:param user: 当前登录用户
:param request: 创建文件请求parent_id, name
:return: 创建结果
"""
user_id = user.id
# 验证文件名
if not request.name or '/' in request.name or '\\' in request.name:
raise HTTPException(status_code=400, detail="无效的文件名")
# 验证父目录
parent = await Object.get(session, Object.id == request.parent_id)
if not parent or parent.owner_id != user_id:
raise HTTPException(status_code=404, detail="父目录不存在")
if not parent.is_folder:
raise HTTPException(status_code=400, detail="父对象不是目录")
if parent.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 检查是否已存在同名文件
existing = await Object.get(
session,
(Object.owner_id == user_id) &
(Object.parent_id == parent.id) &
(Object.name == request.name)
)
if existing:
raise HTTPException(status_code=409, detail="同名文件已存在")
# 确定存储策略
policy_id = request.policy_id or parent.policy_id
policy = await Policy.get(session, Policy.id == policy_id)
if not policy:
raise HTTPException(status_code=404, detail="存储策略不存在")
parent_id = parent.id
# 生成存储路径并创建空文件
if policy.type == PolicyType.LOCAL:
storage_service = LocalStorageService(policy)
dir_path, storage_name, full_path = await storage_service.generate_file_path(
user_id=user_id,
original_filename=request.name,
)
await storage_service.create_empty_file(full_path)
storage_path = full_path
else:
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
# 创建 PhysicalFile 记录
physical_file = PhysicalFile(
storage_path=storage_path,
size=0,
policy_id=policy_id,
reference_count=1,
)
physical_file = await physical_file.save(session)
# 创建 Object 记录
file_object = Object(
name=request.name,
type=ObjectType.FILE,
size=0,
physical_file_id=physical_file.id,
parent_id=parent_id,
owner_id=user_id,
policy_id=policy_id,
)
await file_object.save(session)
l.info(f"创建空白文件: {request.name}")
return ResponseBase()
@object_router.delete(
path='/',
summary='删除对象',
@@ -197,10 +297,7 @@ async def router_object_delete(
user_id = user.id
deleted_count = 0
# 处理单个 UUID 或 UUID 列表
ids = request.ids if isinstance(request.ids, list) else [request.ids]
for obj_id in ids:
for obj_id in request.ids:
obj = await Object.get(session, Object.id == obj_id)
if not obj or obj.owner_id != user_id:
continue
@@ -219,7 +316,7 @@ async def router_object_delete(
return ResponseBase(
data={
"deleted": deleted_count,
"total": len(ids),
"total": len(request.ids),
}
)
@@ -253,6 +350,9 @@ async def router_object_move(
if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
if dst.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
dst_id = dst.id
dst_parent_id = dst.parent_id
@@ -264,6 +364,9 @@ async def router_object_move(
if not src or src.owner_id != user_id:
continue
if src.is_banned:
continue
# 不能移动根目录
if src.parent_id is None:
continue
@@ -348,6 +451,9 @@ async def router_object_copy(
if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
if dst.is_banned:
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
copied_count = 0
new_ids: list[UUID] = []
@@ -356,6 +462,9 @@ async def router_object_copy(
if not src or src.owner_id != user_id:
continue
if src.is_banned:
continue
# 不能复制根目录
if src.parent_id is None:
continue
@@ -438,6 +547,9 @@ async def router_object_rename(
if obj.owner_id != user_id:
raise HTTPException(status_code=403, detail="无权操作此对象")
if obj.is_banned:
http_exceptions.raise_banned()
# 不能重命名根目录
if obj.parent_id is None:
raise HTTPException(status_code=400, detail="无法重命名根目录")
@@ -543,7 +655,7 @@ async def router_object_property_detail(
policy_name = policy.name if policy else None
# 获取分享统计
from models import Share
from sqlmodels import Share
shares = await Share.get(
session,
Share.object_id == obj.id,

View File

@@ -7,11 +7,11 @@ from loguru import logger as l
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from models import ResponseBase
from models.user import User
from models.share import Share, ShareCreateRequest, ShareResponse
from models.object import Object
from models.mixin import ListResponse, TableViewRequest
from sqlmodels import ResponseBase
from sqlmodels.user import User
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
from sqlmodels.object import Object
from sqlmodels.mixin import ListResponse, TableViewRequest
from utils import http_exceptions
from utils.password.pwd import Password
@@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase:
"""
http_exceptions.raise_not_implemented()
@share_router.get(
path='/doc/{id}',
summary='取得Office文档预览地址',
description='Get Office document preview URL by ID.',
)
def router_share_doc(id: str) -> ResponseBase:
"""
Get Office document preview URL by ID.
Args:
id (str): The ID of the Office document.
Returns:
dict: A dictionary containing the document preview URL.
"""
http_exceptions.raise_not_implemented()
@share_router.get(
path='/content/{id}',
summary='获取文本文件内容',
@@ -261,6 +244,9 @@ async def router_share_create(
if not obj or obj.owner_id != user.id:
raise HTTPException(status_code=404, detail="对象不存在或无权限")
if obj.is_banned:
http_exceptions.raise_banned()
# 生成分享码
code = str(uuid4())

View File

@@ -1,7 +1,8 @@
from fastapi import APIRouter
from middleware.dependencies import SessionDep
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
from sqlmodels.setting import CaptchaType
from utils import http_exceptions
site_router = APIRouter(
@@ -43,16 +44,43 @@ def router_site_captcha():
@site_router.get(
path='/config',
summary='站点全局配置',
description='Get the configuration file.',
response_model=ResponseBase,
description='获取站点全局配置,包括验证码设置、注册开关等。',
)
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
"""
Get the configuration file.
获取站点全局配置
Returns:
dict: The site configuration.
无需认证。前端在初始化时调用此端点获取验证码类型、
登录/注册/找回密码是否需要验证码等配置。
"""
# 批量查询所需设置
settings: list[Setting] = await Setting.get(
session,
(Setting.type == SettingsType.BASIC) |
(Setting.type == SettingsType.LOGIN) |
(Setting.type == SettingsType.REGISTER) |
(Setting.type == SettingsType.CAPTCHA),
fetch_mode="all",
)
# 构建 name→value 映射
s: dict[str, str | None] = {item.name: item.value for item in settings}
# 根据 captcha_type 选择对应的 public key
captcha_type_str = s.get("captcha_type", "default")
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
captcha_key: str | None = None
if captcha_type == CaptchaType.GCAPTCHA:
captcha_key = s.get("captcha_ReCaptchaKey") or None
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
captcha_key = s.get("captcha_CloudflareKey") or None
return SiteConfigResponse(
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
title=s.get("siteName") or "DiskNext",
register_enabled=s.get("register_enabled") == "1",
login_captcha=s.get("login_captcha") == "1",
reg_captcha=s.get("reg_captcha") == "1",
forget_captcha=s.get("forget_captcha") == "1",
captcha_type=captcha_type,
captcha_key=captcha_key,
)

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from middleware.auth import auth_required
from models import ResponseBase
from sqlmodels import ResponseBase
from utils import http_exceptions
slave_router = APIRouter(

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from models import ResponseBase
from sqlmodels import ResponseBase
from utils import http_exceptions
tag_router = APIRouter(

View File

@@ -1,30 +1,26 @@
from typing import Annotated, Literal
from uuid import UUID
from uuid import UUID, uuid4
import jwt
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from loguru import logger
from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from loguru import logger
import models
import service
import sqlmodels
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from utils.JWT import SECRET_KEY
from utils import Password, http_exceptions
from utils import JWT, Password, http_exceptions
from .settings import user_settings_router
user_router = APIRouter(
prefix="/user",
tags=["user"],
)
user_settings_router = APIRouter(
prefix='/user/settings',
tags=["user", "user_settings"],
dependencies=[Depends(auth_required)],
)
user_router.include_router(user_settings_router)
@user_router.post(
path='/session',
@@ -34,7 +30,7 @@ user_settings_router = APIRouter(
async def router_user_session(
session: SessionDep,
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
) -> models.TokenResponse:
) -> sqlmodels.TokenResponse:
"""
用户登录端点。
@@ -43,7 +39,7 @@ async def router_user_session(
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
"""
username = form_data.username
email = form_data.username # OAuth2 表单字段名为 username实际传入的是 email
password = form_data.password
# 从 scopes 中提取 OTP 验证码OAuth2.1 扩展方式)
@@ -59,8 +55,8 @@ async def router_user_session(
result = await service.user.login(
session,
models.LoginRequest(
username=username,
sqlmodels.LoginRequest(
email=email,
password=password,
two_fa_code=otp_code,
),
@@ -75,19 +71,70 @@ async def router_user_session(
)
async def router_user_session_refresh(
session: SessionDep,
request, # RefreshTokenRequest
) -> models.TokenResponse:
http_exceptions.raise_not_implemented()
request: sqlmodels.RefreshTokenRequest,
) -> sqlmodels.TokenResponse:
"""
使用 refresh_token 签发新的 access_token 和 refresh_token。
流程:
1. 解码 refresh_token JWT
2. 验证 token_type 为 refresh
3. 验证用户存在且状态正常
4. 签发新的 access_token + refresh_token
:param session: 数据库会话
:param request: 刷新令牌请求
:return: 新的 TokenResponse
"""
try:
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
except jwt.InvalidTokenError:
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
# 验证是 refresh token
if payload.get("token_type") != "refresh":
http_exceptions.raise_unauthorized("非刷新令牌")
user_id_str = payload.get("sub")
if not user_id_str:
http_exceptions.raise_unauthorized("令牌缺少用户标识")
user_id = UUID(user_id_str)
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id)
if not user:
http_exceptions.raise_unauthorized("用户不存在")
if not user.status:
http_exceptions.raise_forbidden("账户已被禁用")
# 签发新令牌
access_token = JWT.create_access_token(
sub=user.id,
jti=uuid4(),
)
refresh_token = JWT.create_refresh_token(
sub=user.id,
jti=uuid4(),
)
return sqlmodels.TokenResponse(
access_token=access_token.access_token,
access_expires=access_token.access_expires,
refresh_token=refresh_token.refresh_token,
refresh_expires=refresh_token.refresh_expires,
)
@user_router.post(
path='/',
summary='用户注册',
description='User registration endpoint.',
status_code=204,
)
async def router_user_register(
session: SessionDep,
request: models.RegisterRequest,
) -> models.ResponseBase:
request: sqlmodels.RegisterRequest,
) -> None:
"""
用户注册端点
@@ -95,7 +142,7 @@ async def router_user_register(
1. 验证用户名唯一性
2. 获取默认用户组
3. 创建用户记录
4. 创建用户名命名的根目录
4. 创建用户根目录name="/"
:param session: 数据库会话
:param request: 注册请求
@@ -103,62 +150,53 @@ async def router_user_register(
:raises HTTPException 400: 用户名已存在
:raises HTTPException 500: 默认用户组或存储策略不存在
"""
# 1. 验证用户名唯一性
existing_user = await models.User.get(
# 1. 验证邮箱唯一性
existing_user = await sqlmodels.User.get(
session,
models.User.username == request.username
sqlmodels.User.email == request.email
)
if existing_user:
raise HTTPException(status_code=400, detail="用户名已存在")
raise HTTPException(status_code=400, detail="邮箱已存在")
# 2. 获取默认用户组(从设置中读取 UUID
default_group_setting: models.Setting | None = await models.Setting.get(
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
session,
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
)
if default_group_setting is None or not default_group_setting.value:
logger.error("默认用户组不存在")
http_exceptions.raise_internal_error()
default_group_id = UUID(default_group_setting.value)
default_group = await models.Group.get(session, models.Group.id == default_group_id)
default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
if not default_group:
logger.error("默认用户组不存在")
http_exceptions.raise_internal_error()
# 3. 创建用户
hashed_password = Password.hash(request.password)
new_user = models.User(
username=request.username,
new_user = sqlmodels.User(
email=request.email,
password=hashed_password,
group_id=default_group.id,
)
new_user_id = new_user.id # 在 save 前保存 UUID
new_user_username = new_user.username
new_user_id = new_user.id
await new_user.save(session)
# 4. 创建用户名命名的根目录
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
# 4. 创建用户根目录
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
if not default_policy:
logger.error("默认存储策略不存在")
http_exceptions.raise_internal_error()
await models.Object(
name=new_user_username,
type=models.ObjectType.FOLDER,
await sqlmodels.Object(
name="/",
type=sqlmodels.ObjectType.FOLDER,
owner_id=new_user_id,
parent_id=None,
policy_id=default_policy.id,
).save(session)
return models.ResponseBase(
data={
"user_id": new_user_id,
"username": new_user_username,
},
msg="注册成功",
)
@user_router.post(
path='/code',
summary='发送验证码邮件',
@@ -166,7 +204,7 @@ async def router_user_register(
)
def router_user_email_code(
reason: Literal['register', 'reset'] = 'register',
) -> models.ResponseBase:
) -> sqlmodels.ResponseBase:
"""
Send a verification code email.
@@ -180,7 +218,7 @@ def router_user_email_code(
summary='初始化QQ登录',
description='Initialize QQ login for a user.',
)
def router_user_qq() -> models.ResponseBase:
def router_user_qq() -> sqlmodels.ResponseBase:
"""
Initialize QQ login for a user.
@@ -194,7 +232,7 @@ def router_user_qq() -> models.ResponseBase:
summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.',
)
async def router_user_authn(username: str) -> models.ResponseBase:
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
http_exceptions.raise_not_implemented()
@@ -203,7 +241,7 @@ async def router_user_authn(username: str) -> models.ResponseBase:
summary='WebAuthn登录',
description='Finish WebAuthn login for a user.',
)
def router_user_authn_finish(username: str) -> models.ResponseBase:
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
"""
Finish WebAuthn login for a user.
@@ -220,7 +258,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
summary='获取用户主页展示用分享',
description='Get user profile for display.',
)
def router_user_profile(id: str) -> models.ResponseBase:
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
"""
Get user profile for display.
@@ -237,7 +275,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
summary='获取用户头像',
description='Get user avatar by ID and size.',
)
def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
"""
Get user avatar by ID and size.
@@ -259,12 +297,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
summary='获取用户信息',
description='Get user information.',
dependencies=[Depends(dependency=auth_required)],
response_model=models.UserResponse,
response_model=sqlmodels.UserResponse,
)
async def router_user_me(
session: SessionDep,
user: Annotated[models.User, Depends(auth_required)],
) -> models.ResponseBase:
user: Annotated[sqlmodels.User, Depends(auth_required)],
) -> sqlmodels.UserResponse:
"""
获取用户信息.
@@ -272,10 +310,10 @@ async def router_user_me(
:rtype: ResponseBase
"""
# 加载 group 及其 options 关系
group = await models.Group.get(
group = await sqlmodels.Group.get(
session,
models.Group.id == user.group_id,
load=models.Group.options
sqlmodels.Group.id == user.group_id,
load=sqlmodels.Group.options
)
# 构建 GroupResponse
@@ -284,9 +322,9 @@ async def router_user_me(
# 异步加载 tags 关系
user_tags = await user.awaitable_attrs.tags
return models.UserResponse(
return sqlmodels.UserResponse(
id=user.id,
username=user.username,
email=user.email,
status=user.status,
score=user.score,
nickname=user.nickname,
@@ -304,30 +342,26 @@ async def router_user_me(
)
async def router_user_storage(
session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.UserStorageResponse:
"""
获取用户存储空间信息。
返回值:
- used: 已使用空间(字节)
- free: 剩余空间(字节)
- total: 总容量(字节)= 用户组容量
"""
# 获取用户组的基础存储容量
group = await models.Group.get(session, models.Group.id == user.group_id)
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
if not group:
raise HTTPException(status_code=500, detail="用户组不存在")
raise HTTPException(status_code=404, detail="用户组不存在")
# [TODO] 总空间加上用户购买的额外空间
total: int = group.max_storage
used: int = user.storage
free: int = max(0, total - used)
return models.ResponseBase(
data={
"used": used,
"free": free,
"total": total,
}
return sqlmodels.UserStorageResponse(
used=used,
free=free,
total=total,
)
@user_router.put(
@@ -338,8 +372,8 @@ async def router_user_storage(
)
async def router_user_authn_start(
session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase:
"""
Initialize WebAuthn login for a user.
@@ -347,30 +381,30 @@ async def router_user_authn_start(
dict: A dictionary containing WebAuthn initialization information.
"""
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
authn_setting = await models.Setting.get(
authn_setting = await sqlmodels.Setting.get(
session,
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
)
if not authn_setting or authn_setting.value != "1":
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
site_url_setting = await models.Setting.get(
site_url_setting = await sqlmodels.Setting.get(
session,
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
)
site_title_setting = await models.Setting.get(
site_title_setting = await sqlmodels.Setting.get(
session,
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
)
options = generate_registration_options(
rp_id=site_url_setting.value if site_url_setting else "",
rp_name=site_title_setting.value if site_title_setting else "",
user_name=user.username,
user_display_name=user.nick or user.username,
user_name=user.email,
user_display_name=user.nickname or user.email,
)
return models.ResponseBase(data=options_to_json_dict(options))
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
@user_router.put(
path='/authn/finish',
@@ -378,179 +412,11 @@ async def router_user_authn_start(
description='Finish WebAuthn login for a user.',
dependencies=[Depends(auth_required)],
)
def router_user_authn_finish() -> models.ResponseBase:
def router_user_authn_finish() -> sqlmodels.ResponseBase:
"""
Finish WebAuthn login for a user.
Returns:
dict: A dictionary containing WebAuthn login information.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
def router_user_settings_policies() -> models.ResponseBase:
"""
Get user selectable storage policies.
Returns:
dict: A dictionary containing available storage policies for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/nodes',
summary='获取用户可选节点',
description='Get user selectable nodes.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_nodes() -> models.ResponseBase:
"""
Get user selectable nodes.
Returns:
dict: A dictionary containing available nodes for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/tasks',
summary='任务队列',
description='Get user task queue.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_tasks() -> models.ResponseBase:
"""
Get user task queue.
Returns:
dict: A dictionary containing the user's task queue information.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/',
summary='获取当前用户设定',
description='Get current user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings() -> models.ResponseBase:
"""
Get current user settings.
Returns:
dict: A dictionary containing the current user settings.
"""
return models.ResponseBase(data=models.UserSettingResponse().model_dump())
@user_settings_router.post(
path='/avatar',
summary='从文件上传头像',
description='Upload user avatar from file.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar() -> models.ResponseBase:
"""
Upload user avatar from file.
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
"""
Set user avatar to Gravatar.
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
path='/{option}',
summary='更新用户设定',
description='Update user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_patch(option: str) -> models.ResponseBase:
"""
Update user settings.
Args:
option (str): The setting option to update.
Returns:
dict: A dictionary containing the result of the settings update.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/2fa',
summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa(
user: Annotated[models.user.User, Depends(auth_required)],
) -> models.ResponseBase:
"""
Get two-factor authentication initialization information.
Returns:
dict: A dictionary containing two-factor authentication setup information.
"""
return models.ResponseBase(
data=await Password.generate_totp(user.username)
)
@user_settings_router.post(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[models.user.User, Depends(auth_required)],
setup_token: str,
code: str,
) -> models.ResponseBase:
"""
Enable two-factor authentication for the user.
Returns:
dict: A dictionary containing the result of enabling two-factor authentication.
"""
serializer = URLSafeTimedSerializer(SECRET_KEY)
try:
# 1. 解包 Token设置有效期例如 600秒
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码
if not Password.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
user.two_factor = secret
user = await user.save(session)
return models.ResponseBase(
data={"message": "Two-factor authentication enabled successfully"}
)
http_exceptions.raise_not_implemented()

View File

@@ -0,0 +1,203 @@
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import sqlmodels
from middleware.auth import auth_required
from middleware.dependencies import SessionDep
from utils import JWT, Password, http_exceptions
user_settings_router = APIRouter(
prefix='/settings',
tags=["user", "user_settings"],
dependencies=[Depends(auth_required)],
)
@user_settings_router.get(
path='/policies',
summary='获取用户可选存储策略',
description='Get user selectable storage policies.',
)
def router_user_settings_policies() -> sqlmodels.ResponseBase:
"""
Get user selectable storage policies.
Returns:
dict: A dictionary containing available storage policies for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/nodes',
summary='获取用户可选节点',
description='Get user selectable nodes.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
"""
Get user selectable nodes.
Returns:
dict: A dictionary containing available nodes for the user.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/tasks',
summary='任务队列',
description='Get user task queue.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
"""
Get user task queue.
Returns:
dict: A dictionary containing the user's task queue information.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/',
summary='获取当前用户设定',
description='Get current user settings.',
)
def router_user_settings(
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.UserSettingResponse:
"""
Get current user settings.
Returns:
dict: A dictionary containing the current user settings.
"""
return sqlmodels.UserSettingResponse(
id=user.id,
email=user.email,
nickname=user.nickname,
created_at=user.created_at,
group_name=user.group.name,
language=user.language,
timezone=user.timezone,
group_expires=user.group_expires,
two_factor=user.two_factor is not None,
)
@user_settings_router.post(
path='/avatar',
summary='从文件上传头像',
description='Upload user avatar from file.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
"""
Upload user avatar from file.
Returns:
dict: A dictionary containing the result of the avatar upload.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.put(
path='/avatar',
summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
"""
Set user avatar to Gravatar.
Returns:
dict: A dictionary containing the result of setting the Gravatar avatar.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.patch(
path='/{option}',
summary='更新用户设定',
description='Update user settings.',
dependencies=[Depends(auth_required)],
)
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
"""
Update user settings.
Args:
option (str): The setting option to update.
Returns:
dict: A dictionary containing the result of the settings update.
"""
http_exceptions.raise_not_implemented()
@user_settings_router.get(
path='/2fa',
summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa(
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
) -> sqlmodels.ResponseBase:
"""
Get two-factor authentication initialization information.
Returns:
dict: A dictionary containing two-factor authentication setup information.
"""
return sqlmodels.ResponseBase(
data=await Password.generate_totp(user.email)
)
@user_settings_router.post(
path='/2fa',
summary='启用两步验证',
description='Enable two-factor authentication.',
dependencies=[Depends(auth_required)],
)
async def router_user_settings_2fa_enable(
session: SessionDep,
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
setup_token: str,
code: str,
) -> sqlmodels.ResponseBase:
"""
Enable two-factor authentication for the user.
Returns:
dict: A dictionary containing the result of enabling two-factor authentication.
"""
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
try:
# 1. 解包 Token设置有效期例如 600秒
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
except SignatureExpired:
raise HTTPException(status_code=400, detail="Setup session expired")
except BadSignature:
raise HTTPException(status_code=400, detail="Invalid token")
# 2. 验证用户输入的 6 位验证码
if not Password.verify_totp(secret, code):
raise HTTPException(status_code=400, detail="Invalid OTP code")
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
user.two_factor = secret
user = await user.save(session)
return sqlmodels.ResponseBase(
data={"message": "Two-factor authentication enabled successfully"}
)

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from models import ResponseBase
from sqlmodels import ResponseBase
from utils import http_exceptions
vas_router = APIRouter(

View File

@@ -1,7 +1,7 @@
from fastapi import APIRouter, Depends
from middleware.auth import auth_required
from models import ResponseBase
from sqlmodels import ResponseBase
from utils import http_exceptions
# WebDAV 管理路由