feat: 更新模型以支持 UUID,添加注册请求 DTO,重构用户注册逻辑

This commit is contained in:
2025-12-19 16:32:49 +08:00
parent e031f3cc40
commit 922692b820
17 changed files with 380 additions and 147 deletions

View File

@@ -2,6 +2,7 @@ from . import response
from .user import (
LoginRequest,
RegisterRequest,
TokenResponse,
User,
UserBase,
@@ -21,6 +22,8 @@ from .object import (
DirectoryResponse,
Object,
ObjectBase,
ObjectDeleteRequest,
ObjectMoveRequest,
ObjectResponse,
ObjectType,
PolicyResponse,

View File

@@ -1,5 +1,8 @@
from typing import Optional, TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
@@ -31,7 +34,7 @@ class Download(DownloadBase, UUIDTableBase, table=True):
dst: str = Field(description="目标存储路径")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
task_id: int | None = Field(default=None, foreign_key="task.id", index=True, description="关联的任务ID")
node_id: int = Field(foreign_key="node.id", index=True, description="执行下载的节点ID")

View File

@@ -1,9 +1,10 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, text
from .base import TableBase, SQLModelBase
from .base import TableBase, SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
@@ -45,8 +46,8 @@ class GroupOptionsBase(SQLModelBase):
class GroupResponse(GroupBase, GroupOptionsBase):
"""用户组响应 DTO"""
id: int
"""用户组ID"""
id: UUID
"""用户组UUID"""
allow_share: bool = False
"""是否允许分享"""
@@ -72,8 +73,8 @@ class GroupResponse(GroupBase, GroupOptionsBase):
class GroupOptions(GroupOptionsBase, TableBase, table=True):
"""用户组选项模型"""
group_id: int = Field(foreign_key="group.id", unique=True)
"""关联的用户组ID"""
group_id: UUID = Field(foreign_key="group.id", unique=True)
"""关联的用户组UUID"""
archive_download: bool = False
"""是否允许打包下载"""
@@ -97,7 +98,7 @@ class GroupOptions(GroupOptionsBase, TableBase, table=True):
group: "Group" = Relationship(back_populates="options")
class Group(GroupBase, TableBase, table=True):
class Group(GroupBase, UUIDTableBase, table=True):
"""用户组模型"""
name: str = Field(max_length=255, unique=True)

View File

@@ -25,7 +25,7 @@ default_settings: list[Setting] = [
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
Setting(name="default_group", value="2", type=SettingsType.REGISTER),
Setting(name="default_group", value="", type=SettingsType.REGISTER), # UUID 在组创建后更新
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
@@ -138,24 +138,27 @@ async def init_default_settings() -> None:
async def init_default_group() -> None:
from .group import Group, GroupOptions
from .setting import Setting
from .database import get_session
log.info('初始化用户组...')
async for session in get_session():
# 未找到初始管理组时,则创建
if not await Group.get(session, Group.id == 1):
admin_group = await Group(
if not await Group.get(session, Group.name == "管理员"):
admin_group = Group(
name="管理员",
policies="1",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
admin=True,
).save(session)
assert admin_group.id is not None
)
admin_group_id = admin_group.id # 在 save 前保存 UUID
await admin_group.save(session)
await GroupOptions(
group_id=admin_group.id,
group_id=admin_group_id,
archive_download=True,
archive_task=True,
share_download=True,
@@ -166,30 +169,40 @@ async def init_default_group() -> None:
).save(session)
# 未找到初始注册会员时,则创建
if not await Group.get(session, Group.id == 2):
member_group = await Group(
if not await Group.get(session, Group.name == "注册会员"):
member_group = Group(
name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
).save(session)
assert member_group.id is not None
)
member_group_id = member_group.id # 在 save 前保存 UUID
await member_group.save(session)
await GroupOptions(
group_id=member_group.id,
group_id=member_group_id,
share_download=True,
).save(session)
# 更新 default_group 设置为注册会员组的 UUID
default_group_setting = await Setting.get(session, Setting.name == "default_group")
if default_group_setting:
default_group_setting.value = str(member_group_id)
await default_group_setting.save(session)
# 未找到初始游客组时,则创建
if not await Group.get(session, Group.id == 3):
guest_group = await Group(
if not await Group.get(session, Group.name == "游客"):
guest_group = Group(
name="游客",
policies="[]",
share_enabled=False,
web_dav_enabled=False,
).save(session)
assert guest_group.id is not None
)
guest_group_id = guest_group.id # 在 save 前保存 UUID
await guest_group.save(session)
await GroupOptions(
group_id=guest_group.id,
group_id=guest_group_id,
share_download=True,
).save(session)
@@ -203,11 +216,11 @@ async def init_default_user() -> None:
async for session in get_session():
# 检查管理员用户是否存在
admin_user = await User.get(session, User.id == 1)
admin_user = await User.get(session, User.username == "admin")
if not admin_user:
# 获取管理员组
admin_group = await Group.get(session, Group.id == 1)
admin_group = await Group.get(session, Group.name == "管理员")
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
@@ -215,19 +228,22 @@ async def init_default_user() -> None:
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = await User(
admin_user = User(
username="admin",
nick="admin",
nickname="admin",
status=True,
group_id=admin_group.id,
password=hashed_admin_password,
).save(session)
)
admin_user_id = admin_user.id # 在 save 前保存 UUID
admin_username = admin_user.username
await admin_user.save(session)
# 为管理员创建根目录(使用默认存储策略
# 为管理员创建根目录(使用用户名作为目录名
await Object(
name="my",
name=admin_username,
type=ObjectType.FOLDER,
owner_id=admin_user.id,
owner_id=admin_user_id,
parent_id=None,
policy_id=1, # 默认本地存储策略
).save(session)
@@ -244,7 +260,7 @@ async def init_default_policy() -> None:
async for session in get_session():
# 检查默认存储策略是否存在
default_policy = await Policy.get(session, Policy.id == 1)
default_policy = await Policy.get(session, Policy.name == "本地存储")
if not default_policy:
local_policy = Policy(

View File

@@ -1,11 +1,12 @@
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Optional
from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index
from .base import TableBase, SQLModelBase
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
@@ -40,18 +41,38 @@ class ObjectBase(SQLModelBase):
class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO"""
path: str
"""目录路径,如 /docs/images"""
parent_id: UUID
"""目录UUID"""
name: str
"""目录名称"""
policy_id: int | None = None
"""存储策略ID不指定则继承父目录"""
class ObjectMoveRequest(SQLModelBase):
"""移动对象请求 DTO"""
src_ids: list[UUID]
"""源对象UUID列表"""
dst_id: UUID
"""目标文件夹UUID"""
class ObjectDeleteRequest(SQLModelBase):
"""删除对象请求 DTO"""
ids: list[UUID]
"""待删除对象UUID列表"""
class ObjectResponse(ObjectBase):
"""对象响应 DTO"""
id: str
"""对象ID"""
id: UUID
"""对象UUID"""
path: str
"""对象路径"""
@@ -91,8 +112,11 @@ class PolicyResponse(SQLModelBase):
class DirectoryResponse(SQLModelBase):
"""目录响应 DTO"""
parent: str | None = None
"""父目录ID根目录为None"""
id: UUID
"""当前目录UUID"""
parent: UUID | None = None
"""父目录UUID根目录为None"""
objects: list[ObjectResponse] = []
"""目录下的对象列表"""
@@ -103,16 +127,17 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ====================
class Object(ObjectBase, TableBase, table=True):
class Object(ObjectBase, UUIDTableBase, table=True):
"""
统一对象模型
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
根目录规则:
- 每个用户有一个显式根目录对象name="my", parent_id=NULL
- 每个用户有一个显式根目录对象name=用户的username, parent_id=NULL
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
- 根目录的 policy_id 指定用户默认存储策略
- 路径格式:/username/path/to/file如 /admin/docs/readme.md
"""
__table_args__ = (
@@ -158,11 +183,11 @@ class Object(ObjectBase, TableBase, table=True):
# ==================== 外键 ====================
parent_id: int | None = Field(default=None, foreign_key="object.id", index=True)
"""父目录IDNULL 表示这是用户的根目录"""
parent_id: UUID | None = Field(default=None, foreign_key="object.id", index=True)
"""父目录UUIDNULL 表示这是用户的根目录"""
owner_id: int = Field(foreign_key="user.id", index=True)
"""所有者用户ID"""
owner_id: UUID = Field(foreign_key="user.id", index=True)
"""所有者用户UUID"""
policy_id: int = Field(foreign_key="policy.id", index=True)
"""存储策略ID文件直接使用目录作为子文件的默认策略"""
@@ -207,12 +232,12 @@ class Object(ObjectBase, TableBase, table=True):
# ==================== 业务方法 ====================
@classmethod
async def get_root(cls, session, user_id: int) -> "Object | None":
async def get_root(cls, session, user_id: UUID) -> "Object | None":
"""
获取用户的根目录
:param session: 数据库会话
:param user_id: 用户ID
:param user_id: 用户UUID
:return: 根目录对象,不存在则返回 None
"""
return await cls.get(
@@ -221,33 +246,51 @@ class Object(ObjectBase, TableBase, table=True):
)
@classmethod
async def get_by_path(cls, session, user_id: int, path: str) -> "Object | None":
async def get_by_path(
cls,
session,
user_id: UUID,
path: str,
username: str,
) -> "Object | None":
"""
根据路径获取对象
:param session: 数据库会话
:param user_id: 用户ID
:param path: 路径,如 "/""/docs/images"
:param user_id: 用户UUID
:param path: 路径,如 "/username""/username/docs/images"
:param username: 用户名,用于识别根目录
:return: Object 或 None
"""
path = path.strip()
if not path:
raise ValueError("路径不能为空")
if path in ["/my"]:
return await cls.get_root(session, user_id)
# 获取用户根目录
root = await cls.get_root(session, user_id)
if not root:
return None
# 移除开头的斜杠并分割路径
if path.startswith("/"):
path = path[1:]
parts = [p for p in path.split("/") if p]
# 空路径 -> 返回根目录
if not parts:
return await cls.get_root(session, user_id)
return root
# 从根目录开始遍历
current = await cls.get_root(session, user_id)
# 检查第一部分是否是用户名(根目录名)
if parts[0] == username:
# 路径以用户名开头,如 /admin/docs
if len(parts) == 1:
# 只有用户名,返回根目录
return root
# 去掉用户名部分,从第二个部分开始遍历
parts = parts[1:]
# 从根目录开始遍历剩余路径
current = root
for part in parts:
if not current:
return None
@@ -262,13 +305,13 @@ class Object(ObjectBase, TableBase, table=True):
return current
@classmethod
async def get_children(cls, session, user_id: int, parent_id: int) -> list["Object"]:
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
"""
获取目录下的所有子对象
:param session: 数据库会话
:param user_id: 用户ID
:param parent_id: 父目录ID
:param user_id: 用户UUID
:param parent_id: 父目录UUID
:return: 子对象列表
"""
return await cls.get(

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import TableBase
if TYPE_CHECKING:
@@ -19,7 +22,7 @@ class Order(TableBase, table=True):
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="orders")

View File

@@ -1,7 +1,10 @@
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from .base import TableBase
if TYPE_CHECKING:
@@ -26,8 +29,8 @@ class Share(TableBase, table=True):
password: str | None = Field(default=None, max_length=255)
"""分享密码(加密后)"""
object_id: int = Field(foreign_key="object.id", index=True)
"""关联的对象ID"""
object_id: UUID = Field(foreign_key="object.id", index=True)
"""关联的对象UUID"""
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""浏览次数"""
@@ -51,8 +54,8 @@ class Share(TableBase, table=True):
"""兑换此分享所需的积分"""
# 外键
user_id: int = Field(foreign_key="user.id", index=True)
"""创建分享的用户ID"""
user_id: UUID = Field(foreign_key="user.id", index=True)
"""创建分享的用户UUID"""
# 关系
user: "User" = Relationship(back_populates="shares")

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, Index
from .base import TableBase
if TYPE_CHECKING:
@@ -21,8 +24,8 @@ class SourceLink(TableBase, table=True):
"""通过此链接的下载次数"""
# 外键
object_id: int = Field(foreign_key="object.id", index=True)
"""关联的对象ID必须是文件类型"""
object_id: UUID = Field(foreign_key="object.id", index=True)
"""关联的对象UUID必须是文件类型"""
# 关系
object: "Object" = Relationship(back_populates="source_links")

View File

@@ -1,9 +1,11 @@
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import TableBase
from datetime import datetime
if TYPE_CHECKING:
from .user import User
@@ -17,7 +19,7 @@ class StoragePack(TableBase, table=True):
size: int = Field(description="容量包大小(字节)")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="storage_packs")

View File

@@ -1,9 +1,12 @@
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import TableBase
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import TableBase
if TYPE_CHECKING:
from .user import User
@@ -19,7 +22,7 @@ class Tag(TableBase, table=True):
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="tags")

View File

@@ -1,9 +1,12 @@
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, CheckConstraint
from .base import TableBase
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, CheckConstraint
from .base import TableBase
if TYPE_CHECKING:
from .user import User
from .download import Download
@@ -22,7 +25,7 @@ class Task(TableBase, table=True):
props: str | None = Field(default=None, description="任务属性 (JSON格式)")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="tasks")

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from enum import StrEnum
from typing import Literal, Optional, TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import TableBase, SQLModelBase
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .group import Group
@@ -75,6 +76,19 @@ class LoginRequest(SQLModelBase):
"""两步验证代码"""
class RegisterRequest(SQLModelBase):
"""注册请求 DTO"""
username: str
"""用户名,唯一,一经注册不可更改"""
password: str
"""用户密码"""
captcha: str | None = None
"""验证码"""
class WebAuthnInfo(SQLModelBase):
"""WebAuthn 信息 DTO"""
@@ -116,8 +130,8 @@ class TokenResponse(SQLModelBase):
class UserResponse(UserBase):
"""用户响应 DTO"""
id: int
"""用户ID"""
id: UUID
"""用户UUID"""
nickname: str | None = None
"""用户昵称"""
@@ -141,8 +155,8 @@ class UserResponse(UserBase):
class UserPublic(UserBase):
"""用户公开信息 DTO用于 API 响应"""
id: int | None = None
"""用户ID"""
id: UUID | None = None
"""用户UUID"""
nick: str | None = None
"""昵称"""
@@ -156,8 +170,8 @@ class UserPublic(UserBase):
group_expires: datetime | None = None
"""用户组过期时间"""
group_id: int | None = None
"""所属用户组ID"""
group_id: UUID | None = None
"""所属用户组UUID"""
created_at: datetime | None = None
"""创建时间"""
@@ -187,8 +201,8 @@ class UserSettingResponse(SQLModelBase):
two_factor: bool = False
"""是否启用两步验证"""
uid: int = 0
"""用户UID"""
uid: UUID | None = None
"""用户UUID"""
# 前向引用导入
@@ -202,7 +216,7 @@ UserSettingResponse.model_rebuild()
# ==================== 数据库模型 ====================
class User(UserBase, TableBase, table=True):
class User(UserBase, UUIDTableBase, table=True):
"""用户模型"""
username: str = Field(max_length=50, unique=True, index=True)
@@ -243,11 +257,11 @@ class User(UserBase, TableBase, table=True):
"""时区UTC 偏移小时数"""
# 外键
group_id: int = Field(foreign_key="group.id", index=True)
"""所属用户组ID"""
group_id: UUID = Field(foreign_key="group.id", index=True)
"""所属用户组UUID"""
previous_group_id: int | None = Field(default=None, foreign_key="group.id")
"""之前的用户组ID用于过期后恢复"""
previous_group_id: UUID | None = Field(default=None, foreign_key="group.id")
"""之前的用户组UUID用于过期后恢复"""
# 关系

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
@@ -48,8 +49,8 @@ class UserAuthn(TableBase, table=True):
"""用户自定义的凭证名称,便于识别"""
# 外键
user_id: int = Field(foreign_key="user.id", index=True)
"""所属用户ID"""
user_id: UUID = Field(foreign_key="user.id", index=True)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="authns")

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
from .base import TableBase
if TYPE_CHECKING:
@@ -18,7 +21,7 @@ class WebDAV(TableBase, table=True):
use_proxy: bool = Field(default=False, description="是否使用代理下载")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="webdavs")

