优化数据表结构
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from . import response
|
from . import response
|
||||||
|
|
||||||
from .user import User
|
from .user import User
|
||||||
|
from .user_authn import UserAuthn
|
||||||
|
|
||||||
from .download import Download
|
from .download import Download
|
||||||
from .file import File
|
from .file import File
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ async def migration() -> None:
|
|||||||
|
|
||||||
await init_default_settings()
|
await init_default_settings()
|
||||||
await init_default_group()
|
await init_default_group()
|
||||||
|
await init_default_policy()
|
||||||
await init_default_user()
|
await init_default_user()
|
||||||
|
|
||||||
log.info('数据库初始化结束')
|
log.info('数据库初始化结束')
|
||||||
@@ -214,3 +215,30 @@ async def init_default_user() -> None:
|
|||||||
|
|
||||||
log.info(f'初始管理员账号:[bold]admin[/bold]')
|
log.info(f'初始管理员账号:[bold]admin[/bold]')
|
||||||
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
||||||
|
|
||||||
|
|
||||||
|
async def init_default_policy() -> None:
|
||||||
|
from .policy import Policy, PolicyType
|
||||||
|
from .database import get_session
|
||||||
|
|
||||||
|
log.info('初始化默认存储策略...')
|
||||||
|
|
||||||
|
async for session in get_session():
|
||||||
|
# 检查默认存储策略是否存在
|
||||||
|
default_policy = await Policy.get(session, Policy.id == 1)
|
||||||
|
|
||||||
|
if not default_policy:
|
||||||
|
local_policy = Policy(
|
||||||
|
name="本地存储",
|
||||||
|
type=PolicyType.LOCAL,
|
||||||
|
server="./data",
|
||||||
|
is_private=True,
|
||||||
|
max_size=0,
|
||||||
|
auto_rename=True,
|
||||||
|
dir_name_rule="{date}/{randomkey16}",
|
||||||
|
file_name_rule="{randomkey16}_{originname}",
|
||||||
|
)
|
||||||
|
|
||||||
|
await local_policy.save(session)
|
||||||
|
|
||||||
|
log.info('已创建默认本地存储策略,存储目录:./data')
|
||||||
@@ -2,28 +2,59 @@
|
|||||||
from typing import Optional, List, TYPE_CHECKING
|
from typing import Optional, List, TYPE_CHECKING
|
||||||
from sqlmodel import Field, Relationship, text
|
from sqlmodel import Field, Relationship, text
|
||||||
from .base import TableBase
|
from .base import TableBase
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .file import File
|
from .file import File
|
||||||
from .folder import Folder
|
from .folder import Folder
|
||||||
|
|
||||||
|
class PolicyType(StrEnum):
|
||||||
|
LOCAL = "local"
|
||||||
|
S3 = "s3"
|
||||||
|
|
||||||
class Policy(TableBase, table=True):
|
class Policy(TableBase, table=True):
|
||||||
"""存储策略模型"""
|
"""存储策略模型"""
|
||||||
|
|
||||||
name: str = Field(max_length=255, unique=True, description="策略名称")
|
name: str = Field(max_length=255, unique=True)
|
||||||
type: str = Field(max_length=255, description="存储类型 (e.g. 'local', 's3')")
|
"""策略名称"""
|
||||||
server: str | None = Field(default=None, max_length=255, description="服务器地址(本地策略为路径)")
|
|
||||||
bucket_name: str | None = Field(default=None, max_length=255, description="存储桶名称")
|
type: PolicyType
|
||||||
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否为私有空间")
|
"""存储策略类型"""
|
||||||
base_url: str | None = Field(default=None, max_length=255, description="访问文件的基础URL")
|
|
||||||
access_key: str | None = Field(default=None, description="Access Key")
|
server: str | None = Field(default=None, max_length=255)
|
||||||
secret_key: str | None = Field(default=None, description="Secret Key")
|
"""服务器地址(本地策略为绝对路径)"""
|
||||||
max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="允许上传的最大文件尺寸(字节)")
|
|
||||||
auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否自动重命名")
|
bucket_name: str | None = Field(default=None, max_length=255)
|
||||||
dir_name_rule: str | None = Field(default=None, max_length=255, description="目录命名规则")
|
"""存储桶名称"""
|
||||||
file_name_rule: str | None = Field(default=None, max_length=255, description="文件命名规则")
|
|
||||||
is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否开启源链接访问")
|
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")})
|
||||||
options: str | None = Field(default=None, description="其他选项 (JSON格式)")
|
"""是否为私有空间"""
|
||||||
|
|
||||||
|
base_url: str | None = Field(default=None, max_length=255)
|
||||||
|
"""访问文件的基础URL"""
|
||||||
|
|
||||||
|
access_key: str | None = Field(default=None)
|
||||||
|
"""Access Key"""
|
||||||
|
|
||||||
|
secret_key: str | None = Field(default=None)
|
||||||
|
"""Secret Key"""
|
||||||
|
max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"})
|
||||||
|
"""允许上传的最大文件尺寸(字节)"""
|
||||||
|
|
||||||
|
auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
|
"""是否自动重命名"""
|
||||||
|
|
||||||
|
dir_name_rule: str | None = Field(default=None, max_length=255)
|
||||||
|
"""目录命名规则"""
|
||||||
|
|
||||||
|
file_name_rule: str | None = Field(default=None, max_length=255)
|
||||||
|
"""文件命名规则"""
|
||||||
|
|
||||||
|
is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")})
|
||||||
|
"""是否开启源链接访问"""
|
||||||
|
|
||||||
|
options: str | None = Field(default=None)
|
||||||
|
"""其他选项 (JSON格式)"""
|
||||||
# options 示例: {"token":"","file_type":null,"mimetype":"","od_redirect":"http://127.0.0.1:8000/...","chunk_size":52428800,"s3_path_style":false}
|
# options 示例: {"token":"","file_type":null,"mimetype":"","od_redirect":"http://127.0.0.1:8000/...","chunk_size":52428800,"s3_path_style":false}
|
||||||
|
|
||||||
# 关系
|
# 关系
|
||||||
|
|||||||
@@ -1,17 +0,0 @@
|
|||||||
"""
|
|
||||||
请求模型定义
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Literal, Union, Optional
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
|
||||||
"""
|
|
||||||
登录请求模型
|
|
||||||
"""
|
|
||||||
username: str = Field(..., description="用户名或邮箱")
|
|
||||||
password: str = Field(..., description="用户密码")
|
|
||||||
captcha: str | None = Field(None, description="验证码")
|
|
||||||
twoFaCode: str | None = Field(None, description="两步验证代码")
|
|
||||||
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|||||||
from .storage_pack import StoragePack
|
from .storage_pack import StoragePack
|
||||||
from .tag import Tag
|
from .tag import Tag
|
||||||
from .task import Task
|
from .task import Task
|
||||||
|
from .user_authn import UserAuthn
|
||||||
from .webdav import WebDAV
|
from .webdav import WebDAV
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@@ -27,6 +28,15 @@ Option 需求
|
|||||||
- 切换到不同存储策略是否提醒
|
- 切换到不同存储策略是否提醒
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
登录请求模型
|
||||||
|
"""
|
||||||
|
username: str = Field(..., description="用户名或邮箱")
|
||||||
|
password: str = Field(..., description="用户密码")
|
||||||
|
captcha: str | None = Field(None, description="验证码")
|
||||||
|
twoFaCode: str | None = Field(None, description="两步验证代码")
|
||||||
|
|
||||||
class WebAuthnInfo(BaseModel):
|
class WebAuthnInfo(BaseModel):
|
||||||
"""WebAuthn 信息模型"""
|
"""WebAuthn 信息模型"""
|
||||||
|
|
||||||
@@ -75,8 +85,6 @@ class User(TableBase, table=True):
|
|||||||
options: str | None = Field(default=None)
|
options: str | None = Field(default=None)
|
||||||
"""[TODO] 用户个人设置 需要更改,参考上方的需求"""
|
"""[TODO] 用户个人设置 需要更改,参考上方的需求"""
|
||||||
|
|
||||||
authn: str | None = Field(default=None)
|
|
||||||
"""[TODO] WebAuthn 凭证,可不存,也可设置一个或多个"""
|
|
||||||
|
|
||||||
github_open_id: str | None = Field(default=None, unique=True, index=True)
|
github_open_id: str | None = Field(default=None, unique=True, index=True)
|
||||||
"""Github OpenID"""
|
"""Github OpenID"""
|
||||||
@@ -125,6 +133,7 @@ class User(TableBase, table=True):
|
|||||||
tags: list["Tag"] = Relationship(back_populates="user")
|
tags: list["Tag"] = Relationship(back_populates="user")
|
||||||
tasks: list["Task"] = Relationship(back_populates="user")
|
tasks: list["Task"] = Relationship(back_populates="user")
|
||||||
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
||||||
|
authns: list["UserAuthn"] = Relationship(back_populates="user")
|
||||||
|
|
||||||
def to_public(self) -> "UserPublic":
|
def to_public(self) -> "UserPublic":
|
||||||
"""转换为公开 DTO,排除敏感字段"""
|
"""转换为公开 DTO,排除敏感字段"""
|
||||||
|
|||||||
43
models/user_authn.py
Normal file
43
models/user_authn.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy import Column, Text
|
||||||
|
from sqlmodel import Field, Relationship
|
||||||
|
|
||||||
|
from .base import TableBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .user import User
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuthn(TableBase, table=True):
|
||||||
|
"""用户 WebAuthn 凭证模型,与 User 为多对一关系"""
|
||||||
|
|
||||||
|
__tablename__ = "user_authn"
|
||||||
|
|
||||||
|
credential_id: str = Field(max_length=255, unique=True, index=True)
|
||||||
|
"""凭证 ID,Base64 编码"""
|
||||||
|
|
||||||
|
credential_public_key: str = Field(sa_column=Column(Text))
|
||||||
|
"""凭证公钥,Base64 编码"""
|
||||||
|
|
||||||
|
sign_count: int = Field(default=0, ge=0)
|
||||||
|
"""签名计数器,用于防重放攻击"""
|
||||||
|
|
||||||
|
credential_device_type: str = Field(max_length=32)
|
||||||
|
"""凭证设备类型:'single_device' 或 'multi_device'"""
|
||||||
|
|
||||||
|
credential_backed_up: bool = Field(default=False)
|
||||||
|
"""凭证是否已备份"""
|
||||||
|
|
||||||
|
transports: str | None = Field(default=None, max_length=255)
|
||||||
|
"""支持的传输方式,逗号分隔,如 'usb,nfc,ble,internal'"""
|
||||||
|
|
||||||
|
name: str | None = Field(default=None, max_length=100)
|
||||||
|
"""用户自定义的凭证名称,便于识别"""
|
||||||
|
|
||||||
|
# 外键
|
||||||
|
user_id: int = Field(foreign_key="user.id", index=True)
|
||||||
|
"""所属用户ID"""
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
user: "User" = Relationship(back_populates="authns")
|
||||||
@@ -1,55 +1,79 @@
|
|||||||
import secrets
|
import secrets
|
||||||
|
from loguru import logger
|
||||||
from argon2 import PasswordHasher
|
from argon2 import PasswordHasher
|
||||||
|
from argon2.exceptions import VerifyMismatchError
|
||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
_ph = PasswordHasher()
|
||||||
|
|
||||||
|
class PasswordStatus(StrEnum):
|
||||||
|
"""密码校验状态枚举"""
|
||||||
|
|
||||||
|
VALID = "valid"
|
||||||
|
"""密码校验通过"""
|
||||||
|
|
||||||
|
INVALID = "invalid"
|
||||||
|
"""密码校验失败"""
|
||||||
|
|
||||||
|
EXPIRED = "expired"
|
||||||
|
"""密码哈希已过时,建议重新哈希"""
|
||||||
|
|
||||||
class Password:
|
class Password:
|
||||||
|
"""密码处理工具类,包含密码生成、哈希和验证功能"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate(
|
def generate(
|
||||||
length: int = 16,
|
length: int = 8
|
||||||
url_safe: bool = False
|
) -> str:
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
生成一个随机密码。
|
生成指定长度的随机密码。
|
||||||
|
|
||||||
:param length: 密码长度,默认为 `16` 个字符。
|
:param length: 密码长度
|
||||||
:param url_safe: 是否生成URL安全的密码,默认为 `False` 。
|
:type length: int
|
||||||
:return: 生成的随机密码字符串。
|
:return: 随机密码
|
||||||
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
if url_safe:
|
|
||||||
return secrets.token_urlsafe(length)
|
|
||||||
return secrets.token_hex(length)
|
return secrets.token_hex(length)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def hash(
|
def hash(
|
||||||
password: str,
|
password: str
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
生成密码的Argon2哈希值。
|
使用 Argon2 生成密码的哈希值。
|
||||||
|
|
||||||
:param password: 要哈希的密码。
|
返回的哈希字符串已经包含了所有需要验证的信息(盐、算法参数等)。
|
||||||
:return: 使用Argon2算法生成的密码哈希。
|
|
||||||
:rtype: str
|
:param password: 需要哈希的原始密码
|
||||||
|
:return: Argon2 哈希字符串
|
||||||
"""
|
"""
|
||||||
ph = PasswordHasher()
|
return _ph.hash(password)
|
||||||
return ph.hash(password)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def verify(
|
def verify(
|
||||||
stored_password: str,
|
hash: str,
|
||||||
provided_password: str,
|
password: str
|
||||||
) -> bool:
|
) -> PasswordStatus:
|
||||||
"""
|
"""
|
||||||
验证存储的Argon2密码哈希值与用户提供的密码是否匹配。
|
验证存储的 Argon2 哈希值与用户提供的密码是否匹配。
|
||||||
|
|
||||||
:param stored_password: 存储的Argon2密码哈希值。
|
:param hash: 数据库中存储的 Argon2 哈希字符串
|
||||||
:param provided_password: 用户提供的密码。
|
:param password: 用户本次提供的密码
|
||||||
|
:return: 如果密码匹配返回 True, 否则返回 False
|
||||||
:return: 如果密码匹配返回 `True` ,否则返回 `False` 。
|
|
||||||
:rtype: bool
|
|
||||||
"""
|
"""
|
||||||
ph = PasswordHasher()
|
|
||||||
try:
|
try:
|
||||||
ph.verify(stored_password, provided_password)
|
# verify 函数会自动解析 stored_password 中的盐和参数
|
||||||
return True
|
_ph.verify(hash, password)
|
||||||
except:
|
|
||||||
return False
|
# 检查哈希参数是否已过时。如果返回True,
|
||||||
|
# 意味着你应该使用新的参数重新哈希密码并更新存储。
|
||||||
|
# 这是一个很好的实践,可以随着时间推移增强安全性。
|
||||||
|
if _ph.check_needs_rehash(hash):
|
||||||
|
logger.warning("密码哈希参数已过时,建议重新哈希并更新。")
|
||||||
|
return PasswordStatus.EXPIRED
|
||||||
|
|
||||||
|
return PasswordStatus.VALID
|
||||||
|
except VerifyMismatchError:
|
||||||
|
# 这是预期的异常,当密码不匹配时触发。
|
||||||
|
return PasswordStatus.INVALID
|
||||||
|
# 其他异常(如哈希格式错误)应该传播,让调用方感知系统问题
|
||||||
@@ -36,7 +36,7 @@ async def router_user_session(
|
|||||||
|
|
||||||
result = await service.user.Login(
|
result = await service.user.Login(
|
||||||
session,
|
session,
|
||||||
models.request.LoginRequest(username=username, password=password),
|
models.user.LoginRequest(username=username, password=password),
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(result, models.response.TokenModel):
|
if isinstance(result, models.response.TokenModel):
|
||||||
|
|||||||
@@ -2,14 +2,13 @@ from loguru import logger as log
|
|||||||
from sqlalchemy import and_
|
from sqlalchemy import and_
|
||||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from models.request import LoginRequest
|
|
||||||
from models.response import TokenModel
|
from models.response import TokenModel
|
||||||
from models.setting import Setting
|
from models import user
|
||||||
from models.user import User
|
from models.user import User
|
||||||
from pkg.JWT.jwt import create_access_token, create_refresh_token
|
from pkg.JWT.jwt import create_access_token, create_refresh_token
|
||||||
|
|
||||||
|
|
||||||
async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenModel | bool | None:
|
async def Login(session: AsyncSession, login_request: user.LoginRequest) -> TokenModel | bool | None:
|
||||||
"""
|
"""
|
||||||
根据账号密码进行登录。
|
根据账号密码进行登录。
|
||||||
|
|
||||||
@@ -32,26 +31,26 @@ async def Login(session: AsyncSession, login_request: LoginRequest) -> TokenMode
|
|||||||
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
# is_captcha_required = captcha_setting and captcha_setting.value == "1"
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
user = await User.get(session, User.username == login_request.username)
|
current_user = await User.get(session, User.username == login_request.username, fetch_mode="one")
|
||||||
|
|
||||||
# 验证用户是否存在
|
# 验证用户是否存在
|
||||||
if not user:
|
if not current_user:
|
||||||
log.debug(f"Cannot find user with username: {login_request.username}")
|
log.debug(f"Cannot find user with username: {login_request.username}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 验证密码是否正确
|
# 验证密码是否正确
|
||||||
if not Password.verify(user.password, login_request.password):
|
if not Password.verify(current_user.password, login_request.password):
|
||||||
log.debug(f"Password verification failed for user: {login_request.username}")
|
log.debug(f"Password verification failed for user: {login_request.username}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 验证用户是否可登录
|
# 验证用户是否可登录
|
||||||
if not user.status:
|
if not current_user.status:
|
||||||
# 未完成注册 or 账号已被封禁
|
# 未完成注册 or 账号已被封禁
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 创建令牌
|
# 创建令牌
|
||||||
access_token, access_expire = create_access_token(data={'sub': user.username})
|
access_token, access_expire = create_access_token(data={'sub': current_user.username})
|
||||||
refresh_token, refresh_expire = create_refresh_token(data={'sub': user.username})
|
refresh_token, refresh_expire = create_refresh_token(data={'sub': current_user.username})
|
||||||
|
|
||||||
return TokenModel(
|
return TokenModel(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
|||||||
Reference in New Issue
Block a user