feat: 更新模型以支持 UUID,添加注册请求 DTO,重构用户注册逻辑

This commit is contained in:
2025-12-19 16:32:49 +08:00
parent e031f3cc40
commit 922692b820
17 changed files with 380 additions and 147 deletions

View File

@@ -2,6 +2,7 @@ from . import response
from .user import (
LoginRequest,
RegisterRequest,
TokenResponse,
User,
UserBase,
@@ -21,6 +22,8 @@ from .object import (
DirectoryResponse,
Object,
ObjectBase,
ObjectDeleteRequest,
ObjectMoveRequest,
ObjectResponse,
ObjectType,
PolicyResponse,

View File

@@ -1,5 +1,8 @@
from typing import Optional, TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
@@ -31,7 +34,7 @@ class Download(DownloadBase, UUIDTableBase, table=True):
dst: str = Field(description="目标存储路径")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
task_id: int | None = Field(default=None, foreign_key="task.id", index=True, description="关联的任务ID")
node_id: int = Field(foreign_key="node.id", index=True, description="执行下载的节点ID")

View File

@@ -1,9 +1,10 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, text
from .base import TableBase, SQLModelBase
from .base import TableBase, SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
@@ -45,8 +46,8 @@ class GroupOptionsBase(SQLModelBase):
class GroupResponse(GroupBase, GroupOptionsBase):
"""用户组响应 DTO"""
id: int
"""用户组ID"""
id: UUID
"""用户组UUID"""
allow_share: bool = False
"""是否允许分享"""
@@ -72,8 +73,8 @@ class GroupResponse(GroupBase, GroupOptionsBase):
class GroupOptions(GroupOptionsBase, TableBase, table=True):
"""用户组选项模型"""
group_id: int = Field(foreign_key="group.id", unique=True)
"""关联的用户组ID"""
group_id: UUID = Field(foreign_key="group.id", unique=True)
"""关联的用户组UUID"""
archive_download: bool = False
"""是否允许打包下载"""
@@ -97,7 +98,7 @@ class GroupOptions(GroupOptionsBase, TableBase, table=True):
group: "Group" = Relationship(back_populates="options")
class Group(GroupBase, TableBase, table=True):
class Group(GroupBase, UUIDTableBase, table=True):
"""用户组模型"""
name: str = Field(max_length=255, unique=True)

View File

@@ -25,7 +25,7 @@ default_settings: list[Setting] = [
Setting(name="siteURL", value="http://localhost", type=SettingsType.BASIC),
Setting(name="siteName", value="DiskNext", type=SettingsType.BASIC),
Setting(name="register_enabled", value="1", type=SettingsType.REGISTER),
Setting(name="default_group", value="2", type=SettingsType.REGISTER),
Setting(name="default_group", value="", type=SettingsType.REGISTER), # UUID 在组创建后更新
Setting(name="siteKeywords", value="网盘,网盘", type=SettingsType.BASIC),
Setting(name="siteDes", value="DiskNext", type=SettingsType.BASIC),
Setting(name="siteTitle", value="云星启智", type=SettingsType.BASIC),
@@ -138,24 +138,27 @@ async def init_default_settings() -> None:
async def init_default_group() -> None:
from .group import Group, GroupOptions
from .setting import Setting
from .database import get_session
log.info('初始化用户组...')
async for session in get_session():
# 未找到初始管理组时,则创建
if not await Group.get(session, Group.id == 1):
admin_group = await Group(
if not await Group.get(session, Group.name == "管理员"):
admin_group = Group(
name="管理员",
policies="1",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
admin=True,
).save(session)
assert admin_group.id is not None
)
admin_group_id = admin_group.id # 在 save 前保存 UUID
await admin_group.save(session)
await GroupOptions(
group_id=admin_group.id,
group_id=admin_group_id,
archive_download=True,
archive_task=True,
share_download=True,
@@ -166,30 +169,40 @@ async def init_default_group() -> None:
).save(session)
# 未找到初始注册会员时,则创建
if not await Group.get(session, Group.id == 2):
member_group = await Group(
if not await Group.get(session, Group.name == "注册会员"):
member_group = Group(
name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
).save(session)
assert member_group.id is not None
)
member_group_id = member_group.id # 在 save 前保存 UUID
await member_group.save(session)
await GroupOptions(
group_id=member_group.id,
group_id=member_group_id,
share_download=True,
).save(session)
# 更新 default_group 设置为注册会员组的 UUID
default_group_setting = await Setting.get(session, Setting.name == "default_group")
if default_group_setting:
default_group_setting.value = str(member_group_id)
await default_group_setting.save(session)
# 未找到初始游客组时,则创建
if not await Group.get(session, Group.id == 3):
guest_group = await Group(
if not await Group.get(session, Group.name == "游客"):
guest_group = Group(
name="游客",
policies="[]",
share_enabled=False,
web_dav_enabled=False,
).save(session)
assert guest_group.id is not None
)
guest_group_id = guest_group.id # 在 save 前保存 UUID
await guest_group.save(session)
await GroupOptions(
group_id=guest_group.id,
group_id=guest_group_id,
share_download=True,
).save(session)
@@ -203,11 +216,11 @@ async def init_default_user() -> None:
async for session in get_session():
# 检查管理员用户是否存在
admin_user = await User.get(session, User.id == 1)
admin_user = await User.get(session, User.username == "admin")
if not admin_user:
# 获取管理员组
admin_group = await Group.get(session, Group.id == 1)
admin_group = await Group.get(session, Group.name == "管理员")
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
@@ -215,19 +228,22 @@ async def init_default_user() -> None:
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = await User(
admin_user = User(
username="admin",
nick="admin",
nickname="admin",
status=True,
group_id=admin_group.id,
password=hashed_admin_password,
).save(session)
)
admin_user_id = admin_user.id # 在 save 前保存 UUID
admin_username = admin_user.username
await admin_user.save(session)
# 为管理员创建根目录(使用默认存储策略
# 为管理员创建根目录(使用用户名作为目录名
await Object(
name="my",
name=admin_username,
type=ObjectType.FOLDER,
owner_id=admin_user.id,
owner_id=admin_user_id,
parent_id=None,
policy_id=1, # 默认本地存储策略
).save(session)
@@ -244,7 +260,7 @@ async def init_default_policy() -> None:
async for session in get_session():
# 检查默认存储策略是否存在
default_policy = await Policy.get(session, Policy.id == 1)
default_policy = await Policy.get(session, Policy.name == "本地存储")
if not default_policy:
local_policy = Policy(

View File

@@ -1,11 +1,12 @@
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Optional
from uuid import UUID
from enum import StrEnum
from sqlmodel import Field, Relationship, UniqueConstraint, CheckConstraint, Index
from .base import TableBase, SQLModelBase
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .user import User
@@ -40,18 +41,38 @@ class ObjectBase(SQLModelBase):
class DirectoryCreateRequest(SQLModelBase):
"""创建目录请求 DTO"""
path: str
"""目录路径,如 /docs/images"""
parent_id: UUID
"""目录UUID"""
name: str
"""目录名称"""
policy_id: int | None = None
"""存储策略ID不指定则继承父目录"""
class ObjectMoveRequest(SQLModelBase):
"""移动对象请求 DTO"""
src_ids: list[UUID]
"""源对象UUID列表"""
dst_id: UUID
"""目标文件夹UUID"""
class ObjectDeleteRequest(SQLModelBase):
"""删除对象请求 DTO"""
ids: list[UUID]
"""待删除对象UUID列表"""
class ObjectResponse(ObjectBase):
"""对象响应 DTO"""
id: str
"""对象ID"""
id: UUID
"""对象UUID"""
path: str
"""对象路径"""
@@ -91,8 +112,11 @@ class PolicyResponse(SQLModelBase):
class DirectoryResponse(SQLModelBase):
"""目录响应 DTO"""
parent: str | None = None
"""父目录ID根目录为None"""
id: UUID
"""当前目录UUID"""
parent: UUID | None = None
"""父目录UUID根目录为None"""
objects: list[ObjectResponse] = []
"""目录下的对象列表"""
@@ -103,16 +127,17 @@ class DirectoryResponse(SQLModelBase):
# ==================== 数据库模型 ====================
class Object(ObjectBase, TableBase, table=True):
class Object(ObjectBase, UUIDTableBase, table=True):
"""
统一对象模型
合并了原有的 File 和 Folder 模型,通过 type 字段区分文件和目录。
根目录规则:
- 每个用户有一个显式根目录对象name="my", parent_id=NULL
- 每个用户有一个显式根目录对象name=用户的username, parent_id=NULL
- 用户创建的文件/文件夹的 parent_id 指向根目录或其他文件夹的 id
- 根目录的 policy_id 指定用户默认存储策略
- 路径格式:/username/path/to/file如 /admin/docs/readme.md
"""
__table_args__ = (
@@ -158,11 +183,11 @@ class Object(ObjectBase, TableBase, table=True):
# ==================== 外键 ====================
parent_id: int | None = Field(default=None, foreign_key="object.id", index=True)
"""父目录IDNULL 表示这是用户的根目录"""
parent_id: UUID | None = Field(default=None, foreign_key="object.id", index=True)
"""父目录UUIDNULL 表示这是用户的根目录"""
owner_id: int = Field(foreign_key="user.id", index=True)
"""所有者用户ID"""
owner_id: UUID = Field(foreign_key="user.id", index=True)
"""所有者用户UUID"""
policy_id: int = Field(foreign_key="policy.id", index=True)
"""存储策略ID文件直接使用目录作为子文件的默认策略"""
@@ -207,12 +232,12 @@ class Object(ObjectBase, TableBase, table=True):
# ==================== 业务方法 ====================
@classmethod
async def get_root(cls, session, user_id: int) -> "Object | None":
async def get_root(cls, session, user_id: UUID) -> "Object | None":
"""
获取用户的根目录
:param session: 数据库会话
:param user_id: 用户ID
:param user_id: 用户UUID
:return: 根目录对象,不存在则返回 None
"""
return await cls.get(
@@ -221,33 +246,51 @@ class Object(ObjectBase, TableBase, table=True):
)
@classmethod
async def get_by_path(cls, session, user_id: int, path: str) -> "Object | None":
async def get_by_path(
cls,
session,
user_id: UUID,
path: str,
username: str,
) -> "Object | None":
"""
根据路径获取对象
:param session: 数据库会话
:param user_id: 用户ID
:param path: 路径,如 "/""/docs/images"
:param user_id: 用户UUID
:param path: 路径,如 "/username""/username/docs/images"
:param username: 用户名,用于识别根目录
:return: Object 或 None
"""
path = path.strip()
if not path:
raise ValueError("路径不能为空")
if path in ["/my"]:
return await cls.get_root(session, user_id)
# 获取用户根目录
root = await cls.get_root(session, user_id)
if not root:
return None
# 移除开头的斜杠并分割路径
if path.startswith("/"):
path = path[1:]
parts = [p for p in path.split("/") if p]
# 空路径 -> 返回根目录
if not parts:
return await cls.get_root(session, user_id)
return root
# 从根目录开始遍历
current = await cls.get_root(session, user_id)
# 检查第一部分是否是用户名(根目录名)
if parts[0] == username:
# 路径以用户名开头,如 /admin/docs
if len(parts) == 1:
# 只有用户名,返回根目录
return root
# 去掉用户名部分,从第二个部分开始遍历
parts = parts[1:]
# 从根目录开始遍历剩余路径
current = root
for part in parts:
if not current:
return None
@@ -262,13 +305,13 @@ class Object(ObjectBase, TableBase, table=True):
return current
@classmethod
async def get_children(cls, session, user_id: int, parent_id: int) -> list["Object"]:
async def get_children(cls, session, user_id: UUID, parent_id: UUID) -> list["Object"]:
"""
获取目录下的所有子对象
:param session: 数据库会话
:param user_id: 用户ID
:param parent_id: 父目录ID
:param user_id: 用户UUID
:param parent_id: 父目录UUID
:return: 子对象列表
"""
return await cls.get(

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import TableBase
if TYPE_CHECKING:
@@ -19,7 +22,7 @@ class Order(TableBase, table=True):
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="orders")

View File

@@ -1,7 +1,10 @@
from typing import TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, text, UniqueConstraint, Index
from .base import TableBase
if TYPE_CHECKING:
@@ -26,8 +29,8 @@ class Share(TableBase, table=True):
password: str | None = Field(default=None, max_length=255)
"""分享密码(加密后)"""
object_id: int = Field(foreign_key="object.id", index=True)
"""关联的对象ID"""
object_id: UUID = Field(foreign_key="object.id", index=True)
"""关联的对象UUID"""
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
"""浏览次数"""
@@ -51,8 +54,8 @@ class Share(TableBase, table=True):
"""兑换此分享所需的积分"""
# 外键
user_id: int = Field(foreign_key="user.id", index=True)
"""创建分享的用户ID"""
user_id: UUID = Field(foreign_key="user.id", index=True)
"""创建分享的用户UUID"""
# 关系
user: "User" = Relationship(back_populates="shares")

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, Index
from .base import TableBase
if TYPE_CHECKING:
@@ -21,8 +24,8 @@ class SourceLink(TableBase, table=True):
"""通过此链接的下载次数"""
# 外键
object_id: int = Field(foreign_key="object.id", index=True)
"""关联的对象ID必须是文件类型"""
object_id: UUID = Field(foreign_key="object.id", index=True)
"""关联的对象UUID必须是文件类型"""
# 关系
object: "Object" = Relationship(back_populates="source_links")

View File

@@ -1,9 +1,11 @@
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from uuid import UUID
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import TableBase
from datetime import datetime
if TYPE_CHECKING:
from .user import User
@@ -17,7 +19,7 @@ class StoragePack(TableBase, table=True):
size: int = Field(description="容量包大小(字节)")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="storage_packs")

View File

@@ -1,9 +1,12 @@
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import TableBase
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import TableBase
if TYPE_CHECKING:
from .user import User
@@ -19,7 +22,7 @@ class Tag(TableBase, table=True):
expression: str | None = Field(default=None, description="自动标签的匹配表达式")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="tags")

View File

@@ -1,9 +1,12 @@
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, CheckConstraint
from .base import TableBase
from uuid import UUID
from datetime import datetime
from sqlmodel import Field, Relationship, CheckConstraint
from .base import TableBase
if TYPE_CHECKING:
from .user import User
from .download import Download
@@ -22,7 +25,7 @@ class Task(TableBase, table=True):
props: str | None = Field(default=None, description="任务属性 (JSON格式)")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="tasks")

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from enum import StrEnum
from typing import Literal, Optional, TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship
from .base import TableBase, SQLModelBase
from .base import SQLModelBase, UUIDTableBase
if TYPE_CHECKING:
from .group import Group
@@ -75,6 +76,19 @@ class LoginRequest(SQLModelBase):
"""两步验证代码"""
class RegisterRequest(SQLModelBase):
"""注册请求 DTO"""
username: str
"""用户名,唯一,一经注册不可更改"""
password: str
"""用户密码"""
captcha: str | None = None
"""验证码"""
class WebAuthnInfo(SQLModelBase):
"""WebAuthn 信息 DTO"""
@@ -116,8 +130,8 @@ class TokenResponse(SQLModelBase):
class UserResponse(UserBase):
"""用户响应 DTO"""
id: int
"""用户ID"""
id: UUID
"""用户UUID"""
nickname: str | None = None
"""用户昵称"""
@@ -141,8 +155,8 @@ class UserResponse(UserBase):
class UserPublic(UserBase):
"""用户公开信息 DTO用于 API 响应"""
id: int | None = None
"""用户ID"""
id: UUID | None = None
"""用户UUID"""
nick: str | None = None
"""昵称"""
@@ -156,8 +170,8 @@ class UserPublic(UserBase):
group_expires: datetime | None = None
"""用户组过期时间"""
group_id: int | None = None
"""所属用户组ID"""
group_id: UUID | None = None
"""所属用户组UUID"""
created_at: datetime | None = None
"""创建时间"""
@@ -187,8 +201,8 @@ class UserSettingResponse(SQLModelBase):
two_factor: bool = False
"""是否启用两步验证"""
uid: int = 0
"""用户UID"""
uid: UUID | None = None
"""用户UUID"""
# 前向引用导入
@@ -202,7 +216,7 @@ UserSettingResponse.model_rebuild()
# ==================== 数据库模型 ====================
class User(UserBase, TableBase, table=True):
class User(UserBase, UUIDTableBase, table=True):
"""用户模型"""
username: str = Field(max_length=50, unique=True, index=True)
@@ -243,11 +257,11 @@ class User(UserBase, TableBase, table=True):
"""时区UTC 偏移小时数"""
# 外键
group_id: int = Field(foreign_key="group.id", index=True)
"""所属用户组ID"""
group_id: UUID = Field(foreign_key="group.id", index=True)
"""所属用户组UUID"""
previous_group_id: int | None = Field(default=None, foreign_key="group.id")
"""之前的用户组ID用于过期后恢复"""
previous_group_id: UUID | None = Field(default=None, foreign_key="group.id")
"""之前的用户组UUID用于过期后恢复"""
# 关系

View File

@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlalchemy import Column, Text
from sqlmodel import Field, Relationship
@@ -48,8 +49,8 @@ class UserAuthn(TableBase, table=True):
"""用户自定义的凭证名称,便于识别"""
# 外键
user_id: int = Field(foreign_key="user.id", index=True)
"""所属用户ID"""
user_id: UUID = Field(foreign_key="user.id", index=True)
"""所属用户UUID"""
# 关系
user: "User" = Relationship(back_populates="authns")

View File

@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING
from uuid import UUID
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
from .base import TableBase
if TYPE_CHECKING:
@@ -18,7 +21,7 @@ class WebDAV(TableBase, table=True):
use_proxy: bool = Field(default=False, description="是否使用代理下载")
# 外键
user_id: int = Field(foreign_key="user.id", index=True, description="所属用户ID")
user_id: UUID = Field(foreign_key="user.id", index=True, description="所属用户UUID")
# 关系
user: "User" = Relationship(back_populates="webdavs")