View File

@@ -37,7 +37,7 @@ async def router_directory_get(
:param path: 目录路径
:return: 目录内容
"""
folder = await Object.get_by_path(session, user.id, path or "/")
folder = await Object.get_by_path(session, user.id, path or "/", user.username)
if not folder:
raise HTTPException(status_code=404, detail="目录不存在")
@@ -50,7 +50,7 @@ async def router_directory_get(
objects = [
ObjectResponse(
id=str(child.id),
id=child.id,
name=child.name,
path=f"/{child.name}", # TODO: 完整路径
thumb=False,
@@ -63,7 +63,7 @@ async def router_directory_get(
for child in children
]
policy=PolicyResponse(
policy_response = PolicyResponse(
id=str(policy.id),
name=policy.name,
type=policy.type.value,
@@ -71,9 +71,10 @@ async def router_directory_get(
)
return DirectoryResponse(
parent=str(folder.parent_id) if folder.parent_id else None,
id=folder.id,
parent=folder.parent_id,
objects=objects,
policy=policy,
policy=policy_response,
)
@@ -91,26 +92,20 @@ async def router_directory_create(
:param session: 数据库会话
:param user: 当前登录用户
:param request: 创建请求
:param request: 创建请求(包含 parent_id UUID 和 name
:return: 创建结果
"""
path = request.path.strip()
if not path or path == "/":
raise HTTPException(status_code=400, detail="路径不能为空或根目录")
# 验证目录名称
name = request.name.strip()
if not name:
raise HTTPException(status_code=400, detail="目录名称不能为空")
# 解析路径
if path.startswith("/"):
path = path[1:]
parts = [p for p in path.split("/") if p]
if "/" in name or "\\" in name:
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
if not parts:
raise HTTPException(status_code=400, detail="无效的目录路径")
new_folder_name = parts[-1]
parent_path = "/" + "/".join(parts[:-1]) if len(parts) > 1 else "/"
parent = await Object.get_by_path(session, user.id, parent_path)
if not parent:
# 通过 UUID 获取父目录
parent = await Object.get(session, Object.id == request.parent_id)
if not parent or parent.owner_id != user.id:
raise HTTPException(status_code=404, detail="父目录不存在")
if not parent.is_folder:
@@ -121,25 +116,29 @@ async def router_directory_create(
session,
(Object.owner_id == user.id) &
(Object.parent_id == parent.id) &
(Object.name == new_folder_name)
(Object.name == name)
)
if existing:
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
policy_id = request.policy_id if request.policy_id else parent.policy_id
parent_id = parent.id # 在 save 前保存
new_folder = await Object(
name=new_folder_name,
new_folder = Object(
name=name,
type=ObjectType.FOLDER,
owner_id=user.id,
parent_id=parent.id,
parent_id=parent_id,
policy_id=policy_id,
).save(session)
)
new_folder_id = new_folder.id # 在 save 前保存 UUID
new_folder_name = new_folder.name
await new_folder.save(session)
return response.ResponseModel(
data={
"id": new_folder.id,
"name": new_folder.name,
"path": f"{parent_path.rstrip('/')}/{new_folder_name}",
"id": new_folder_id,
"name": new_folder_name,
"parent_id": parent_id,
}
)

View File

@@ -1,5 +1,10 @@
from fastapi import APIRouter, Depends
from middleware.auth import SignRequired
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from middleware.auth import AuthRequired
from middleware.dependencies import SessionDep
from models import Object, ObjectDeleteRequest, ObjectMoveRequest, User
from models.response import ResponseModel
object_router = APIRouter(
@@ -7,41 +12,106 @@ object_router = APIRouter(
tags=["object"]
)
@object_router.delete(
path='/',
summary='删除对象',
description='Delete an object endpoint.',
dependencies=[Depends(SignRequired)]
description='删除一个或多个对象(文件或目录)',
)
def router_object_delete() -> ResponseModel:
async def router_object_delete(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: ObjectDeleteRequest,
) -> ResponseModel:
"""
Delete an object endpoint.
Returns:
ResponseModel: A model containing the response data for the object deletion.
删除对象端点
:param session: 数据库会话
:param user: 当前登录用户
:param request: 删除请求包含待删除对象的UUID列表
:return: 删除结果
"""
pass
deleted_count = 0
for obj_id in request.ids:
obj = await Object.get(session, Object.id == obj_id)
if obj and obj.owner_id == user.id:
# TODO: 递归删除子对象(如果是目录)
# TODO: 更新用户存储空间
await obj.delete(session)
deleted_count += 1
return ResponseModel(
data={
"deleted": deleted_count,
"total": len(request.ids),
}
)
@object_router.patch(
path='/',
summary='移动对象',
description='Move an object endpoint.',
dependencies=[Depends(SignRequired)]
description='移动一个或多个对象到目标目录',
)
def router_object_move() -> ResponseModel:
async def router_object_move(
session: SessionDep,
user: Annotated[User, Depends(AuthRequired)],
request: ObjectMoveRequest,
) -> ResponseModel:
"""
Move an object endpoint.
Returns:
ResponseModel: A model containing the response data for the object move.
移动对象端点
:param session: 数据库会话
:param user: 当前登录用户
:param request: 移动请求包含源对象UUID列表和目标目录UUID
:return: 移动结果
"""
pass
# 验证目标目录
dst = await Object.get(session, Object.id == request.dst_id)
if not dst or dst.owner_id != user.id:
raise HTTPException(status_code=404, detail="目标目录不存在")
if not dst.is_folder:
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
moved_count = 0
for src_id in request.src_ids:
src = await Object.get(session, Object.id == src_id)
if not src or src.owner_id != user.id:
continue
# 检查是否移动到自身或子目录(防止循环引用)
if src.id == dst.id:
continue
# 检查目标目录下是否存在同名对象
existing = await Object.get(
session,
(Object.owner_id == user.id) &
(Object.parent_id == dst.id) &
(Object.name == src.name)
)
if existing:
continue # 跳过重名对象
src.parent_id = dst.id
await src.save(session)
moved_count += 1
return ResponseModel(
data={
"moved": moved_count,
"total": len(request.src_ids),
}
)
@object_router.post(
path='/copy',
summary='复制对象',
description='Copy an object endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_object_copy() -> ResponseModel:
"""
@@ -56,7 +126,7 @@ def router_object_copy() -> ResponseModel:
path='/rename',
summary='重命名对象',
description='Rename an object endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_object_rename() -> ResponseModel:
"""
@@ -71,7 +141,7 @@ def router_object_rename() -> ResponseModel:
path='/property/{id}',
summary='获取对象属性',
description='Get object properties endpoint.',
dependencies=[Depends(SignRequired)]
dependencies=[Depends(AuthRequired)]
)
def router_object_property(id: str) -> ResponseModel:
"""

View File

@@ -1,11 +1,11 @@
from typing import Annotated, Literal
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy import and_
from webauthn import generate_registration_options
from webauthn.helpers import options_to_json_dict
import pyotp
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
import models
@@ -93,14 +93,77 @@ async def router_user_session(
summary='用户注册',
description='User registration endpoint.',
)
def router_user_register() -> models.response.ResponseModel:
async def router_user_register(
session: SessionDep,
request: models.RegisterRequest,
) -> models.response.ResponseModel:
"""
User registration endpoint.
Returns:
dict: A dictionary containing user registration information.
用户注册端点
流程:
1. 验证用户名唯一性
2. 获取默认用户组
3. 创建用户记录
4. 创建以用户名命名的根目录
:param session: 数据库会话
:param request: 注册请求
:return: 注册结果
:raises HTTPException 400: 用户名已存在
:raises HTTPException 500: 默认用户组或存储策略不存在
"""
pass
# 1. 验证用户名唯一性
existing_user = await models.User.get(
session,
models.User.username == request.username
)
if existing_user:
raise HTTPException(status_code=400, detail="用户名已存在")
# 2. 获取默认用户组(从设置中读取 UUID
default_group_setting: models.Setting | None = await models.Setting.get(
session,
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
)
if default_group_setting is None or not default_group_setting.value:
raise HTTPException(status_code=500, detail="默认用户组设置不存在")
default_group_id = UUID(default_group_setting.value)
default_group = await models.Group.get(session, models.Group.id == default_group_id)
if not default_group:
raise HTTPException(status_code=500, detail="默认用户组不存在")
# 3. 创建用户
hashed_password = Password.hash(request.password)
new_user = models.User(
username=request.username,
password=hashed_password,
group_id=default_group.id,
)
new_user_id = new_user.id # 在 save 前保存 UUID
new_user_username = new_user.username
await new_user.save(session)
# 4. 创建以用户名命名的根目录
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
if not default_policy:
raise HTTPException(status_code=500, detail="默认存储策略不存在")
await models.Object(
name=new_user_username,
type=models.ObjectType.FOLDER,
owner_id=new_user_id,
parent_id=None,
policy_id=default_policy.id,
).save(session)
return models.response.ResponseModel(
data={
"user_id": new_user_id,
"username": new_user_username,
},
msg="注册成功",
)
@user_router.post(
path='/code',