feat: 更新模型以支持 UUID,添加注册请求 DTO,重构用户注册逻辑
This commit is contained in:
@@ -2,6 +2,7 @@ from . import response
|
||||
|
||||
from .user import (
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
TokenResponse,
|
||||
User,
|
||||
UserBase,
|
||||
@@ -21,6 +22,8 @@ from .object import (
|
||||
DirectoryResponse,
|
||||
Object,
|
||||
ObjectBase,
|
||||
ObjectDeleteRequest,
|
||||
ObjectMoveRequest,
|
||||
ObjectResponse,
|
||||
ObjectType,
|
||||
PolicyResponse,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -31,7 +34,7 @@ class Download(DownloadBase, UUIDTableBase, table=True):
|
||||
dst: str = Field(description="目标存储路径")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
task_id: int | None = Field(default=None, foreign_key="task.id", index=True, description="关联的任务ID")
|
||||
node_id: int = Field(foreign_key="node.id", index=True, description="执行下载的节点ID")
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text
|
||||
|
||||
from .base import TableBase, SQLModelBase
|
||||
from .base import TableBase, SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -45,8 +46,8 @@ class GroupOptionsBase(SQLModelBase):
|
||||
class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
"""用户组响应 DTO"""
|
||||
|
||||
id: int
|
||||
"""用户组ID"""
|
||||
id: UUID
|
||||
"""用户组UUID"""
|
||||
|
||||
allow_share: bool = False
|
||||
"""是否允许分享"""
|
||||
@@ -72,8 +73,8 @@ class GroupResponse(GroupBase, GroupOptionsBase):
|
||||
class GroupOptions(GroupOptionsBase, TableBase, table=True):
|
||||
"""用户组选项模型"""
|
||||
|
||||
group_id: int = Field(foreign_key="group.id", unique=True)
|
||||
"""关联的用户组ID"""
|
||||
group_id: UUID = Field(foreign_key="group.id", unique=True)
|
||||
"""关联的用户组UUID"""
|
||||
|
||||
archive_download: bool = False
|
||||
"""是否允许打包下载"""
|
||||
@@ -97,7 +98,7 @@ class GroupOptions(GroupOptionsBase, TableBase, table=True):
|
||||
group: "Group" = Relationship(back_populates="options")
|
||||
|
||||
|
||||
class Group(GroupBase, TableBase, table=True):
|
||||
class Group(GroupBase, UUIDTableBase, table=True):
|
||||
"""用户组模型"""
|
||||
|
||||
name: str = Field(max_length=255, unique=True)
|
||||
|
||||
@@ -25,7 +25,7 @@ default_settings: list[Setting] = [
|
||||
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
|
||||
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
|
||||
Setting(name="default_group", value="2", type=SettingsType.REGISTER),
|
||||
Setting(name="default_group", value="", type=SettingsType.REGISTER), # UUID 在组创建后更新
|
||||
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
|
||||
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
|
||||
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
|
||||
@@ -138,24 +138,27 @@ async def init_default_settings() -> None:
|
||||
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group, GroupOptions
|
||||
from .setting import Setting
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
async for session in get_session():
|
||||
# 未找到初始管理组时,则创建
|
||||
if not await Group.get(session, Group.id == 1):
|
||||
admin_group = await Group(
|
||||
if not await Group.get(session, Group.name == "管理员"):
|
||||
admin_group = Group(
|
||||
name="管理员",
|
||||
policies="1",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
admin=True,
|
||||
).save(session)
|
||||
assert admin_group.id is not None
|
||||
)
|
||||
admin_group_id = admin_group.id # 在 save 前保存 UUID
|
||||
await admin_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=admin_group.id,
|
||||
group_id=admin_group_id,
|
||||
archive_download=True,
|
||||
archive_task=True,
|
||||
share_download=True,
|
||||
@@ -166,30 +169,40 @@ async def init_default_group() -> None:
|
||||
).save(session)
|
||||
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not await Group.get(session, Group.id == 2):
|
||||
member_group = await Group(
|
||||
if not await Group.get(session, Group.name == "注册会员"):
|
||||
member_group = Group(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
).save(session)
|
||||
assert member_group.id is not None
|
||||
)
|
||||
member_group_id = member_group.id # 在 save 前保存 UUID
|
||||
await member_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=member_group.id,
|
||||
group_id=member_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
# 更新 default_group 设置为注册会员组的 UUID
|
||||
default_group_setting = await Setting.get(session, Setting.name == "default_group")
|
||||
if default_group_setting:
|
||||
default_group_setting.value = str(member_group_id)
|
||||
await default_group_setting.save(session)
|
||||
|
||||
# 未找到初始游客组时,则创建
|
||||
if not await Group.get(session, Group.id == 3):
|
||||
guest_group = await Group(
|
||||
if not await Group.get(session, Group.name == "游客"):
|
||||
guest_group = Group(
|
||||
name="游客",
|
||||
policies="[]",
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
).save(session)
|
||||
assert guest_group.id is not None
|
||||
)
|
||||
guest_group_id = guest_group.id # 在 save 前保存 UUID
|
||||
await guest_group.save(session)
|
||||
|
||||
await GroupOptions(
|
||||
group_id=guest_group.id,
|
||||
group_id=guest_group_id,
|
||||
share_download=True,
|
||||
).save(session)
|
||||
|
||||
@@ -203,11 +216,11 @@ async def init_default_user() -> None:
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.id == 1)
|
||||
admin_user = await User.get(session, User.username == "admin")
|
||||
|
||||
if not admin_user:
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(session, Group.id == 1)
|
||||
admin_group = await Group.get(session, Group.name == "管理员")
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
@@ -215,19 +228,22 @@ async def init_default_user() -> None:
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = await User(
|
||||
admin_user = User(
|
||||
username="admin",
|
||||
nick="admin",
|
||||
nickname="admin",
|
||||
status=True,
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
).save(session)
|
||||
)
|
||||
admin_user_id = admin_user.id # 在 save 前保存 UUID
|
||||
admin_username = admin_user.username
|
||||
await admin_user.save(session)
|
||||
|
||||
# 为管理员创建根目录(使用默认存储策略)
|
||||
# 为管理员创建根目录(使用用户名作为目录名)
|
||||
await Object(
|
||||
name="my",
|
||||
name=admin_username,
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user.id,
|
||||
owner_id=admin_user_id,
|
||||
parent_id=None,
|
||||
policy_id=1, # 默认本地存储策略
|
||||
).save(session)
|
||||
@@ -244,7 +260,7 @@ async def init_default_policy() -> None:
|
||||
|
||||
async for session in get_session():
|
||||
# 检查默认存储策略是否存在
|
||||
default_policy = await Policy.get(session, Policy.id == 1)
|
||||
default_policy = await Policy.get(session, Policy.name == "本地存储")
|
||||
|
||||
if not default_policy:
|
||||
local_policy = Policy(
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from enum import StrEnum
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index
|
||||
|
||||
from .base import TableBase, SQLModelBase
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -40,18 +41,38 @@ class ObjectBase(SQLModelBase):
|
||||
class DirectoryCreateRequest(SQLModelBase):
|
||||
"""创建目录请求 DTO"""
|
||||
|
||||
path: str
|
||||
"""目录路径,如 /docs/images"""
|
||||
parent_id: UUID
|
||||
"""父目录UUID"""
|
||||
|
||||
name: str
|
||||
"""目录名称"""
|
||||
|
||||
policy_id: int | None = None
|
||||
"""存储策略ID,不指定则继承父目录"""
|
||||
|
||||
|
||||
class ObjectMoveRequest(SQLModelBase):
|
||||
"""移动对象请求 DTO"""
|
||||
|
||||
src_ids: list[UUID]
|
||||
"""源对象UUID列表"""
|
||||
|
||||
dst_id: UUID
|
||||
"""目标文件夹UUID"""
|
||||
|
||||
|
||||
class ObjectDeleteRequest(SQLModelBase):
|
||||
"""删除对象请求 DTO"""
|
||||
|
||||
ids: list[UUID]
|
||||
"""待删除对象UUID列表"""
|
||||
|
||||
|
||||
class ObjectResponse(ObjectBase):
|
||||
"""对象响应 DTO"""
|
||||
|
||||
id: str
|
||||
"""对象ID"""
|
||||
id: UUID
|
||||
"""对象UUID"""
|
||||
|
||||
path: str
|
||||
"""对象路径"""
|
||||
@@ -91,8 +112,11 @@ class PolicyResponse(SQLModelBase):
|
||||
class DirectoryResponse(SQLModelBase):
|
||||
"""目录响应 DTO"""
|
||||
|
||||
parent: str | None = None
|
||||
"""父目录ID,根目录为None"""
|
||||
id: UUID
|
||||
"""当前目录UUID"""
|
||||
|
||||
parent: UUID | None = None
|
||||
"""父目录UUID,根目录为None"""
|
||||
|
||||
objects: list[ObjectResponse] = []
|
||||
"""目录下的对象列表"""
|
||||
@@ -103,16 +127,17 @@ class DirectoryResponse(SQLModelBase):
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class Object(ObjectBase, TableBase, table=True):
|
||||
class Object(ObjectBase, UUIDTableBase, table=True):
|
||||
"""
|
||||
统一对象模型
|
||||
|
||||
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
|
||||
|
||||
根目录规则:
|
||||
- 每个用户有一个显式根目录对象(name="my", parent_id=NULL)
|
||||
- 每个用户有一个显式根目录对象(name=用户的username, parent_id=NULL)
|
||||
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
|
||||
- 根目录的 policy_id 指定用户默认存储策略
|
||||
- 路径格式:/username/path/to/file(如 /admin/docs/readme.md)
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
@@ -158,11 +183,11 @@ class Object(ObjectBase, TableBase, table=True):
|
||||
|
||||
# ==================== 外键 ====================
|
||||
|
||||
parent_id: int | None = Field(default=None, foreign_key="object.id", index=True)
|
||||
"""父目录ID,NULL 表示这是用户的根目录"""
|
||||
parent_id: UUID | None = Field(default=None, foreign_key="object.id", index=True)
|
||||
"""父目录UUID,NULL 表示这是用户的根目录"""
|
||||
|
||||
owner_id: int = Field(foreign_key="user.id", index=True)
|
||||
"""所有者用户ID"""
|
||||
owner_id: UUID = Field(foreign_key="user.id", index=True)
|
||||
"""所有者用户UUID"""
|
||||
|
||||
policy_id: int = Field(foreign_key="policy.id", index=True)
|
||||
"""存储策略ID(文件直接使用,目录作为子文件的默认策略)"""
|
||||
@@ -207,12 +232,12 @@ class Object(ObjectBase, TableBase, table=True):
|
||||
# ==================== 业务方法 ====================
|
||||
|
||||
@classmethod
|
||||
async def get_root(cls, session, user_id: int) -> "Object | None":
|
||||
async def get_root(cls, session, user_id: UUID) -> "Object | None":
|
||||
"""
|
||||
获取用户的根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户ID
|
||||
:param user_id: 用户UUID
|
||||
:return: 根目录对象,不存在则返回 None
|
||||
"""
|
||||
return await cls.get(
|
||||
@@ -221,33 +246,51 @@ class Object(ObjectBase, TableBase, table=True):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def get_by_path(cls, session, user_id: int, path: str) -> "Object | None":
|
||||
async def get_by_path(
|
||||
cls,
|
||||
session,
|
||||
user_id: UUID,
|
||||
path: str,
|
||||
username: str,
|
||||
) -> "Object | None":
|
||||
"""
|
||||
根据路径获取对象
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户ID
|
||||
:param path: 路径,如 "/" 或 "/docs/images"
|
||||
:param user_id: 用户UUID
|
||||
:param path: 路径,如 "/username" 或 "/username/docs/images"
|
||||
:param username: 用户名,用于识别根目录
|
||||
:return: Object 或 None
|
||||
"""
|
||||
path = path.strip()
|
||||
if not path:
|
||||
raise ValueError("路径不能为空")
|
||||
|
||||
if path in ["/my"]:
|
||||
return await cls.get_root(session, user_id)
|
||||
# 获取用户根目录
|
||||
root = await cls.get_root(session, user_id)
|
||||
if not root:
|
||||
return None
|
||||
|
||||
# 移除开头的斜杠并分割路径
|
||||
if path.startswith("/"):
|
||||
path = path[1:]
|
||||
parts = [p for p in path.split("/") if p]
|
||||
|
||||
# 空路径 -> 返回根目录
|
||||
if not parts:
|
||||
return await cls.get_root(session, user_id)
|
||||
return root
|
||||
|
||||
# 从根目录开始遍历
|
||||
current = await cls.get_root(session, user_id)
|
||||
# 检查第一部分是否是用户名(根目录名)
|
||||
if parts[0] == username:
|
||||
# 路径以用户名开头,如 /admin/docs
|
||||
if len(parts) == 1:
|
||||
# 只有用户名,返回根目录
|
||||
return root
|
||||
# 去掉用户名部分,从第二个部分开始遍历
|
||||
parts = parts[1:]
|
||||
|
||||
# 从根目录开始遍历剩余路径
|
||||
current = root
|
||||
for part in parts:
|
||||
if not current:
|
||||
return None
|
||||
@@ -262,13 +305,13 @@ class Object(ObjectBase, TableBase, table=True):
|
||||
return current
|
||||
|
||||
@classmethod
|
||||
async def get_children(cls, session, user_id: int, parent_id: int) -> list["Object"]:
|
||||
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
|
||||
"""
|
||||
获取目录下的所有子对象
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user_id: 用户ID
|
||||
:param parent_id: 父目录ID
|
||||
:param user_id: 用户UUID
|
||||
:param parent_id: 父目录UUID
|
||||
:return: 子对象列表
|
||||
"""
|
||||
return await cls.get(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,7 +22,7 @@ class Order(TableBase, table=True):
|
||||
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="orders")
|
||||
@@ -1,7 +1,10 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -26,8 +29,8 @@ class Share(TableBase, table=True):
|
||||
password: str | None = Field(default=None, max_length=255)
|
||||
"""分享密码(加密后)"""
|
||||
|
||||
object_id: int = Field(foreign_key="object.id", index=True)
|
||||
"""关联的对象ID"""
|
||||
object_id: UUID = Field(foreign_key="object.id", index=True)
|
||||
"""关联的对象UUID"""
|
||||
|
||||
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||
"""浏览次数"""
|
||||
@@ -51,8 +54,8 @@ class Share(TableBase, table=True):
|
||||
"""兑换此分享所需的积分"""
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True)
|
||||
"""创建分享的用户ID"""
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True)
|
||||
"""创建分享的用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="shares")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Index
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -21,8 +24,8 @@ class SourceLink(TableBase, table=True):
|
||||
"""通过此链接的下载次数"""
|
||||
|
||||
# 外键
|
||||
object_id: int = Field(foreign_key="object.id", index=True)
|
||||
"""关联的对象ID(必须是文件类型)"""
|
||||
object_id: UUID = Field(foreign_key="object.id", index=True)
|
||||
"""关联的对象UUID(必须是文件类型)"""
|
||||
|
||||
# 关系
|
||||
object: "Object" = Relationship(back_populates="source_links")
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
|
||||
from .base import TableBase
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
@@ -17,7 +19,7 @@ class StoragePack(TableBase, table=True):
|
||||
size: int = Field(description="容量包大小(字节)")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="storage_packs")
|
||||
@@ -1,9 +1,12 @@
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
from .base import TableBase
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
@@ -19,7 +22,7 @@ class Tag(TableBase, table=True):
|
||||
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="tags")
|
||||
@@ -1,9 +1,12 @@
|
||||
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, CheckConstraint
|
||||
from .base import TableBase
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from sqlmodel import Field, Relationship, CheckConstraint
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
from .download import Download
|
||||
@@ -22,7 +25,7 @@ class Task(TableBase, table=True):
|
||||
props: str | None = Field(default=None, description="任务属性 (JSON格式)")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="tasks")
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import TableBase, SQLModelBase
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
@@ -75,6 +76,19 @@ class LoginRequest(SQLModelBase):
|
||||
"""两步验证代码"""
|
||||
|
||||
|
||||
class RegisterRequest(SQLModelBase):
|
||||
"""注册请求 DTO"""
|
||||
|
||||
username: str
|
||||
"""用户名,唯一,一经注册不可更改"""
|
||||
|
||||
password: str
|
||||
"""用户密码"""
|
||||
|
||||
captcha: str | None = None
|
||||
"""验证码"""
|
||||
|
||||
|
||||
class WebAuthnInfo(SQLModelBase):
|
||||
"""WebAuthn 信息 DTO"""
|
||||
|
||||
@@ -116,8 +130,8 @@ class TokenResponse(SQLModelBase):
|
||||
class UserResponse(UserBase):
|
||||
"""用户响应 DTO"""
|
||||
|
||||
id: int
|
||||
"""用户ID"""
|
||||
id: UUID
|
||||
"""用户UUID"""
|
||||
|
||||
nickname: str | None = None
|
||||
"""用户昵称"""
|
||||
@@ -141,8 +155,8 @@ class UserResponse(UserBase):
|
||||
class UserPublic(UserBase):
|
||||
"""用户公开信息 DTO,用于 API 响应"""
|
||||
|
||||
id: int | None = None
|
||||
"""用户ID"""
|
||||
id: UUID | None = None
|
||||
"""用户UUID"""
|
||||
|
||||
nick: str | None = None
|
||||
"""昵称"""
|
||||
@@ -156,8 +170,8 @@ class UserPublic(UserBase):
|
||||
group_expires: datetime | None = None
|
||||
"""用户组过期时间"""
|
||||
|
||||
group_id: int | None = None
|
||||
"""所属用户组ID"""
|
||||
group_id: UUID | None = None
|
||||
"""所属用户组UUID"""
|
||||
|
||||
created_at: datetime | None = None
|
||||
"""创建时间"""
|
||||
@@ -187,8 +201,8 @@ class UserSettingResponse(SQLModelBase):
|
||||
two_factor: bool = False
|
||||
"""是否启用两步验证"""
|
||||
|
||||
uid: int = 0
|
||||
"""用户UID"""
|
||||
uid: UUID | None = None
|
||||
"""用户UUID"""
|
||||
|
||||
|
||||
# 前向引用导入
|
||||
@@ -202,7 +216,7 @@ UserSettingResponse.model_rebuild()
|
||||
|
||||
# ==================== 数据库模型 ====================
|
||||
|
||||
class User(UserBase, TableBase, table=True):
|
||||
class User(UserBase, UUIDTableBase, table=True):
|
||||
"""用户模型"""
|
||||
|
||||
username: str = Field(max_length=50, unique=True, index=True)
|
||||
@@ -243,11 +257,11 @@ class User(UserBase, TableBase, table=True):
|
||||
"""时区,UTC 偏移小时数"""
|
||||
|
||||
# 外键
|
||||
group_id: int = Field(foreign_key="group.id", index=True)
|
||||
"""所属用户组ID"""
|
||||
group_id: UUID = Field(foreign_key="group.id", index=True)
|
||||
"""所属用户组UUID"""
|
||||
|
||||
previous_group_id: int | None = Field(default=None, foreign_key="group.id")
|
||||
"""之前的用户组ID(用于过期后恢复)"""
|
||||
previous_group_id: UUID | None = Field(default=None, foreign_key="group.id")
|
||||
"""之前的用户组UUID(用于过期后恢复)"""
|
||||
|
||||
|
||||
# 关系
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import Column, Text
|
||||
from sqlmodel import Field, Relationship
|
||||
@@ -48,8 +49,8 @@ class UserAuthn(TableBase, table=True):
|
||||
"""用户自定义的凭证名称,便于识别"""
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True)
|
||||
"""所属用户ID"""
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True)
|
||||
"""所属用户UUID"""
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="authns")
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
|
||||
|
||||
from .base import TableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -18,7 +21,7 @@ class WebDAV(TableBase, table=True):
|
||||
use_proxy: bool = Field(default=False, description="是否使用代理下载")
|
||||
|
||||
# 外键
|
||||
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
|
||||
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
|
||||
|
||||
# 关系
|
||||
user: "User" = Relationship(back_populates="webdavs")
|
||||
@@ -37,7 +37,7 @@ async def router_directory_get(
|
||||
:param path: 目录路径
|
||||
:return: 目录内容
|
||||
"""
|
||||
folder = await Object.get_by_path(session, user.id, path or "/")
|
||||
folder = await Object.get_by_path(session, user.id, path or "/", user.username)
|
||||
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="目录不存在")
|
||||
@@ -50,7 +50,7 @@ async def router_directory_get(
|
||||
|
||||
objects = [
|
||||
ObjectResponse(
|
||||
id=str(child.id),
|
||||
id=child.id,
|
||||
name=child.name,
|
||||
path=f"/{child.name}", # TODO: 完整路径
|
||||
thumb=False,
|
||||
@@ -63,7 +63,7 @@ async def router_directory_get(
|
||||
for child in children
|
||||
]
|
||||
|
||||
policy=PolicyResponse(
|
||||
policy_response = PolicyResponse(
|
||||
id=str(policy.id),
|
||||
name=policy.name,
|
||||
type=policy.type.value,
|
||||
@@ -71,9 +71,10 @@ async def router_directory_get(
|
||||
)
|
||||
|
||||
return DirectoryResponse(
|
||||
parent=str(folder.parent_id) if folder.parent_id else None,
|
||||
id=folder.id,
|
||||
parent=folder.parent_id,
|
||||
objects=objects,
|
||||
policy=policy,
|
||||
policy=policy_response,
|
||||
)
|
||||
|
||||
|
||||
@@ -91,26 +92,20 @@ async def router_directory_create(
|
||||
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 创建请求
|
||||
:param request: 创建请求(包含 parent_id UUID 和 name)
|
||||
:return: 创建结果
|
||||
"""
|
||||
path = request.path.strip()
|
||||
if not path or path == "/":
|
||||
raise HTTPException(status_code=400, detail="路径不能为空或根目录")
|
||||
# 验证目录名称
|
||||
name = request.name.strip()
|
||||
if not name:
|
||||
raise HTTPException(status_code=400, detail="目录名称不能为空")
|
||||
|
||||
# 解析路径
|
||||
if path.startswith("/"):
|
||||
path = path[1:]
|
||||
parts = [p for p in path.split("/") if p]
|
||||
if "/" in name or "\\" in name:
|
||||
raise HTTPException(status_code=400, detail="目录名称不能包含斜杠")
|
||||
|
||||
if not parts:
|
||||
raise HTTPException(status_code=400, detail="无效的目录路径")
|
||||
|
||||
new_folder_name = parts[-1]
|
||||
parent_path = "/" + "/".join(parts[:-1]) if len(parts) > 1 else "/"
|
||||
|
||||
parent = await Object.get_by_path(session, user.id, parent_path)
|
||||
if not parent:
|
||||
# 通过 UUID 获取父目录
|
||||
parent = await Object.get(session, Object.id == request.parent_id)
|
||||
if not parent or parent.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="父目录不存在")
|
||||
|
||||
if not parent.is_folder:
|
||||
@@ -121,25 +116,29 @@ async def router_directory_create(
|
||||
session,
|
||||
(Object.owner_id == user.id) &
|
||||
(Object.parent_id == parent.id) &
|
||||
(Object.name == new_folder_name)
|
||||
(Object.name == name)
|
||||
)
|
||||
if existing:
|
||||
raise HTTPException(status_code=409, detail="同名文件或目录已存在")
|
||||
|
||||
policy_id = request.policy_id if request.policy_id else parent.policy_id
|
||||
parent_id = parent.id # 在 save 前保存
|
||||
|
||||
new_folder = await Object(
|
||||
name=new_folder_name,
|
||||
new_folder = Object(
|
||||
name=name,
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=user.id,
|
||||
parent_id=parent.id,
|
||||
parent_id=parent_id,
|
||||
policy_id=policy_id,
|
||||
).save(session)
|
||||
)
|
||||
new_folder_id = new_folder.id # 在 save 前保存 UUID
|
||||
new_folder_name = new_folder.name
|
||||
await new_folder.save(session)
|
||||
|
||||
return response.ResponseModel(
|
||||
data={
|
||||
"id": new_folder.id,
|
||||
"name": new_folder.name,
|
||||
"path": f"{parent_path.rstrip('/')}/{new_folder_name}",
|
||||
"id": new_folder_id,
|
||||
"name": new_folder_name,
|
||||
"parent_id": parent_id,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from middleware.auth import SignRequired
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from middleware.auth import AuthRequired
|
||||
from middleware.dependencies import SessionDep
|
||||
from models import Object, ObjectDeleteRequest, ObjectMoveRequest, User
|
||||
from models.response import ResponseModel
|
||||
|
||||
object_router = APIRouter(
|
||||
@@ -7,41 +12,106 @@ object_router = APIRouter(
|
||||
tags=["object"]
|
||||
)
|
||||
|
||||
|
||||
@object_router.delete(
|
||||
path='/',
|
||||
summary='删除对象',
|
||||
description='Delete an object endpoint.',
|
||||
dependencies=[Depends(SignRequired)]
|
||||
description='删除一个或多个对象(文件或目录)',
|
||||
)
|
||||
def router_object_delete() -> ResponseModel:
|
||||
async def router_object_delete(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
request: ObjectDeleteRequest,
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete an object endpoint.
|
||||
删除对象端点
|
||||
|
||||
Returns:
|
||||
ResponseModel: A model containing the response data for the object deletion.
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 删除请求(包含待删除对象的UUID列表)
|
||||
:return: 删除结果
|
||||
"""
|
||||
pass
|
||||
deleted_count = 0
|
||||
|
||||
for obj_id in request.ids:
|
||||
obj = await Object.get(session, Object.id == obj_id)
|
||||
if obj and obj.owner_id == user.id:
|
||||
# TODO: 递归删除子对象(如果是目录)
|
||||
# TODO: 更新用户存储空间
|
||||
await obj.delete(session)
|
||||
deleted_count += 1
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"deleted": deleted_count,
|
||||
"total": len(request.ids),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@object_router.patch(
|
||||
path='/',
|
||||
summary='移动对象',
|
||||
description='Move an object endpoint.',
|
||||
dependencies=[Depends(SignRequired)]
|
||||
description='移动一个或多个对象到目标目录',
|
||||
)
|
||||
def router_object_move() -> ResponseModel:
|
||||
async def router_object_move(
|
||||
session: SessionDep,
|
||||
user: Annotated[User, Depends(AuthRequired)],
|
||||
request: ObjectMoveRequest,
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Move an object endpoint.
|
||||
移动对象端点
|
||||
|
||||
Returns:
|
||||
ResponseModel: A model containing the response data for the object move.
|
||||
:param session: 数据库会话
|
||||
:param user: 当前登录用户
|
||||
:param request: 移动请求(包含源对象UUID列表和目标目录UUID)
|
||||
:return: 移动结果
|
||||
"""
|
||||
pass
|
||||
# 验证目标目录
|
||||
dst = await Object.get(session, Object.id == request.dst_id)
|
||||
if not dst or dst.owner_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="目标目录不存在")
|
||||
|
||||
if not dst.is_folder:
|
||||
raise HTTPException(status_code=400, detail="目标不是有效文件夹")
|
||||
|
||||
moved_count = 0
|
||||
|
||||
for src_id in request.src_ids:
|
||||
src = await Object.get(session, Object.id == src_id)
|
||||
if not src or src.owner_id != user.id:
|
||||
continue
|
||||
|
||||
# 检查是否移动到自身或子目录(防止循环引用)
|
||||
if src.id == dst.id:
|
||||
continue
|
||||
|
||||
# 检查目标目录下是否存在同名对象
|
||||
existing = await Object.get(
|
||||
session,
|
||||
(Object.owner_id == user.id) &
|
||||
(Object.parent_id == dst.id) &
|
||||
(Object.name == src.name)
|
||||
)
|
||||
if existing:
|
||||
continue # 跳过重名对象
|
||||
|
||||
src.parent_id = dst.id
|
||||
await src.save(session)
|
||||
moved_count += 1
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"moved": moved_count,
|
||||
"total": len(request.src_ids),
|
||||
}
|
||||
)
|
||||
|
||||
@object_router.post(
|
||||
path='/copy',
|
||||
summary='复制对象',
|
||||
description='Copy an object endpoint.',
|
||||
dependencies=[Depends(SignRequired)]
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
)
|
||||
def router_object_copy() -> ResponseModel:
|
||||
"""
|
||||
@@ -56,7 +126,7 @@ def router_object_copy() -> ResponseModel:
|
||||
path='/rename',
|
||||
summary='重命名对象',
|
||||
description='Rename an object endpoint.',
|
||||
dependencies=[Depends(SignRequired)]
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
)
|
||||
def router_object_rename() -> ResponseModel:
|
||||
"""
|
||||
@@ -71,7 +141,7 @@ def router_object_rename() -> ResponseModel:
|
||||
path='/property/{id}',
|
||||
summary='获取对象属性',
|
||||
description='Get object properties endpoint.',
|
||||
dependencies=[Depends(SignRequired)]
|
||||
dependencies=[Depends(AuthRequired)]
|
||||
)
|
||||
def router_object_property(id: str) -> ResponseModel:
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Annotated, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from sqlalchemy import and_
|
||||
from webauthn import generate_registration_options
|
||||
from webauthn.helpers import options_to_json_dict
|
||||
import pyotp
|
||||
from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
|
||||
|
||||
import models
|
||||
@@ -93,14 +93,77 @@ async def router_user_session(
|
||||
summary='用户注册',
|
||||
description='User registration endpoint.',
|
||||
)
|
||||
def router_user_register() -> models.response.ResponseModel:
|
||||
async def router_user_register(
|
||||
session: SessionDep,
|
||||
request: models.RegisterRequest,
|
||||
) -> models.response.ResponseModel:
|
||||
"""
|
||||
User registration endpoint.
|
||||
用户注册端点
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing user registration information.
|
||||
流程:
|
||||
1. 验证用户名唯一性
|
||||
2. 获取默认用户组
|
||||
3. 创建用户记录
|
||||
4. 创建以用户名命名的根目录
|
||||
|
||||
:param session: 数据库会话
|
||||
:param request: 注册请求
|
||||
:return: 注册结果
|
||||
:raises HTTPException 400: 用户名已存在
|
||||
:raises HTTPException 500: 默认用户组或存储策略不存在
|
||||
"""
|
||||
pass
|
||||
# 1. 验证用户名唯一性
|
||||
existing_user = await models.User.get(
|
||||
session,
|
||||
models.User.username == request.username
|
||||
)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="用户名已存在")
|
||||
|
||||
# 2. 获取默认用户组(从设置中读取 UUID)
|
||||
default_group_setting: models.Setting | None = await models.Setting.get(
|
||||
session,
|
||||
and_(models.Setting.type == models.SettingsType.REGISTER, models.Setting.name == "default_group")
|
||||
)
|
||||
if default_group_setting is None or not default_group_setting.value:
|
||||
raise HTTPException(status_code=500, detail="默认用户组设置不存在")
|
||||
|
||||
default_group_id = UUID(default_group_setting.value)
|
||||
default_group = await models.Group.get(session, models.Group.id == default_group_id)
|
||||
if not default_group:
|
||||
raise HTTPException(status_code=500, detail="默认用户组不存在")
|
||||
|
||||
# 3. 创建用户
|
||||
hashed_password = Password.hash(request.password)
|
||||
new_user = models.User(
|
||||
username=request.username,
|
||||
password=hashed_password,
|
||||
group_id=default_group.id,
|
||||
)
|
||||
new_user_id = new_user.id # 在 save 前保存 UUID
|
||||
new_user_username = new_user.username
|
||||
await new_user.save(session)
|
||||
|
||||
# 4. 创建以用户名命名的根目录
|
||||
default_policy = await models.Policy.get(session, models.Policy.name == "本地存储")
|
||||
if not default_policy:
|
||||
raise HTTPException(status_code=500, detail="默认存储策略不存在")
|
||||
|
||||
await models.Object(
|
||||
name=new_user_username,
|
||||
type=models.ObjectType.FOLDER,
|
||||
owner_id=new_user_id,
|
||||
parent_id=None,
|
||||
policy_id=default_policy.id,
|
||||
).save(session)
|
||||
|
||||
return models.response.ResponseModel(
|
||||
data={
|
||||
"user_id": new_user_id,
|
||||
"username": new_user_username,
|
||||
},
|
||||
msg="注册成功",
|
||||
)
|
||||
|
||||
@user_router.post(
|
||||
path='/code',
|
||||
|
||||
Reference in New Issue
Block a user