feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
@@ -5,15 +5,15 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase,
|
||||
Setting, Object, ObjectType, Share, AdminSummaryResponse, MetricsSummary, LicenseInfo, VersionInfo,
|
||||
)
|
||||
from models.base import SQLModelBase
|
||||
from models.setting import (
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from sqlmodels.setting import (
|
||||
SettingItem, SettingsListResponse, SettingsUpdateRequest, SettingsUpdateResponse,
|
||||
)
|
||||
from models.setting import SettingsType
|
||||
from sqlmodels.setting import SettingsType
|
||||
from utils import http_exceptions
|
||||
from utils.conf import appmeta
|
||||
from .file import admin_file_router
|
||||
|
||||
@@ -5,14 +5,60 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger as l
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
Policy, PolicyType, User, ResponseBase, ListResponse,
|
||||
from sqlmodels import (
|
||||
Policy, PolicyType, User, ListResponse,
|
||||
Object, ObjectType, AdminFileResponse, FileBanRequest, )
|
||||
from service.storage import LocalStorageService
|
||||
|
||||
async def _set_ban_recursive(
|
||||
session: AsyncSession,
|
||||
obj: Object,
|
||||
ban: bool,
|
||||
admin_id: UUID,
|
||||
reason: str | None,
|
||||
) -> int:
|
||||
"""
|
||||
递归设置封禁状态,返回受影响对象数量。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param obj: 要封禁/解禁的对象
|
||||
:param ban: True=封禁, False=解禁
|
||||
:param admin_id: 管理员UUID
|
||||
:param reason: 封禁原因
|
||||
:return: 受影响的对象数量
|
||||
"""
|
||||
count = 0
|
||||
|
||||
# 如果是文件夹,先递归处理子对象
|
||||
if obj.is_folder:
|
||||
children = await Object.get(
|
||||
session,
|
||||
Object.parent_id == obj.id,
|
||||
fetch_mode="all",
|
||||
)
|
||||
for child in children:
|
||||
count += await _set_ban_recursive(session, child, ban, admin_id, reason)
|
||||
|
||||
# 设置当前对象
|
||||
obj.is_banned = ban
|
||||
if ban:
|
||||
obj.banned_at = datetime.now()
|
||||
obj.banned_by = admin_id
|
||||
obj.ban_reason = reason
|
||||
else:
|
||||
obj.banned_at = None
|
||||
obj.banned_by = None
|
||||
obj.ban_reason = None
|
||||
|
||||
await obj.save(session)
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
admin_file_router = APIRouter(
|
||||
prefix="/file",
|
||||
tags=["admin", "admin_file"],
|
||||
@@ -119,15 +165,17 @@ async def router_admin_preview_file(
|
||||
summary='封禁/解禁文件',
|
||||
description='Ban the file, user can\'t open, copy, move, download or share this file if administrator ban.',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_ban_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
request: FileBanRequest,
|
||||
admin: Annotated[User, Depends(admin_required)],
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
封禁或解禁文件。封禁后用户无法访问该文件。
|
||||
封禁或解禁文件/文件夹。封禁后用户无法访问该文件。
|
||||
封禁文件夹时会级联封禁所有子对象。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param file_id: 文件UUID
|
||||
@@ -139,24 +187,10 @@ async def router_admin_ban_file(
|
||||
if not file_obj:
|
||||
raise HTTPException(status_code=404, detail="文件不存在")
|
||||
|
||||
file_obj.is_banned = request.is_banned
|
||||
if request.is_banned:
|
||||
file_obj.banned_at = datetime.now()
|
||||
file_obj.banned_by = admin.id
|
||||
file_obj.ban_reason = request.reason
|
||||
else:
|
||||
file_obj.banned_at = None
|
||||
file_obj.banned_by = None
|
||||
file_obj.ban_reason = None
|
||||
count = await _set_ban_recursive(session, file_obj, request.ban, admin.id, request.reason)
|
||||
|
||||
file_obj = await file_obj.save(session)
|
||||
|
||||
action = "封禁" if request.is_banned else "解禁"
|
||||
l.info(f"管理员{action}了文件: {file_obj.name}")
|
||||
return ResponseBase(data={
|
||||
"id": str(file_obj.id),
|
||||
"is_banned": file_obj.is_banned,
|
||||
})
|
||||
action = "封禁" if request.ban else "解禁"
|
||||
l.info(f"管理员{action}了对象: {file_obj.name},共影响 {count} 个对象")
|
||||
|
||||
|
||||
@admin_file_router.delete(
|
||||
@@ -164,12 +198,13 @@ async def router_admin_ban_file(
|
||||
summary='删除文件',
|
||||
description='Delete file by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_file(
|
||||
session: SessionDep,
|
||||
file_id: UUID,
|
||||
delete_physical: bool = True,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
删除文件。
|
||||
|
||||
@@ -211,5 +246,4 @@ async def router_admin_delete_file(
|
||||
# 使用条件删除
|
||||
await Object.delete(session, condition=Object.id == file_obj.id)
|
||||
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了文件: {file_name}")
|
||||
@@ -5,12 +5,12 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, GroupOptions, )
|
||||
from models.group import (
|
||||
from sqlmodels.group import (
|
||||
GroupCreateRequest, GroupUpdateRequest, GroupDetailResponse, )
|
||||
from models.policy import GroupPolicyLink
|
||||
from sqlmodels.policy import GroupPolicyLink
|
||||
|
||||
admin_group_router = APIRouter(
|
||||
prefix="/group",
|
||||
@@ -113,11 +113,12 @@ async def router_admin_get_group_members(
|
||||
summary='创建用户组',
|
||||
description='Create a new user group',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_create_group(
|
||||
session: SessionDep,
|
||||
request: GroupCreateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
创建新的用户组。
|
||||
|
||||
@@ -164,7 +165,6 @@ async def router_admin_create_group(
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员创建了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id), "name": group.name})
|
||||
|
||||
|
||||
@admin_group_router.patch(
|
||||
@@ -172,12 +172,13 @@ async def router_admin_create_group(
|
||||
summary='更新用户组信息',
|
||||
description='Update user group information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_update_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
request: GroupUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID更新用户组信息。
|
||||
|
||||
@@ -233,8 +234,7 @@ async def router_admin_update_group(
|
||||
session.add(link)
|
||||
await session.commit()
|
||||
|
||||
l.info(f"管理员更新了用户组: {group.name}")
|
||||
return ResponseBase(data={"id": str(group.id)})
|
||||
l.info(f"管理员更新了用户组: {group_id}")
|
||||
|
||||
|
||||
@admin_group_router.delete(
|
||||
@@ -242,11 +242,12 @@ async def router_admin_update_group(
|
||||
summary='删除用户组',
|
||||
description='Delete user group by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_group(
|
||||
session: SessionDep,
|
||||
group_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户组ID删除用户组。
|
||||
|
||||
@@ -271,5 +272,4 @@ async def router_admin_delete_group(
|
||||
group_name = group.name
|
||||
await Group.delete(session, group)
|
||||
|
||||
l.info(f"管理员删除了用户组: {group_name}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
l.info(f"管理员删除了用户组: {group_id}")
|
||||
@@ -6,10 +6,10 @@ from sqlmodel import Field
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
Policy, PolicyBase, PolicyType, PolicySummary, ResponseBase,
|
||||
ListResponse, Object, )
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
from service.storage import DirectoryCreationError, LocalStorageService
|
||||
|
||||
admin_policy_router = APIRouter(
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Share, AdminShareListItem, )
|
||||
|
||||
@@ -80,7 +80,7 @@ async def router_admin_get_share(
|
||||
"score": share.score,
|
||||
"has_password": bool(share.password),
|
||||
"user_id": str(share.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"object": {
|
||||
"id": str(obj.id),
|
||||
"name": obj.name,
|
||||
|
||||
@@ -5,7 +5,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase, ListResponse,
|
||||
Task, TaskSummary,
|
||||
)
|
||||
@@ -89,7 +89,7 @@ async def router_admin_get_task(
|
||||
"progress": task.progress,
|
||||
"error": task.error,
|
||||
"user_id": str(task.user_id),
|
||||
"username": user.username if user else None,
|
||||
"username": user.email if user else None,
|
||||
"props": props.model_dump() if props else None,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
"updated_at": task.updated_at.isoformat(),
|
||||
|
||||
@@ -6,11 +6,13 @@ from sqlalchemy import func
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep, TableViewRequestDep, UserFilterParamsDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
User, ResponseBase, UserPublic, ListResponse,
|
||||
Group, Object, ObjectType, )
|
||||
from models.user import (
|
||||
UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
Group, Object, ObjectType, Setting, SettingsType,
|
||||
BatchDeleteRequest,
|
||||
)
|
||||
from sqlmodels.user import (
|
||||
UserAdminCreateRequest, UserAdminUpdateRequest, UserCalibrateResponse,
|
||||
)
|
||||
from utils import Password, http_exceptions
|
||||
|
||||
@@ -26,19 +28,19 @@ admin_user_router = APIRouter(
|
||||
description='Get user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
)
|
||||
async def router_admin_get_user(session: SessionDep, user_id: int) -> ResponseBase:
|
||||
async def router_admin_get_user(session: SessionDep, user_id: UUID) -> UserPublic:
|
||||
"""
|
||||
根据用户ID获取用户信息,包括用户名、邮箱、注册时间等。
|
||||
|
||||
Args:
|
||||
session(SessionDep): 数据库会话依赖项。
|
||||
user_id (int): 用户ID。
|
||||
user_id (UUID): 用户ID。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含用户信息的响应模型。
|
||||
"""
|
||||
user = await User.get_exist_one(session, user_id)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.get(
|
||||
@@ -60,7 +62,7 @@ async def router_admin_get_users(
|
||||
:param filter_params: 用户筛选参数(用户组、用户名、昵称、状态)
|
||||
:return: 分页用户列表
|
||||
"""
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view)
|
||||
result = await User.get_with_count(session, filter_params=filter_params, table_view=table_view, load=User.group)
|
||||
return ListResponse(
|
||||
items=[user.to_public() for user in result.items],
|
||||
count=result.count,
|
||||
@@ -75,22 +77,33 @@ async def router_admin_get_users(
|
||||
)
|
||||
async def router_admin_create_user(
|
||||
session: SessionDep,
|
||||
user: User,
|
||||
) -> ResponseBase:
|
||||
request: UserAdminCreateRequest,
|
||||
) -> UserPublic:
|
||||
"""
|
||||
创建一个新的用户,设置用户名、密码等信息。
|
||||
创建一个新的用户,设置邮箱、密码、用户组等信息。
|
||||
|
||||
Returns:
|
||||
ResponseBase: 包含创建结果的响应模型。
|
||||
:param session: 数据库会话
|
||||
:param request: 创建用户请求 DTO
|
||||
:return: 创建结果
|
||||
"""
|
||||
existing_user = await User.get(session, User.username == user.username)
|
||||
existing_user = await User.get(session, User.email == request.email)
|
||||
if existing_user:
|
||||
return ResponseBase(
|
||||
code=400,
|
||||
msg="User with this username already exists."
|
||||
)
|
||||
raise HTTPException(status_code=409, detail="该邮箱已被注册")
|
||||
|
||||
# 验证用户组存在
|
||||
group = await Group.get(session, Group.id == request.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=400, detail="目标用户组不存在")
|
||||
|
||||
user = User(
|
||||
email=request.email,
|
||||
password=Password.hash(request.password),
|
||||
nickname=request.nickname,
|
||||
group_id=request.group_id,
|
||||
status=request.status,
|
||||
)
|
||||
user = await user.save(session)
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
return user.to_public()
|
||||
|
||||
|
||||
@admin_user_router.patch(
|
||||
@@ -98,12 +111,13 @@ async def router_admin_create_user(
|
||||
summary='更新用户信息',
|
||||
description='Update user information by ID',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204
|
||||
)
|
||||
async def router_admin_update_user(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
request: UserAdminUpdateRequest,
|
||||
) -> ResponseBase:
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID更新用户信息。
|
||||
|
||||
@@ -116,8 +130,15 @@ async def router_admin_update_user(
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
# 默认管理员(用户名为 admin)不允许更改用户组
|
||||
if request.group_id and user.username == "admin" and request.group_id != user.group_id:
|
||||
# 默认管理员不允许更改用户组(通过 Setting 中的 default_admin_id 识别)
|
||||
default_admin_setting = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.AUTH) & (Setting.name == "default_admin_id")
|
||||
)
|
||||
if (request.group_id
|
||||
and default_admin_setting
|
||||
and default_admin_setting.value == str(user_id)
|
||||
and request.group_id != user.group_id):
|
||||
http_exceptions.raise_forbidden("默认管理员不允许更改用户组")
|
||||
|
||||
# 如果更新用户组,验证新组存在
|
||||
@@ -143,38 +164,35 @@ async def router_admin_update_user(
|
||||
setattr(user, key, value)
|
||||
user = await user.save(session)
|
||||
|
||||
l.info(f"管理员更新了用户: {user.username}")
|
||||
return ResponseBase(data=user.to_public().model_dump())
|
||||
l.info(f"管理员更新了用户: {request.email}")
|
||||
|
||||
|
||||
@admin_user_router.delete(
|
||||
path='/{user_id}',
|
||||
summary='删除用户',
|
||||
description='Delete user by ID',
|
||||
path='/',
|
||||
summary='删除用户(支持批量)',
|
||||
description='Delete users by ID list',
|
||||
dependencies=[Depends(admin_required)],
|
||||
status_code=204,
|
||||
)
|
||||
async def router_admin_delete_user(
|
||||
async def router_admin_delete_users(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
request: BatchDeleteRequest,
|
||||
) -> None:
|
||||
"""
|
||||
根据用户ID删除用户及其所有数据。
|
||||
批量删除用户及其所有数据。
|
||||
|
||||
注意: 这是一个危险操作,会级联删除用户的所有文件、分享、任务等。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户UUID
|
||||
:return: 删除结果
|
||||
:param request: 批量删除请求,包含待删除用户的 UUID 列表
|
||||
:return: 删除结果(已删除数 / 总请求数)
|
||||
"""
|
||||
user = await User.get(session, User.id == user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="用户不存在")
|
||||
|
||||
username = user.username
|
||||
await User.delete(session, user)
|
||||
|
||||
l.info(f"管理员删除了用户: {username}")
|
||||
return ResponseBase(data={"deleted": True})
|
||||
deleted = 0
|
||||
for uid in request.ids:
|
||||
user = await User.get(session, User.id == uid)
|
||||
if user:
|
||||
await User.delete(session, user)
|
||||
l.info(f"管理员删除了用户: {user.email}")
|
||||
|
||||
|
||||
@admin_user_router.post(
|
||||
@@ -186,7 +204,7 @@ async def router_admin_delete_user(
|
||||
async def router_admin_calibrate_storage(
|
||||
session: SessionDep,
|
||||
user_id: UUID,
|
||||
) -> ResponseBase:
|
||||
) -> UserCalibrateResponse:
|
||||
"""
|
||||
重新计算用户的已用存储空间。
|
||||
|
||||
@@ -228,5 +246,5 @@ async def router_admin_calibrate_storage(
|
||||
file_count=file_count,
|
||||
)
|
||||
|
||||
l.info(f"管理员校准了用户存储: {user.username}, 差值: {actual_storage - previous_storage}")
|
||||
return ResponseBase(data=response.model_dump())
|
||||
l.info(f"管理员校准了用户存储: {user.email}, 差值: {actual_storage - previous_storage}")
|
||||
return response
|
||||
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import admin_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
ResponseBase,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Query
|
||||
from fastapi.responses import PlainTextResponse
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
import service.oauth
|
||||
from utils import http_exceptions
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
DirectoryCreateRequest,
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
@@ -14,50 +16,28 @@ from models import (
|
||||
User,
|
||||
ResponseBase,
|
||||
)
|
||||
from utils import http_exceptions
|
||||
|
||||
directory_router = APIRouter(
|
||||
prefix="/directory",
|
||||
tags=["directory"]
|
||||
)
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
|
||||
async def _get_directory_response(
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
folder: Object,
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径必须以用户名或 `.crash` 开头,如 /api/directory/admin 或 /api/directory/admin/docs
|
||||
`.crash` 代表回收站,也就意味着用户名禁止为 `.crash`
|
||||
构建目录响应 DTO
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(必须以用户名开头)
|
||||
:return: 目录内容
|
||||
:param user_id: 用户UUID
|
||||
:param folder: 目录对象
|
||||
:return: DirectoryResponse
|
||||
"""
|
||||
# 路径必须以用户名开头
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空,请使用 /{username} 格式")
|
||||
|
||||
path_parts = path.split("/")
|
||||
if path_parts[0] != user.username:
|
||||
raise HTTPException(status_code=403, detail="无权访问其他用户的目录")
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path, user.username)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
children = await Object.get_children(session, user.id, folder.id)
|
||||
children = await Object.get_children(session, user_id, folder.id)
|
||||
policy = await folder.awaitable_attrs.policy
|
||||
|
||||
objects = [
|
||||
@@ -67,8 +47,8 @@ async def router_directory_get(
|
||||
thumb=False,
|
||||
size=child.size,
|
||||
type=ObjectType.FOLDER if child.is_folder else ObjectType.FILE,
|
||||
date=child.updated_at,
|
||||
create_date=child.created_at,
|
||||
created_at=child.created_at,
|
||||
updated_at=child.updated_at,
|
||||
source_enabled=False,
|
||||
)
|
||||
for child in children
|
||||
@@ -89,7 +69,74 @@ async def router_directory_get(
|
||||
)
|
||||
|
||||
|
||||
@directory_router.put(
|
||||
@directory_router.get(
|
||||
path="/",
|
||||
summary="获取根目录内容",
|
||||
)
|
||||
async def router_directory_root(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取当前用户的根目录内容
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:return: 根目录内容
|
||||
"""
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
|
||||
if root.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
|
||||
@directory_router.get(
|
||||
path="/{path:path}",
|
||||
summary="获取目录内容",
|
||||
)
|
||||
async def router_directory_get(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
path: str
|
||||
) -> DirectoryResponse:
|
||||
"""
|
||||
获取目录内容
|
||||
|
||||
路径从用户根目录开始,不包含用户名前缀。
|
||||
如 /api/v1/directory/docs 表示根目录下的 docs 目录。
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param path: 目录路径(从根目录开始的相对路径)
|
||||
:return: 目录内容
|
||||
"""
|
||||
path = path.strip("/")
|
||||
if not path:
|
||||
# 空路径交给根目录端点处理(理论上不会到达这里)
|
||||
root = await Object.get_root(session, user.id)
|
||||
if not root:
|
||||
raise HTTPException(status_code=404, detail="根目录不存在")
|
||||
return await _get_directory_response(session, user.id, root)
|
||||
|
||||
folder = await Object.get_by_path(session, user.id, "/" + path)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
|
||||
if not folder.is_folder:
|
||||
raise HTTPException(status_code=400, detail="指定路径不是目录")
|
||||
|
||||
if folder.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
return await _get_directory_response(session, user.id, folder)
|
||||
|
||||
|
||||
@directory_router.post(
|
||||
path="/",
|
||||
summary="创建目录",
|
||||
)
|
||||
@@ -123,6 +170,9 @@ async def router_directory_create(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父路径不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名对象
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
download_router = APIRouter(
|
||||
|
||||
@@ -18,7 +18,7 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required, verify_download_token
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
CreateUploadSessionRequest,
|
||||
Object,
|
||||
@@ -91,6 +91,9 @@ async def create_upload_session(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
@@ -100,7 +103,7 @@ async def create_upload_session(
|
||||
# 验证文件大小限制
|
||||
if policy.max_size > 0 and request.file_size > policy.max_size:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
status_code=413,
|
||||
detail=f"文件大小超过限制 ({policy.max_size} bytes)"
|
||||
)
|
||||
|
||||
@@ -221,30 +224,40 @@ async def upload_chunk(
|
||||
upload_session.uploaded_size += len(content)
|
||||
upload_session = await upload_session.save(session)
|
||||
|
||||
# 检查是否完成
|
||||
# 在后续可能的 commit 前保存需要的属性
|
||||
is_complete = upload_session.is_complete
|
||||
uploaded_chunks = upload_session.uploaded_chunks
|
||||
total_chunks = upload_session.total_chunks
|
||||
file_object_id: UUID | None = None
|
||||
|
||||
if is_complete:
|
||||
# 保存 upload_session 属性(commit 后会过期)
|
||||
file_name = upload_session.file_name
|
||||
uploaded_size = upload_session.uploaded_size
|
||||
storage_path = upload_session.storage_path
|
||||
upload_session_id = upload_session.id
|
||||
parent_id = upload_session.parent_id
|
||||
policy_id = upload_session.policy_id
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=upload_session.storage_path,
|
||||
size=upload_session.uploaded_size,
|
||||
policy_id=upload_session.policy_id,
|
||||
storage_path=storage_path,
|
||||
size=uploaded_size,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session, commit=False)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=upload_session.file_name,
|
||||
name=file_name,
|
||||
type=ObjectType.FILE,
|
||||
size=upload_session.uploaded_size,
|
||||
size=uploaded_size,
|
||||
physical_file_id=physical_file.id,
|
||||
upload_session_id=str(upload_session.id),
|
||||
parent_id=upload_session.parent_id,
|
||||
upload_session_id=str(upload_session_id),
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=upload_session.policy_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
file_object = await file_object.save(session, commit=False)
|
||||
file_object_id = file_object.id
|
||||
@@ -252,18 +265,18 @@ async def upload_chunk(
|
||||
# 删除上传会话(使用条件删除)
|
||||
await UploadSession.delete(
|
||||
session,
|
||||
condition=UploadSession.id == upload_session.id,
|
||||
condition=UploadSession.id == upload_session_id,
|
||||
commit=False
|
||||
)
|
||||
|
||||
# 统一提交所有更改
|
||||
await session.commit()
|
||||
|
||||
l.info(f"文件上传完成: {file_object.name}, size={file_object.size}, id={file_object.id}")
|
||||
l.info(f"文件上传完成: {file_name}, size={uploaded_size}, id={file_object_id}")
|
||||
|
||||
return UploadChunkResponse(
|
||||
uploaded_chunks=upload_session.uploaded_chunks if not is_complete else upload_session.total_chunks,
|
||||
total_chunks=upload_session.total_chunks,
|
||||
uploaded_chunks=uploaded_chunks if not is_complete else total_chunks,
|
||||
total_chunks=total_chunks,
|
||||
is_complete=is_complete,
|
||||
object_id=file_object_id,
|
||||
)
|
||||
@@ -368,6 +381,9 @@ async def create_download_token_endpoint(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
token = create_download_token(file_id, user.id)
|
||||
|
||||
l.debug(f"创建下载令牌: file_id={file_id}, user_id={user.id}")
|
||||
@@ -410,6 +426,9 @@ async def download_file(
|
||||
if not file_obj.is_file:
|
||||
raise HTTPException(status_code=400, detail="对象不是文件")
|
||||
|
||||
if file_obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 预加载 physical_file 关系以获取存储路径
|
||||
physical_file = await file_obj.awaitable_attrs.physical_file
|
||||
if not physical_file or not physical_file.storage_path:
|
||||
@@ -470,6 +489,9 @@ async def create_empty_file(
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from models import MCPRequestBase, MCPResponseBase, MCPMethod
|
||||
|
||||
# MCP 路由
|
||||
MCP_router = APIRouter(
|
||||
prefix='/mcp',
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
@MCP_router.get(
|
||||
"/",
|
||||
)
|
||||
async def mcp_root(
|
||||
param: MCPRequestBase
|
||||
):
|
||||
match param.method:
|
||||
case MCPMethod.PING:
|
||||
return MCPResponseBase(result="pong", **param.model_dump())
|
||||
@@ -14,7 +14,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import (
|
||||
from sqlmodels import (
|
||||
CreateFileRequest,
|
||||
Object,
|
||||
ObjectCopyRequest,
|
||||
ObjectDeleteRequest,
|
||||
@@ -26,10 +27,11 @@ from models import (
|
||||
PhysicalFile,
|
||||
Policy,
|
||||
PolicyType,
|
||||
ResponseBase,
|
||||
User,
|
||||
)
|
||||
from models import ResponseBase
|
||||
from service.storage import LocalStorageService
|
||||
from utils import http_exceptions
|
||||
|
||||
object_router = APIRouter(
|
||||
prefix="/object",
|
||||
@@ -59,15 +61,22 @@ async def _delete_object_recursive(
|
||||
"""
|
||||
deleted_count = 0
|
||||
|
||||
if obj.is_folder:
|
||||
# 在任何数据库操作前保存所有需要的属性,避免 commit 后对象过期导致懒加载失败
|
||||
obj_id = obj.id
|
||||
obj_name = obj.name
|
||||
obj_is_folder = obj.is_folder
|
||||
obj_is_file = obj.is_file
|
||||
obj_physical_file_id = obj.physical_file_id
|
||||
|
||||
if obj_is_folder:
|
||||
# 递归删除子对象
|
||||
children = await Object.get_children(session, user_id, obj.id)
|
||||
children = await Object.get_children(session, user_id, obj_id)
|
||||
for child in children:
|
||||
deleted_count += await _delete_object_recursive(session, child, user_id)
|
||||
|
||||
# 如果是文件,处理物理文件引用
|
||||
if obj.is_file and obj.physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj.physical_file_id)
|
||||
if obj_is_file and obj_physical_file_id:
|
||||
physical_file = await PhysicalFile.get(session, PhysicalFile.id == obj_physical_file_id)
|
||||
if physical_file:
|
||||
# 减少引用计数
|
||||
new_count = physical_file.decrement_reference()
|
||||
@@ -81,11 +90,11 @@ async def _delete_object_recursive(
|
||||
await storage_service.move_to_trash(
|
||||
source_path=physical_file.storage_path,
|
||||
user_id=user_id,
|
||||
object_id=obj.id,
|
||||
object_id=obj_id,
|
||||
)
|
||||
l.debug(f"物理文件已移动到回收站: {obj.name}")
|
||||
l.debug(f"物理文件已移动到回收站: {obj_name}")
|
||||
except Exception as e:
|
||||
l.warning(f"移动物理文件到回收站失败: {obj.name}, 错误: {e}")
|
||||
l.warning(f"移动物理文件到回收站失败: {obj_name}, 错误: {e}")
|
||||
|
||||
# 删除 PhysicalFile 记录
|
||||
await PhysicalFile.delete(session, physical_file)
|
||||
@@ -95,8 +104,8 @@ async def _delete_object_recursive(
|
||||
await physical_file.save(session)
|
||||
l.debug(f"物理文件仍有 {new_count} 个引用,不删除: {physical_file.storage_path}")
|
||||
|
||||
# 删除数据库记录
|
||||
await Object.delete(session, obj)
|
||||
# 使用条件删除,避免访问过期的 obj 实例
|
||||
await Object.delete(session, condition=Object.id == obj_id)
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
@@ -168,6 +177,97 @@ async def _copy_object_recursive(
|
||||
return copied_count, new_ids
|
||||
|
||||
|
||||
@object_router.post(
|
||||
path='/',
|
||||
summary='创建空白文件',
|
||||
description='在指定目录下创建空白文件。',
|
||||
)
|
||||
async def router_object_create(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(auth_required)],
|
||||
request: CreateFileRequest,
|
||||
) -> ResponseBase:
|
||||
"""
|
||||
创建空白文件端点
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 创建文件请求(parent_id, name)
|
||||
:return: 创建结果
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
# 验证文件名
|
||||
if not request.name or '/' in request.name or '\\' in request.name:
|
||||
raise HTTPException(status_code=400, detail="无效的文件名")
|
||||
|
||||
# 验证父目录
|
||||
parent = await Object.get(session, Object.id == request.parent_id)
|
||||
if not parent or parent.owner_id != user_id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
if not parent.is_folder:
|
||||
raise HTTPException(status_code=400, detail="父对象不是目录")
|
||||
|
||||
if parent.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 检查是否已存在同名文件
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user_id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == request.name)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件已存在")
|
||||
|
||||
# 确定存储策略
|
||||
policy_id = request.policy_id or parent.policy_id
|
||||
policy = await Policy.get(session, Policy.id == policy_id)
|
||||
if not policy:
|
||||
raise HTTPException(status_code=404, detail="存储策略不存在")
|
||||
|
||||
parent_id = parent.id
|
||||
|
||||
# 生成存储路径并创建空文件
|
||||
if policy.type == PolicyType.LOCAL:
|
||||
storage_service = LocalStorageService(policy)
|
||||
dir_path, storage_name, full_path = await storage_service.generate_file_path(
|
||||
user_id=user_id,
|
||||
original_filename=request.name,
|
||||
)
|
||||
await storage_service.create_empty_file(full_path)
|
||||
storage_path = full_path
|
||||
else:
|
||||
raise HTTPException(status_code=501, detail="S3 存储暂未实现")
|
||||
|
||||
# 创建 PhysicalFile 记录
|
||||
physical_file = PhysicalFile(
|
||||
storage_path=storage_path,
|
||||
size=0,
|
||||
policy_id=policy_id,
|
||||
reference_count=1,
|
||||
)
|
||||
physical_file = await physical_file.save(session)
|
||||
|
||||
# 创建 Object 记录
|
||||
file_object = Object(
|
||||
name=request.name,
|
||||
type=ObjectType.FILE,
|
||||
size=0,
|
||||
physical_file_id=physical_file.id,
|
||||
parent_id=parent_id,
|
||||
owner_id=user_id,
|
||||
policy_id=policy_id,
|
||||
)
|
||||
await file_object.save(session)
|
||||
|
||||
l.info(f"创建空白文件: {request.name}")
|
||||
|
||||
return ResponseBase()
|
||||
|
||||
|
||||
@object_router.delete(
|
||||
path='/',
|
||||
summary='删除对象',
|
||||
@@ -197,10 +297,7 @@ async def router_object_delete(
|
||||
user_id = user.id
|
||||
deleted_count = 0
|
||||
|
||||
# 处理单个 UUID 或 UUID 列表
|
||||
ids = request.ids if isinstance(request.ids, list) else [request.ids]
|
||||
|
||||
for obj_id in ids:
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
if not obj or obj.owner_id != user_id:
|
||||
continue
|
||||
@@ -219,7 +316,7 @@ async def router_object_delete(
|
||||
return ResponseBase(
|
||||
data={
|
||||
"deleted": deleted_count,
|
||||
"total": len(ids),
|
||||
"total": len(request.ids),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -253,6 +350,9 @@ async def router_object_move(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
# 存储 dst 的属性,避免后续数据库操作导致 dst 过期后无法访问
|
||||
dst_id = dst.id
|
||||
dst_parent_id = dst.parent_id
|
||||
@@ -264,6 +364,9 @@ async def router_object_move(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能移动根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -348,6 +451,9 @@ async def router_object_copy(
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
if dst.is_banned:
|
||||
http_exceptions.raise_banned("目标目录已被封禁,无法执行此操作")
|
||||
|
||||
copied_count = 0
|
||||
new_ids: list[UUID] = []
|
||||
|
||||
@@ -356,6 +462,9 @@ async def router_object_copy(
|
||||
if not src or src.owner_id != user_id:
|
||||
continue
|
||||
|
||||
if src.is_banned:
|
||||
continue
|
||||
|
||||
# 不能复制根目录
|
||||
if src.parent_id is None:
|
||||
continue
|
||||
@@ -438,6 +547,9 @@ async def router_object_rename(
|
||||
if obj.owner_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="无权操作此对象")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 不能重命名根目录
|
||||
if obj.parent_id is None:
|
||||
raise HTTPException(status_code=400, detail="无法重命名根目录")
|
||||
@@ -543,7 +655,7 @@ async def router_object_property_detail(
|
||||
policy_name = policy.name if policy else None
|
||||
|
||||
# 获取分享统计
|
||||
from models import Share
|
||||
from sqlmodels import Share
|
||||
shares = await Share.get(
|
||||
session,
|
||||
Share.object_id == obj.id,
|
||||
|
||||
@@ -7,11 +7,11 @@ from loguru import logger as l
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase
|
||||
from models.user import User
|
||||
from models.share import Share, ShareCreateRequest, ShareResponse
|
||||
from models.object import Object
|
||||
from models.mixin import ListResponse, TableViewRequest
|
||||
from sqlmodels import ResponseBase
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.share import Share, ShareCreateRequest, ShareResponse
|
||||
from sqlmodels.object import Object
|
||||
from sqlmodels.mixin import ListResponse, TableViewRequest
|
||||
from utils import http_exceptions
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -72,23 +72,6 @@ def router_share_preview(id: str) -> ResponseBase:
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/doc/{id}',
|
||||
summary='取得Office文档预览地址',
|
||||
description='Get Office document preview URL by ID.',
|
||||
)
|
||||
def router_share_doc(id: str) -> ResponseBase:
|
||||
"""
|
||||
Get Office document preview URL by ID.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the Office document.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the document preview URL.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@share_router.get(
|
||||
path='/content/{id}',
|
||||
summary='获取文本文件内容',
|
||||
@@ -261,6 +244,9 @@ async def router_share_create(
|
||||
if not obj or obj.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="对象不存在或无权限")
|
||||
|
||||
if obj.is_banned:
|
||||
http_exceptions.raise_banned()
|
||||
|
||||
# 生成分享码
|
||||
code = str(uuid4())
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels import ResponseBase, Setting, SettingsType, SiteConfigResponse
|
||||
from sqlmodels.setting import CaptchaType
|
||||
from utils import http_exceptions
|
||||
|
||||
site_router = APIRouter(
|
||||
@@ -43,16 +44,43 @@ def router_site_captcha():
|
||||
@site_router.get(
|
||||
path='/config',
|
||||
summary='站点全局配置',
|
||||
description='Get the configuration file.',
|
||||
response_model=ResponseBase,
|
||||
description='获取站点全局配置,包括验证码设置、注册开关等。',
|
||||
)
|
||||
async def router_site_config(session: SessionDep) -> SiteConfigResponse:
|
||||
"""
|
||||
Get the configuration file.
|
||||
获取站点全局配置
|
||||
|
||||
Returns:
|
||||
dict: The site configuration.
|
||||
无需认证。前端在初始化时调用此端点获取验证码类型、
|
||||
登录/注册/找回密码是否需要验证码等配置。
|
||||
"""
|
||||
# 批量查询所需设置
|
||||
settings: list[Setting] = await Setting.get(
|
||||
session,
|
||||
(Setting.type == SettingsType.BASIC) |
|
||||
(Setting.type == SettingsType.LOGIN) |
|
||||
(Setting.type == SettingsType.REGISTER) |
|
||||
(Setting.type == SettingsType.CAPTCHA),
|
||||
fetch_mode="all",
|
||||
)
|
||||
|
||||
# 构建 name→value 映射
|
||||
s: dict[str, str | None] = {item.name: item.value for item in settings}
|
||||
|
||||
# 根据 captcha_type 选择对应的 public key
|
||||
captcha_type_str = s.get("captcha_type", "default")
|
||||
captcha_type = CaptchaType(captcha_type_str) if captcha_type_str else CaptchaType.DEFAULT
|
||||
captcha_key: str | None = None
|
||||
if captcha_type == CaptchaType.GCAPTCHA:
|
||||
captcha_key = s.get("captcha_ReCaptchaKey") or None
|
||||
elif captcha_type == CaptchaType.CLOUD_FLARE_TURNSTILE:
|
||||
captcha_key = s.get("captcha_CloudflareKey") or None
|
||||
|
||||
return SiteConfigResponse(
|
||||
title=await Setting.get(session, (Setting.type == SettingsType.BASIC) & (Setting.name == "siteName")),
|
||||
title=s.get("siteName") or "DiskNext",
|
||||
register_enabled=s.get("register_enabled") == "1",
|
||||
login_captcha=s.get("login_captcha") == "1",
|
||||
reg_captcha=s.get("reg_captcha") == "1",
|
||||
forget_captcha=s.get("forget_captcha") == "1",
|
||||
captcha_type=captcha_type,
|
||||
captcha_key=captcha_key,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
slave_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import auth_required
|
||||
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
tag_router = APIRouter(
|
||||
|
||||
@@ -1,30 +1,26 @@
|
||||
from typing import Annotated, Literal
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from loguru import logger
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
from loguru import logger
|
||||
|
||||
import models
|
||||
import service
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils.JWT import SECRET_KEY
|
||||
from utils import Password, http_exceptions
|
||||
from utils import JWT, Password, http_exceptions
|
||||
from .settings import user_settings_router
|
||||
|
||||
user_router = APIRouter(
|
||||
prefix="/user",
|
||||
tags=["user"],
|
||||
)
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/user/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
user_router.include_router(user_settings_router)
|
||||
|
||||
@user_router.post(
|
||||
path='/session',
|
||||
@@ -34,7 +30,7 @@ user_settings_router = APIRouter(
|
||||
async def router_user_session(
|
||||
session: SessionDep,
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> models.TokenResponse:
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
用户登录端点。
|
||||
|
||||
@@ -43,7 +39,7 @@ async def router_user_session(
|
||||
|
||||
OAuth2 scopes 字段格式: "otp:123456" 或直接传入验证码
|
||||
"""
|
||||
username = form_data.username
|
||||
email = form_data.username # OAuth2 表单字段名为 username,实际传入的是 email
|
||||
password = form_data.password
|
||||
|
||||
# 从 scopes 中提取 OTP 验证码(OAuth2.1 扩展方式)
|
||||
@@ -59,8 +55,8 @@ async def router_user_session(
|
||||
|
||||
result = await service.user.login(
|
||||
session,
|
||||
models.LoginRequest(
|
||||
username=username,
|
||||
sqlmodels.LoginRequest(
|
||||
email=email,
|
||||
password=password,
|
||||
two_fa_code=otp_code,
|
||||
),
|
||||
@@ -75,19 +71,70 @@ async def router_user_session(
|
||||
)
|
||||
async def router_user_session_refresh(
|
||||
session: SessionDep,
|
||||
request, # RefreshTokenRequest
|
||||
) -> models.TokenResponse:
|
||||
http_exceptions.raise_not_implemented()
|
||||
request: sqlmodels.RefreshTokenRequest,
|
||||
) -> sqlmodels.TokenResponse:
|
||||
"""
|
||||
使用 refresh_token 签发新的 access_token 和 refresh_token。
|
||||
|
||||
流程:
|
||||
1. 解码 refresh_token JWT
|
||||
2. 验证 token_type 为 refresh
|
||||
3. 验证用户存在且状态正常
|
||||
4. 签发新的 access_token + refresh_token
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 刷新令牌请求
|
||||
:return: 新的 TokenResponse
|
||||
"""
|
||||
|
||||
try:
|
||||
payload = jwt.decode(request.refresh_token, JWT.SECRET_KEY, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
http_exceptions.raise_unauthorized("刷新令牌无效或已过期")
|
||||
|
||||
# 验证是 refresh token
|
||||
if payload.get("token_type") != "refresh":
|
||||
http_exceptions.raise_unauthorized("非刷新令牌")
|
||||
|
||||
user_id_str = payload.get("sub")
|
||||
if not user_id_str:
|
||||
http_exceptions.raise_unauthorized("令牌缺少用户标识")
|
||||
|
||||
user_id = UUID(user_id_str)
|
||||
user = await sqlmodels.User.get(session, sqlmodels.User.id == user_id)
|
||||
if not user:
|
||||
http_exceptions.raise_unauthorized("用户不存在")
|
||||
|
||||
if not user.status:
|
||||
http_exceptions.raise_forbidden("账户已被禁用")
|
||||
|
||||
# 签发新令牌
|
||||
access_token = JWT.create_access_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
refresh_token = JWT.create_refresh_token(
|
||||
sub=user.id,
|
||||
jti=uuid4(),
|
||||
)
|
||||
|
||||
return sqlmodels.TokenResponse(
|
||||
access_token=access_token.access_token,
|
||||
access_expires=access_token.access_expires,
|
||||
refresh_token=refresh_token.refresh_token,
|
||||
refresh_expires=refresh_token.refresh_expires,
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/',
|
||||
summary='用户注册',
|
||||
description='User registration endpoint.',
|
||||
status_code=204,
|
||||
)
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: models.RegisterRequest,
|
||||
) -> models.ResponseBase:
|
||||
request: sqlmodels.RegisterRequest,
|
||||
) -> None:
|
||||
"""
|
||||
用户注册端点
|
||||
|
||||
@@ -95,7 +142,7 @@ async def router_user_register(
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建以用户名命名的根目录
|
||||
4. 创建用户根目录(name="/")
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
@@ -103,62 +150,53 @@ async def router_user_register(
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
"""
|
||||
# 1. 验证用户名唯一性
|
||||
existing_user = await models.User.get(
|
||||
# 1. 验证邮箱唯一性
|
||||
existing_user = await sqlmodels.User.get(
|
||||
session,
|
||||
models.User.username == request.username
|
||||
sqlmodels.User.email == request.email
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
raise HTTPException(status_code=400, detail="邮箱已存在")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||
default_group_setting: sqlmodels.Setting | None = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == models.SettingsType.REGISTER) & (models.Setting.name == "default_group")
|
||||
(sqlmodels.Setting.type == sqlmodels.SettingsType.REGISTER) & (sqlmodels.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
default_group = await models.Group.get(session, models.Group.id == default_group_id)
|
||||
default_group = await sqlmodels.Group.get(session, sqlmodels.Group.id == default_group_id)
|
||||
if not default_group:
|
||||
logger.error("默认用户组不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
new_user = models.User(
|
||||
username=request.username,
|
||||
new_user = sqlmodels.User(
|
||||
email=request.email,
|
||||
password=hashed_password,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id # 在 save 前保存 UUID
|
||||
new_user_username = new_user.username
|
||||
new_user_id = new_user.id
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建以用户名命名的根目录
|
||||
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
|
||||
# 4. 创建用户根目录
|
||||
default_policy = await sqlmodels.Policy.get(session, sqlmodels.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
logger.error("默认存储策略不存在")
|
||||
http_exceptions.raise_internal_error()
|
||||
|
||||
await models.Object(
|
||||
name=new_user_username,
|
||||
type=models.ObjectType.FOLDER,
|
||||
await sqlmodels.Object(
|
||||
name="/",
|
||||
type=sqlmodels.ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"user_id": new_user_id,
|
||||
"username": new_user_username,
|
||||
},
|
||||
msg="注册成功",
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
summary='发送验证码邮件',
|
||||
@@ -166,7 +204,7 @@ async def router_user_register(
|
||||
)
|
||||
def router_user_email_code(
|
||||
reason: Literal['register', 'reset'] = 'register',
|
||||
) -> models.ResponseBase:
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Send a verification code email.
|
||||
|
||||
@@ -180,7 +218,7 @@ def router_user_email_code(
|
||||
summary='初始化QQ登录',
|
||||
description='Initialize QQ login for a user.',
|
||||
)
|
||||
def router_user_qq() -> models.ResponseBase:
|
||||
def router_user_qq() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize QQ login for a user.
|
||||
|
||||
@@ -194,7 +232,7 @@ def router_user_qq() -> models.ResponseBase:
|
||||
summary='WebAuthn登录初始化',
|
||||
description='Initialize WebAuthn login for a user.',
|
||||
)
|
||||
async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
async def router_user_authn(username: str) -> sqlmodels.ResponseBase:
|
||||
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@@ -203,7 +241,7 @@ async def router_user_authn(username: str) -> models.ResponseBase:
|
||||
summary='WebAuthn登录',
|
||||
description='Finish WebAuthn login for a user.',
|
||||
)
|
||||
def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
def router_user_authn_finish(username: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
@@ -220,7 +258,7 @@ def router_user_authn_finish(username: str) -> models.ResponseBase:
|
||||
summary='获取用户主页展示用分享',
|
||||
description='Get user profile for display.',
|
||||
)
|
||||
def router_user_profile(id: str) -> models.ResponseBase:
|
||||
def router_user_profile(id: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user profile for display.
|
||||
|
||||
@@ -237,7 +275,7 @@ def router_user_profile(id: str) -> models.ResponseBase:
|
||||
summary='获取用户头像',
|
||||
description='Get user avatar by ID and size.',
|
||||
)
|
||||
def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
def router_user_avatar(id: str, size: int = 128) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user avatar by ID and size.
|
||||
|
||||
@@ -259,12 +297,12 @@ def router_user_avatar(id: str, size: int = 128) -> models.ResponseBase:
|
||||
summary='获取用户信息',
|
||||
description='Get user information.',
|
||||
dependencies=[Depends(dependency=auth_required)],
|
||||
response_model=models.UserResponse,
|
||||
response_model=sqlmodels.UserResponse,
|
||||
)
|
||||
async def router_user_me(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserResponse:
|
||||
"""
|
||||
获取用户信息.
|
||||
|
||||
@@ -272,10 +310,10 @@ async def router_user_me(
|
||||
:rtype: ResponseBase
|
||||
"""
|
||||
# 加载 group 及其 options 关系
|
||||
group = await models.Group.get(
|
||||
group = await sqlmodels.Group.get(
|
||||
session,
|
||||
models.Group.id == user.group_id,
|
||||
load=models.Group.options
|
||||
sqlmodels.Group.id == user.group_id,
|
||||
load=sqlmodels.Group.options
|
||||
)
|
||||
|
||||
# 构建 GroupResponse
|
||||
@@ -284,9 +322,9 @@ async def router_user_me(
|
||||
# 异步加载 tags 关系
|
||||
user_tags = await user.awaitable_attrs.tags
|
||||
|
||||
return models.UserResponse(
|
||||
return sqlmodels.UserResponse(
|
||||
id=user.id,
|
||||
username=user.username,
|
||||
email=user.email,
|
||||
status=user.status,
|
||||
score=user.score,
|
||||
nickname=user.nickname,
|
||||
@@ -304,30 +342,26 @@ async def router_user_me(
|
||||
)
|
||||
async def router_user_storage(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserStorageResponse:
|
||||
"""
|
||||
获取用户存储空间信息。
|
||||
|
||||
返回值:
|
||||
- used: 已使用空间(字节)
|
||||
- free: 剩余空间(字节)
|
||||
- total: 总容量(字节)= 用户组容量
|
||||
"""
|
||||
# 获取用户组的基础存储容量
|
||||
group = await models.Group.get(session, models.Group.id == user.group_id)
|
||||
group = await sqlmodels.Group.get(session, sqlmodels.Group.id == user.group_id)
|
||||
if not group:
|
||||
raise HTTPException(status_code=500, detail="用户组不存在")
|
||||
raise HTTPException(status_code=404, detail="用户组不存在")
|
||||
|
||||
# [TODO] 总空间加上用户购买的额外空间
|
||||
|
||||
total: int = group.max_storage
|
||||
used: int = user.storage
|
||||
free: int = max(0, total - used)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={
|
||||
"used": used,
|
||||
"free": free,
|
||||
"total": total,
|
||||
}
|
||||
return sqlmodels.UserStorageResponse(
|
||||
used=used,
|
||||
free=free,
|
||||
total=total,
|
||||
)
|
||||
|
||||
@user_router.put(
|
||||
@@ -338,8 +372,8 @@ async def router_user_storage(
|
||||
)
|
||||
async def router_user_authn_start(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Initialize WebAuthn login for a user.
|
||||
|
||||
@@ -347,30 +381,30 @@ async def router_user_authn_start(
|
||||
dict: A dictionary containing WebAuthn initialization information.
|
||||
"""
|
||||
# TODO: 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
|
||||
authn_setting = await models.Setting.get(
|
||||
authn_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "authn") & (models.Setting.name == "authn_enabled")
|
||||
(sqlmodels.Setting.type == "authn") & (sqlmodels.Setting.name == "authn_enabled")
|
||||
)
|
||||
if not authn_setting or authn_setting.value != "1":
|
||||
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
|
||||
|
||||
site_url_setting = await models.Setting.get(
|
||||
site_url_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteURL")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteURL")
|
||||
)
|
||||
site_title_setting = await models.Setting.get(
|
||||
site_title_setting = await sqlmodels.Setting.get(
|
||||
session,
|
||||
(models.Setting.type == "basic") & (models.Setting.name == "siteTitle")
|
||||
(sqlmodels.Setting.type == "basic") & (sqlmodels.Setting.name == "siteTitle")
|
||||
)
|
||||
|
||||
options = generate_registration_options(
|
||||
rp_id=site_url_setting.value if site_url_setting else "",
|
||||
rp_name=site_title_setting.value if site_title_setting else "",
|
||||
user_name=user.username,
|
||||
user_display_name=user.nick or user.username,
|
||||
user_name=user.email,
|
||||
user_display_name=user.nickname or user.email,
|
||||
)
|
||||
|
||||
return models.ResponseBase(data=options_to_json_dict(options))
|
||||
return sqlmodels.ResponseBase(data=options_to_json_dict(options))
|
||||
|
||||
@user_router.put(
|
||||
path='/authn/finish',
|
||||
@@ -378,179 +412,11 @@ async def router_user_authn_start(
|
||||
description='Finish WebAuthn login for a user.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_authn_finish() -> models.ResponseBase:
|
||||
def router_user_authn_finish() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Finish WebAuthn login for a user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing WebAuthn login information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> models.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> models.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings() -> models.ResponseBase:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return models.ResponseBase(data=models.UserSettingResponse().model_dump())
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> models.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> models.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> models.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return models.ResponseBase(
|
||||
data=await Password.generate_totp(user.username)
|
||||
)
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[models.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> models.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return models.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
http_exceptions.raise_not_implemented()
|
||||
203
routers/api/v1/user/settings/__init__.py
Normal file
203
routers/api/v1/user/settings/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import sqlmodels
|
||||
from middleware.auth import auth_required
|
||||
from middleware.dependencies import SessionDep
|
||||
from utils import JWT, Password, http_exceptions
|
||||
|
||||
user_settings_router = APIRouter(
|
||||
prefix='/settings',
|
||||
tags=["user", "user_settings"],
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/policies',
|
||||
summary='获取用户可选存储策略',
|
||||
description='Get user selectable storage policies.',
|
||||
)
|
||||
def router_user_settings_policies() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable storage policies.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available storage policies for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/nodes',
|
||||
summary='获取用户可选节点',
|
||||
description='Get user selectable nodes.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_nodes() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user selectable nodes.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing available nodes for the user.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/tasks',
|
||||
summary='任务队列',
|
||||
description='Get user task queue.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_tasks() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get user task queue.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the user's task queue information.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/',
|
||||
summary='获取当前用户设定',
|
||||
description='Get current user settings.',
|
||||
)
|
||||
def router_user_settings(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.UserSettingResponse:
|
||||
"""
|
||||
Get current user settings.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the current user settings.
|
||||
"""
|
||||
return sqlmodels.UserSettingResponse(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
nickname=user.nickname,
|
||||
created_at=user.created_at,
|
||||
group_name=user.group.name,
|
||||
language=user.language,
|
||||
timezone=user.timezone,
|
||||
group_expires=user.group_expires,
|
||||
two_factor=user.two_factor is not None,
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/avatar',
|
||||
summary='从文件上传头像',
|
||||
description='Upload user avatar from file.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Upload user avatar from file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the avatar upload.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.put(
|
||||
path='/avatar',
|
||||
summary='设定为Gravatar头像',
|
||||
description='Set user avatar to Gravatar.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_avatar_gravatar() -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Set user avatar to Gravatar.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of setting the Gravatar avatar.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.patch(
|
||||
path='/{option}',
|
||||
summary='更新用户设定',
|
||||
description='Update user settings.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
def router_user_settings_patch(option: str) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Update user settings.
|
||||
|
||||
Args:
|
||||
option (str): The setting option to update.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of the settings update.
|
||||
"""
|
||||
http_exceptions.raise_not_implemented()
|
||||
|
||||
|
||||
@user_settings_router.get(
|
||||
path='/2fa',
|
||||
summary='获取两步验证初始化信息',
|
||||
description='Get two-factor authentication initialization information.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa(
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Get two-factor authentication initialization information.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing two-factor authentication setup information.
|
||||
"""
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data=await Password.generate_totp(user.email)
|
||||
)
|
||||
|
||||
|
||||
@user_settings_router.post(
|
||||
path='/2fa',
|
||||
summary='启用两步验证',
|
||||
description='Enable two-factor authentication.',
|
||||
dependencies=[Depends(auth_required)],
|
||||
)
|
||||
async def router_user_settings_2fa_enable(
|
||||
session: SessionDep,
|
||||
user: Annotated[sqlmodels.user.User, Depends(auth_required)],
|
||||
setup_token: str,
|
||||
code: str,
|
||||
) -> sqlmodels.ResponseBase:
|
||||
"""
|
||||
Enable two-factor authentication for the user.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the result of enabling two-factor authentication.
|
||||
"""
|
||||
|
||||
serializer = URLSafeTimedSerializer(JWT.SECRET_KEY)
|
||||
|
||||
try:
|
||||
# 1. 解包 Token,设置有效期(例如 600秒)
|
||||
secret = serializer.loads(setup_token, salt="2fa-setup-salt", max_age=600)
|
||||
except SignatureExpired:
|
||||
raise HTTPException(status_code=400, detail="Setup session expired")
|
||||
except BadSignature:
|
||||
raise HTTPException(status_code=400, detail="Invalid token")
|
||||
|
||||
# 2. 验证用户输入的 6 位验证码
|
||||
if not Password.verify_totp(secret, code):
|
||||
raise HTTPException(status_code=400, detail="Invalid OTP code")
|
||||
|
||||
# 3. 将 secret 存储到用户的数据库记录中,启用 2FA
|
||||
user.two_factor = secret
|
||||
user = await user.save(session)
|
||||
|
||||
return sqlmodels.ResponseBase(
|
||||
data={"message": "Two-factor authentication enabled successfully"}
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
vas_router = APIRouter(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from middleware.auth import auth_required
|
||||
from models import ResponseBase
|
||||
from sqlmodels import ResponseBase
|
||||
from utils import http_exceptions
|
||||
|
||||
# WebDAV 管理路由
|
||||
|
||||
Reference in New Issue
Block a user