数据库创建

This commit is contained in:
2025-06-22 19:26:23 +08:00
parent 6094d8219e
commit f6825b670f
31 changed files with 1494 additions and 270 deletions

2
.gitignore vendored
View File

@@ -10,3 +10,5 @@ __pycache__/
*.pyd *.pyd
*.code-workspace *.code-workspace
*.db

243
clean.py Normal file
View File

@@ -0,0 +1,243 @@
import os
import shutil
from pkg.log import log as log
import argparse
from typing import List, Tuple, Set
import time
Version = "2.1.0"
# 默认排除的目录
DEFAULT_EXCLUDE_DIRS = {"venv", "env", ".venv", ".env", "node_modules"}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description='清理 Python 缓存文件')
parser.add_argument('-p', '--path', default=os.getcwd(),
help='要清理的目录路径,默认为当前工作目录')
parser.add_argument('-y', '--yes', action='store_true',
help='自动确认所有操作,不再询问')
parser.add_argument('--no-pycache', action='store_true',
help='不清理 __pycache__ 目录')
parser.add_argument('--no-nicegui', action='store_true',
help='不清理 .nicegui 目录')
parser.add_argument('--no-testdb', action='store_true',
help='不清理 test.db 文件')
parser.add_argument('--pyc', action='store_true',
help='清理 .pyc 文件')
parser.add_argument('--pytest-cache', action='store_true',
help='清理 .pytest_cache 目录')
parser.add_argument('--exclude', type=str, default="",
help='排除的目录,多个目录用逗号分隔')
parser.add_argument('--log-file',
help='指定日志文件路径')
parser.add_argument('--dry-run', action='store_true',
help='仅列出将要删除的文件,不实际删除')
return parser.parse_args()
def confirm_action(message: str, auto_yes: str = False) -> bool:
if auto_yes:
return True
return input(f"{message} (y/N): ").lower() == 'y'
def safe_remove(path: str, dry_run: bool = False) -> Tuple[bool, str]:
"""安全删除文件或目录"""
try:
if dry_run:
return True, f"DRY RUN: 将删除 {path}"
if os.path.isdir(path):
shutil.rmtree(path)
elif os.path.isfile(path):
os.remove(path)
return True, ""
except PermissionError as e:
return False, f"权限错误: {e}"
except OSError as e:
return False, f"系统错误: {e}"
except Exception as e:
return False, f"未知错误: {e}"
def get_excluded_dirs(exclude_arg: str) -> Set[str]:
"""获取要排除的目录列表"""
result = set(DEFAULT_EXCLUDE_DIRS)
if exclude_arg:
for item in exclude_arg.split(','):
item = item.strip()
if item:
result.add(item)
return result
def clean_pycache(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]:
"""清理 __pycache__ 目录"""
log.info("开始清理 __pycache__ 目录...")
cleaned_paths = []
for dirpath, dirnames, _ in os.walk(root_dir):
# 排除指定目录
for exclude in exclude_dirs:
if exclude in dirnames:
dirnames.remove(exclude)
if "__pycache__" in dirnames:
pycache_dir = os.path.join(dirpath, "__pycache__")
success, error = safe_remove(pycache_dir, dry_run)
if success:
cleaned_paths.append(pycache_dir)
else:
log.error(f"无法清理 {pycache_dir}: {error}")
return cleaned_paths
def clean_pyc_files(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]:
"""清理 .pyc 文件"""
log.info("开始清理 .pyc 文件...")
cleaned_files = []
for dirpath, dirnames, filenames in os.walk(root_dir):
# 排除指定目录
for exclude in exclude_dirs:
if exclude in dirnames:
dirnames.remove(exclude)
for filename in filenames:
if filename.endswith('.pyc'):
file_path = os.path.join(dirpath, filename)
success, error = safe_remove(file_path, dry_run)
if success:
cleaned_files.append(file_path)
else:
log.error(f"无法清理 {file_path}: {error}")
return cleaned_files
def clean_pytest_cache(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]:
"""清理 .pytest_cache 目录"""
log.info("开始清理 .pytest_cache 目录...")
cleaned_paths = []
for dirpath, dirnames, _ in os.walk(root_dir):
# 排除指定目录
for exclude in exclude_dirs:
if exclude in dirnames:
dirnames.remove(exclude)
if ".pytest_cache" in dirnames:
cache_dir = os.path.join(dirpath, ".pytest_cache")
success, error = safe_remove(cache_dir, dry_run)
if success:
cleaned_paths.append(cache_dir)
else:
log.error(f"无法清理 {cache_dir}: {error}")
return cleaned_paths
def clean_nicegui(root_dir: str, dry_run: bool = False) -> Tuple[bool, str]:
"""清理 .nicegui 目录"""
log.info("开始清理 .nicegui 目录...")
nicegui_dir = os.path.join(root_dir, ".nicegui")
if os.path.exists(nicegui_dir) and os.path.isdir(nicegui_dir):
success, error = safe_remove(nicegui_dir, dry_run)
if success:
return True, nicegui_dir
else:
log.error(f"无法清理 {nicegui_dir}: {error}")
return False, nicegui_dir
return False, nicegui_dir
def clean_testdb(root_dir: str, dry_run: bool = False) -> Tuple[bool, str, str]:
"""清理测试数据库文件"""
log.info("开始清理 test.db 文件...")
test_db = os.path.join(root_dir, "test.db")
if os.path.exists(test_db) and os.path.isfile(test_db):
success, error = safe_remove(test_db, dry_run)
if success:
return True, test_db, ""
else:
return False, test_db, error
return False, test_db, "文件不存在"
def main():
start_time = time.time()
args = parse_args()
root_dir = os.path.abspath(args.path)
exclude_dirs = get_excluded_dirs(args.exclude)
# 设置日志文件
if args.log_file:
log.set_log_file(args.log_file)
log.title()
log.title(title=f"清理工具 Cleaner\t\tVersion:{Version}", size="h2")
print('')
if not os.path.exists(root_dir):
log.error(f"目录不存在 Directory not exists: {root_dir}")
return 1
log.info(f"清理目录 Clean Directory: {root_dir}")
if args.dry_run:
log.warning("模拟运行模式: 将只列出要删除的文件,不会实际删除")
if exclude_dirs:
log.info(f"排除目录: {', '.join(exclude_dirs)}")
try:
total_cleaned = 0
# 清理 __pycache__
if not args.no_pycache and confirm_action("是否清理 __pycache__ 目录?", args.yes):
cleaned = clean_pycache(root_dir, exclude_dirs, args.dry_run)
for path in cleaned:
log.info(f"已清理 Removed: {path}")
total_cleaned += len(cleaned)
# 清理 .pyc 文件
if args.pyc and confirm_action("是否清理 .pyc 文件?", args.yes):
cleaned = clean_pyc_files(root_dir, exclude_dirs, args.dry_run)
for path in cleaned:
log.info(f"已清理 Removed: {path}")
total_cleaned += len(cleaned)
# 清理 .pytest_cache
if args.pytest_cache and confirm_action("是否清理 .pytest_cache 目录?", args.yes):
cleaned = clean_pytest_cache(root_dir, exclude_dirs, args.dry_run)
for path in cleaned:
log.info(f"已清理 Removed: {path}")
total_cleaned += len(cleaned)
# 清理 .nicegui
if not args.no_nicegui and confirm_action("是否清理 .nicegui 目录?", args.yes):
cleaned, path = clean_nicegui(root_dir, args.dry_run)
if cleaned:
log.info(f"已清理 Removed: {path}")
total_cleaned += 1
else:
log.debug(f"未找到 Not found: {path}")
# 清理 test.db
if not args.no_testdb and confirm_action("是否清理 test.db 文件?", args.yes):
success, path, error = clean_testdb(root_dir, args.dry_run)
if success:
log.info(f"已清理 Removed: {path}")
total_cleaned += 1
elif error == "文件不存在":
log.debug(f"未找到 Not found: {path}")
else:
log.error(f"清理失败 Failed: {error}")
except KeyboardInterrupt:
log.warning("操作被用户中断")
return 1
except Exception as e:
log.error(f"错误 Error: {e}")
return 1
else:
elapsed_time = time.time() - start_time
if args.dry_run:
log.success(f"模拟清理结束,发现 {total_cleaned} 个可清理项目 (用时: {elapsed_time:.2f}秒)")
else:
log.success(f"清理结束,共清理 {total_cleaned} 个项目 (用时: {elapsed_time:.2f}秒)")
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -1,6 +1,10 @@
from fastapi import FastAPI from fastapi import FastAPI
from routers import routers from routers import routers
from pkg.conf import appmeta from pkg.conf import appmeta
from models.database import init_db
from pkg.lifespan import lifespan
lifespan.add_startup(init_db)
app = FastAPI( app = FastAPI(
title=appmeta.APP_NAME, title=appmeta.APP_NAME,
@@ -9,7 +13,7 @@ app = FastAPI(
version=appmeta.BackendVersion, version=appmeta.BackendVersion,
openapi_tags=appmeta.tags_meta, openapi_tags=appmeta.tags_meta,
license_info=appmeta.license_info, license_info=appmeta.license_info,
lifespan=lifespan.lifespan,
) )
for router in routers.Router: for router in routers.Router:

30
models/__init__.py Normal file
View File

@@ -0,0 +1,30 @@
# my_project/models/__init__.py
from . import response
# 将所有模型导入到这个包的命名空间中
from .base import BaseModel
from .download import Download
from .file import File
from .folder import Folder
from .group import Group
from .node import Node
from .order import Order
from .policy import Policy
from .redeem import Redeem
from .report import Report
from .setting import Setting
from .share import Share
from .source_link import SourceLink
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .user import User
from .webdav import WebDAV
# 可以定义一个 __all__ 列表来明确指定可以被 from .models import * 导入的内容
__all__ = [
"BaseModel", "Download", "File", "Folder", "Group", "Node", "Order",
"Policy", "Redeem", "Report", "Setting", "Share", "SourceLink",
"StoragePack", "Tag", "Task", "User", "WebDAV"
]

9
models/base.py Normal file
View File

@@ -0,0 +1,9 @@
# my_project/models/base.py
from typing import Optional
from sqlmodel import SQLModel, Field
class BaseModel(SQLModel):
__abstract__ = True
id: Optional[int] = Field(default=None, primary_key=True, description="主键ID")

31
models/database.py Normal file
View File

@@ -0,0 +1,31 @@
# my_project/database.py
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import sessionmaker
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db"
engine = create_async_engine(
ASYNC_DATABASE_URL,
echo=True,
connect_args={"check_same_thread": False}
if ASYNC_DATABASE_URL.startswith("sqlite")
else None,
future=True,
# pool_size=POOL_SIZE,
# max_overflow=64,
)
_async_session_factory = sessionmaker(engine, class_=AsyncSession)
async def get_session():
async with _async_session_factory() as session:
yield session
async def init_db():
"""初始化数据库"""
# 创建所有表
async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)

56
models/download.py Normal file
View File

@@ -0,0 +1,56 @@
# my_project/models/download.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
from .task import Task
from .node import Node
class Download(BaseModel, table=True):
__tablename__ = 'downloads'
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载状态: 0=进行中, 1=完成, 2=错误")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务类型")
source: str = Field(description="来源URL或标识")
total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="总大小(字节)")
downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已下载大小(字节)")
g_id: Optional[str] = Field(default=None, index=True, description="Aria2 GID")
speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载速度 (bytes/s)")
parent: Optional[str] = Field(default=None, description="父任务标识")
attrs: Optional[str] = Field(default=None, description="额外属性 (JSON格式)")
error: Optional[str] = Field(default=None, description="错误信息")
dst: str = Field(description="目标存储路径")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
task_id: Optional[int] = Field(default=None, foreign_key="tasks.id", index=True, description="关联的任务ID")
node_id: int = Field(foreign_key="nodes.id", index=True, description="执行下载的节点ID")
# 关系
user: "User" = Relationship(back_populates="downloads")
task: Optional["Task"] = Relationship(back_populates="downloads")
node: "Node" = Relationship(back_populates="downloads")

53
models/file.py Normal file
View File

@@ -0,0 +1,53 @@
# my_project/models/file.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
from .folder import Folder
from .policy import Policy
from .source_link import SourceLink
class File(BaseModel, table=True):
__tablename__ = 'files'
name: str = Field(max_length=255, description="文件名")
source_name: Optional[str] = Field(default=None, description="源文件名")
size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="文件大小(字节)")
pic_info: Optional[str] = Field(default=None, max_length=255, description="图片信息(如尺寸)")
upload_session_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="分块上传会话ID")
file_metadata: Optional[str] = Field(default=None, description="文件元数据 (JSON格式)")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
folder_id: int = Field(foreign_key="folders.id", index=True, description="所在目录ID")
policy_id: int = Field(foreign_key="policies.id", index=True, description="所属存储策略ID")
# 关系
user: list["User"] = Relationship(back_populates="files")
folder: list["Folder"] = Relationship(back_populates="files")
policy: list["Policy"] = Relationship(back_populates="files")
source_links: list["SourceLink"] = Relationship(back_populates="file")

52
models/folder.py Normal file
View File

@@ -0,0 +1,52 @@
# my_project/models/folder.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
from .policy import Policy
from .file import File
class Folder(BaseModel, table=True):
__tablename__ = 'folders'
__table_args__ = (UniqueConstraint("name", "parent_id", name="uq_folder_name_parent"),)
name: str = Field(max_length=255, description="目录名")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 外键
parent_id: Optional[int] = Field(default=None, foreign_key="folders.id", index=True, description="父目录ID")
owner_id: int = Field(foreign_key="users.id", index=True, description="所有者用户ID")
policy_id: int = Field(foreign_key="policies.id", index=True, description="所属存储策略ID")
# 关系
owner: "User" = Relationship(back_populates="folders")
policy: "Policy" = Relationship(back_populates="folders")
# 自我引用关系
parent: Optional["Folder"] = Relationship(back_populates="children", sa_relationship_kwargs={"remote_side": "Folder.id"})
children: List["Folder"] = Relationship(back_populates="parent")
files: List["File"] = Relationship(back_populates="folder")

44
models/group.py Normal file
View File

@@ -0,0 +1,44 @@
# my_project/models/group.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
class Group(BaseModel, table=True):
__tablename__ = 'groups'
name: str = Field(max_length=255, unique=True, description="用户组名")
policies: Optional[str] = Field(default=None, max_length=255, description="允许的策略ID列表逗号分隔")
max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="最大存储空间(字节)")
share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许创建分享")
web_dav_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许使用WebDAV")
speed_limit: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="速度限制 (KB/s), 0为不限制")
options: Optional[str] = Field(default=None, description="其他选项 (JSON格式)")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 关系:一个组可以有多个用户
users: List["User"] = Relationship(back_populates="group")
previous_users: List["User"] = Relationship(back_populates="previous_group")

View File

@@ -1,251 +0,0 @@
from sqlalchemy import (
Column, Integer, String, Text, BigInteger, Boolean, DateTime,
ForeignKey, func, text, UniqueConstraint
)
from sqlalchemy.orm import declarative_base
Base = declarative_base()
class BaseModel(Base):
__abstract__ = True
id = Column(Integer, primary_key=True, comment="主键ID")
created_at = Column(
DateTime,
server_default=func.now(),
comment="创建时间"
)
updated_at = Column(
DateTime,
server_default=func.now(),
onupdate=func.now(),
server_onupdate=func.now(),
comment="更新时间"
)
deleted_at = Column(DateTime, nullable=True, comment="软删除时间")
class Download(BaseModel):
__tablename__ = 'downloads'
status = Column(Integer, nullable=False, server_default='0', comment="下载状态: 0=进行中, 1=完成, 2=错误")
type = Column(Integer, nullable=False, server_default='0', comment="任务类型")
source = Column(Text, nullable=False, comment="来源URL或标识")
total_size = Column(BigInteger, nullable=False, server_default='0', comment="总大小(字节)")
downloaded_size = Column(BigInteger, nullable=False, server_default='0', comment="已下载大小(字节)")
g_id = Column(Text, index=True, comment="Aria2 GID") # GID经常用于查询建议索引
speed = Column(Integer, nullable=False, server_default='0', comment="下载速度 (bytes/s)")
parent = Column(Text, comment="父任务标识")
attrs = Column(Text, comment="额外属性 (JSON格式)")
error = Column(Text, comment="错误信息")
dst = Column(Text, nullable=False, comment="目标存储路径")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
task_id = Column(Integer, ForeignKey('tasks.id'), nullable=True, index=True, comment="关联的任务ID")
node_id = Column(Integer, ForeignKey('nodes.id'), nullable=False, index=True, comment="执行下载的节点ID")
class File(BaseModel):
__tablename__ = 'files'
name = Column(String(255), nullable=False, comment="文件名")
source_name = Column(Text, comment="源文件名")
size = Column(BigInteger, nullable=False, server_default='0', comment="文件大小(字节)")
pic_info = Column(String(255), comment="图片信息(如尺寸)")
upload_session_id = Column(String(255), unique=True, index=True, comment="分块上传会话ID")
metadata = Column(Text, comment="文件元数据 (JSON格式)")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
folder_id = Column(Integer, ForeignKey('folders.id'), nullable=False, index=True, comment="所在目录ID")
policy_id = Column(Integer, ForeignKey('policies.id'), nullable=False, index=True, comment="所属存储策略ID")
class Folder(BaseModel):
__tablename__ = 'folders'
name = Column(String(255), nullable=False, comment="目录名")
parent_id = Column(Integer, ForeignKey('folders.id'), nullable=True, index=True, comment="父目录ID")
owner_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所有者用户ID")
policy_id = Column(Integer, ForeignKey('policies.id'), nullable=False, index=True, comment="所属存储策略ID")
__table_args__ = (
UniqueConstraint('name', 'parent_id', name='uq_folder_name_parent'),
)
class Group(BaseModel):
__tablename__ = 'groups'
name = Column(String(255), nullable=False, unique=True, comment="用户组名")
policies = Column(String(255), comment="允许的策略ID列表逗号分隔")
max_storage = Column(BigInteger, nullable=False, server_default='0', comment="最大存储空间(字节)")
share_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否允许创建分享")
web_dav_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否允许使用WebDAV")
speed_limit = Column(Integer, nullable=False, server_default='0', comment="速度限制 (KB/s), 0为不限制")
options = Column(Text, comment="其他选项 (JSON格式)")
class Node(BaseModel):
__tablename__ = 'nodes'
status = Column(Integer, nullable=False, server_default='0', comment="节点状态: 0=正常, 1=离线")
name = Column(String(255), nullable=False, unique=True, comment="节点名称")
type = Column(Integer, nullable=False, server_default='0', comment="节点类型")
server = Column(String(255), nullable=False, comment="节点地址IP或域名")
slave_key = Column(Text, comment="从机通讯密钥")
master_key = Column(Text, comment="主机通讯密钥")
aria2_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否启用Aria2")
aria2_options = Column(Text, comment="Aria2配置 (JSON格式)")
rank = Column(Integer, nullable=False, server_default='0', comment="节点排序权重")
class Order(BaseModel):
__tablename__ = 'orders'
order_no = Column(String(255), nullable=False, unique=True, index=True, comment="订单号,唯一")
type = Column(Integer, nullable=False, comment="订单类型")
method = Column(String(255), comment="支付方式")
product_id = Column(BigInteger, comment="商品ID")
num = Column(Integer, nullable=False, server_default='1', comment="购买数量")
name = Column(String(255), nullable=False, comment="商品名称")
price = Column(Integer, nullable=False, server_default='0', comment="订单价格(分)")
status = Column(Integer, nullable=False, server_default='0', comment="订单状态: 0=待支付, 1=已完成, 2=已取消")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
class Policy(BaseModel):
__tablename__ = 'policies'
name = Column(String(255), nullable=False, unique=True, comment="策略名称")
type = Column(String(255), nullable=False, comment="存储类型 (e.g. 'local', 's3')")
server = Column(String(255), comment="服务器地址(本地策略为路径)")
bucket_name = Column(String(255), comment="存储桶名称")
is_private = Column(Boolean, nullable=False, server_default=text('true'), comment="是否为私有空间")
base_url = Column(String(255), comment="访问文件的基础URL")
access_key = Column(Text, comment="Access Key")
secret_key = Column(Text, comment="Secret Key")
max_size = Column(BigInteger, nullable=False, server_default='0', comment="允许上传的最大文件尺寸(字节)")
auto_rename = Column(Boolean, nullable=False, server_default=text('false'), comment="是否自动重命名")
dir_name_rule = Column(String(255), comment="目录命名规则")
file_name_rule = Column(String(255), comment="文件命名规则")
is_origin_link_enable = Column(Boolean, nullable=False, server_default=text('false'), comment="是否开启源链接访问")
options = Column(Text, comment="其他选项 (JSON格式)")
class Setting(BaseModel):
__tablename__ = 'settings'
# 优化点: type和name的组合应该是唯一的
type = Column(String(255), nullable=False, comment="设置类型/分组")
name = Column(String(255), nullable=False, comment="设置项名称")
value = Column(Text, comment="设置值")
__table_args__ = (
UniqueConstraint('type', 'name', name='uq_setting_type_name'),
)
class Share(BaseModel):
__tablename__ = 'shares'
password = Column(String(255), comment="分享密码(加密后)")
is_dir = Column(Boolean, nullable=False, server_default=text('false'), comment="是否为目录分享")
source_id = Column(Integer, nullable=False, comment="源文件或目录的ID")
views = Column(Integer, nullable=False, server_default='0', comment="浏览次数")
downloads = Column(Integer, nullable=False, server_default='0', comment="下载次数")
remain_downloads = Column(Integer, comment="剩余下载次数 (NULL为不限制)")
expires = Column(DateTime, comment="过期时间 (NULL为永不过期)")
preview_enabled = Column(Boolean, nullable=False, server_default=text('true'), comment="是否允许预览")
source_name = Column(String(255), index=True, comment="源名称(冗余字段,便于展示)")
score = Column(Integer, nullable=False, server_default='0', comment="分享评分/权重")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="创建分享的用户ID")
class Task(BaseModel):
__tablename__ = 'tasks'
status = Column(Integer, nullable=False, server_default='0', comment="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误")
type = Column(Integer, nullable=False, comment="任务类型")
progress = Column(Integer, nullable=False, server_default='0', comment="任务进度 (0-100)")
error = Column(Text, comment="错误信息")
props = Column(Text, comment="任务属性 (JSON格式)")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
class User(BaseModel):
__tablename__ = 'users'
email = Column(String(100), nullable=False, unique=True, index=True, comment="用户邮箱,唯一")
nick = Column(String(50), comment="用户昵称")
password = Column(String(255), nullable=False, comment="用户密码(加密后)")
status = Column(Integer, nullable=False, server_default='0', comment="用户状态: 0=正常, 1=未激活, 2=封禁")
storage = Column(BigInteger, nullable=False, server_default='0', comment="已用存储空间(字节)")
two_factor = Column(String(255), comment="两步验证密钥")
avatar = Column(String(255), comment="头像地址")
options = Column(Text, comment="用户个人设置 (JSON格式)")
authn = Column(Text, comment="WebAuthn 凭证")
open_id = Column(String(255), unique=True, index=True, nullable=True, comment="第三方登录OpenID")
score = Column(Integer, nullable=False, server_default='0', comment="用户积分")
group_expires = Column(DateTime, comment="当前用户组过期时间")
phone = Column(String(255), unique=True, nullable=True, index=True, comment="手机号")
group_id = Column(Integer, ForeignKey('groups.id'), nullable=False, index=True, comment="所属用户组ID")
previous_group_id = Column(Integer, ForeignKey('groups.id'), nullable=True, comment="之前的用户组ID用于过期后恢复")
class Redeem(BaseModel):
__tablename__ = 'redeems'
type = Column(Integer, nullable=False, comment="兑换码类型")
product_id = Column(BigInteger, comment="关联的商品/权益ID")
num = Column(Integer, nullable=False, server_default='1', comment="可兑换数量/时长等")
code = Column(Text, nullable=False, unique=True, index=True, comment="兑换码,唯一")
used = Column(Boolean, nullable=False, server_default=text('false'), comment="是否已使用")
class Report(BaseModel):
__tablename__ = 'reports'
share_id = Column(Integer, ForeignKey('shares.id'), index=True, nullable=False, comment="被举报的分享ID")
reason = Column(Integer, nullable=False, comment="举报原因代码")
description = Column(String(255), comment="补充描述")
class SourceLink(BaseModel):
__tablename__ = 'source_links'
file_id = Column(Integer, ForeignKey('files.id'), nullable=False, index=True, comment="关联的文件ID")
name = Column(String(255), nullable=False, comment="链接名称")
downloads = Column(Integer, nullable=False, server_default='0', comment="通过此链接的下载次数")
class StoragePack(BaseModel):
__tablename__ = 'storage_packs'
name = Column(String(255), nullable=False, comment="容量包名称")
active_time = Column(DateTime, comment="激活时间")
expired_time = Column(DateTime, index=True, comment="过期时间")
size = Column(BigInteger, nullable=False, comment="容量包大小(字节)")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
class Tag(BaseModel):
__tablename__ = 'tags'
name = Column(String(255), nullable=False, comment="标签名称")
icon = Column(String(255), comment="标签图标")
color = Column(String(255), comment="标签颜色")
type = Column(Integer, nullable=False, server_default='0', comment="标签类型: 0=手动, 1=自动")
expression = Column(Text, comment="自动标签的匹配表达式")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
__table_args__ = (
UniqueConstraint('name', 'user_id', name='uq_tag_name_user'),
)
class WebDAV(BaseModel):
__tablename__ = 'webdavs'
name = Column(String(255), nullable=False, comment="WebDAV账户名")
password = Column(String(255), nullable=False, comment="WebDAV密码加密后")
root = Column(Text, nullable=False, server_default="'/'", comment="根目录路径")
readonly = Column(Boolean, nullable=False, server_default=text('false'), comment="是否只读")
use_proxy = Column(Boolean, nullable=False, server_default=text('false'), comment="是否使用代理下载")
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID")
__table_args__ = (
UniqueConstraint('name', 'user_id', name='uq_webdav_name_user'),
)

45
models/node.py Normal file
View File

@@ -0,0 +1,45 @@
# my_project/models/node.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .download import Download
class Node(BaseModel, table=True):
__tablename__ = 'nodes'
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点状态: 0=正常, 1=离线")
name: str = Field(max_length=255, unique=True, description="节点名称")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点类型")
server: str = Field(max_length=255, description="节点地址IP或域名")
slave_key: Optional[str] = Field(default=None, description="从机通讯密钥")
master_key: Optional[str] = Field(default=None, description="主机通讯密钥")
aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否启用Aria2")
aria2_options: Optional[str] = Field(default=None, description="Aria2配置 (JSON格式)")
rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点排序权重")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 关系
downloads: list["Download"] = Relationship(back_populates="node")

47
models/order.py Normal file
View File

@@ -0,0 +1,47 @@
# my_project/models/order.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
class Order(BaseModel, table=True):
__tablename__ = 'orders'
order_no: str = Field(max_length=255, unique=True, index=True, description="订单号,唯一")
type: int = Field(description="订单类型")
method: Optional[str] = Field(default=None, max_length=255, description="支付方式")
product_id: Optional[int] = Field(default=None, description="商品ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="购买数量")
name: str = Field(max_length=255, description="商品名称")
price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单价格(分)")
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
# 关系
user: "User" = Relationship(back_populates="orders")

52
models/policy.py Normal file
View File

@@ -0,0 +1,52 @@
# my_project/models/policy.py
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .file import File
from .folder import Folder
class Policy(BaseModel, table=True):
__tablename__ = 'policies'
name: str = Field(max_length=255, unique=True, description="策略名称")
type: str = Field(max_length=255, description="存储类型 (e.g. 'local', 's3')")
server: Optional[str] = Field(default=None, max_length=255, description="服务器地址(本地策略为路径)")
bucket_name: Optional[str] = Field(default=None, max_length=255, description="存储桶名称")
is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否为私有空间")
base_url: Optional[str] = Field(default=None, max_length=255, description="访问文件的基础URL")
access_key: Optional[str] = Field(default=None, description="Access Key")
secret_key: Optional[str] = Field(default=None, description="Secret Key")
max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="允许上传的最大文件尺寸(字节)")
auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否自动重命名")
dir_name_rule: Optional[str] = Field(default=None, max_length=255, description="目录命名规则")
file_name_rule: Optional[str] = Field(default=None, max_length=255, description="文件命名规则")
is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否开启源链接访问")
options: Optional[str] = Field(default=None, description="其他选项 (JSON格式)")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 关系
files: List["File"] = Relationship(back_populates="policy")
folders: List["Folder"] = Relationship(back_populates="policy")

35
models/redeem.py Normal file
View File

@@ -0,0 +1,35 @@
# my_project/models/redeem.py
from typing import Optional
from sqlmodel import Field, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
class Redeem(BaseModel, table=True):
__tablename__ = 'redeems'
type: int = Field(description="兑换码类型")
product_id: Optional[int] = Field(default=None, description="关联的商品/权益ID")
num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等")
code: str = Field(unique=True, index=True, description="兑换码,唯一")
used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)

41
models/report.py Normal file
View File

@@ -0,0 +1,41 @@
# my_project/models/report.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .share import Share
class Report(BaseModel, table=True):
__tablename__ = 'reports'
reason: int = Field(description="举报原因代码")
description: Optional[str] = Field(default=None, max_length=255, description="补充描述")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
# 外键
share_id: int = Field(foreign_key="shares.id", index=True, description="被举报的分享ID")
# 关系
share: "Share" = Relationship(back_populates="reports")

View File

@@ -1,7 +1,20 @@
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Union, Optional from typing import Literal, Union, Optional
from uuid import uuid4
class ResponseModel(BaseModel): class ResponseModel(BaseModel):
code: int = Field(default=0, description="系统内部状态码, 0表示成功其他表示失败", lt=60000, gt=0) code: int = Field(default=0, description="系统内部状态码, 0表示成功其他表示失败", lt=60000, gt=0)
data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据") data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据")
msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示") msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示")
instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID用于标识请求的唯一性")
class SiteConfigModel(ResponseModel):
title: str = Field(default="DiskNext", description="网站标题")
themes: dict = Field(default_factory=dict, description="网站主题配置")
default_theme: str = Field(default="default", description="默认主题RGB色号")
site_notice: Optional[str] = Field(default=None, description="网站公告")
user: dict = Field(default_factory=dict, description="用户信息")
logo_light: Optional[str] = Field(default=None, description="网站Logo URL")
logo_dark: Optional[str] = Field(default=None, description="网站Logo URL深色模式")
captcha_type: Literal['none', 'default', 'gcaptcha', 'cloudflare turnstile'] = Field(default='none', description="验证码类型")
captcha_key: Optional[str] = Field(default=None, description="验证码密钥")

34
models/setting.py Normal file
View File

@@ -0,0 +1,34 @@
# my_project/models/setting.py
from typing import Optional
from sqlmodel import Field, UniqueConstraint, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
class Setting(BaseModel, table=True):
__tablename__ = 'settings'
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
type: str = Field(max_length=255, description="设置类型/分组")
name: str = Field(max_length=255, description="设置项名称")
value: Optional[str] = Field(default=None, description="设置值")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)

61
models/share.py Normal file
View File

@@ -0,0 +1,61 @@
# my_project/models/share.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
from .report import Report
class Share(BaseModel, table=True):
__tablename__ = 'shares'
password: Optional[str] = Field(default=None, max_length=255, description="分享密码(加密后)")
is_dir: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否为目录分享")
source_id: int = Field(description="源文件或目录的ID")
views: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="浏览次数")
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载次数")
remain_downloads: Optional[int] = Field(default=None, description="剩余下载次数 (NULL为不限制)")
expires: Optional[datetime] = Field(default=None, description="过期时间 (NULL为永不过期)")
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否允许预览")
source_name: Optional[str] = Field(default=None, max_length=255, index=True, description="源名称(冗余字段,便于展示)")
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="兑换此分享所需的积分")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="创建分享的用户ID")
# 关系
user: "User" = Relationship(back_populates="shares")
reports: list["Report"] = Relationship(back_populates="share")

50
models/source_link.py Normal file
View File

@@ -0,0 +1,50 @@
# my_project/models/source_link.py
from typing import TYPE_CHECKING, Optional
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .file import File
class SourceLink(BaseModel, table=True):
__tablename__ = 'source_links'
name: str = Field(max_length=255, description="链接名称")
downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="通过此链接的下载次数")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
file_id: int = Field(foreign_key="files.id", index=True, description="关联的文件ID")
# 关系
file: "File" = Relationship(back_populates="source_links")

52
models/storage_pack.py Normal file
View File

@@ -0,0 +1,52 @@
# my_project/models/storage_pack.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
class StoragePack(BaseModel, table=True):
__tablename__ = 'storage_packs'
name: str = Field(max_length=255, description="容量包名称")
active_time: Optional[datetime] = Field(default=None, description="激活时间")
expired_time: Optional[datetime] = Field(default=None, index=True, description="过期时间")
size: int = Field(description="容量包大小(字节)")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
# 关系
user: "User" = Relationship(back_populates="storage_packs")

54
models/tag.py Normal file
View File

@@ -0,0 +1,54 @@
# my_project/models/tag.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
class Tag(BaseModel, table=True):
__tablename__ = 'tags'
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),)
name: str = Field(max_length=255, description="标签名称")
icon: Optional[str] = Field(default=None, max_length=255, description="标签图标")
color: Optional[str] = Field(default=None, max_length=255, description="标签颜色")
type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="标签类型: 0=手动, 1=自动")
expression: Optional[str] = Field(default=None, description="自动标签的匹配表达式")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
# 关系
user: "User" = Relationship(back_populates="tags")

55
models/task.py Normal file
View File

@@ -0,0 +1,55 @@
# my_project/models/task.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
from .download import Download
class Task(BaseModel, table=True):
__tablename__ = 'tasks'
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误")
type: int = Field(description="任务类型")
progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)")
error: Optional[str] = Field(default=None, description="错误信息")
props: Optional[str] = Field(default=None, description="任务属性 (JSON格式)")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
user_id: "int" = Field(foreign_key="users.id", index=True, description="所属用户ID")
# 关系
user: "User" = Relationship(back_populates="tasks")
downloads: list["Download"] = Relationship(back_populates="task")

83
models/user.py Normal file
View File

@@ -0,0 +1,83 @@
# my_project/models/user.py
from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
if TYPE_CHECKING:
from .group import Group
from .download import Download
from .file import File
from .folder import Folder
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .webdav import WebDAV
class User(BaseModel, table=True):
__tablename__ = 'users'
email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一")
nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称")
password: str = Field(max_length=255, description="用户密码(加密后)")
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户状态: 0=正常, 1=未激活, 2=封禁")
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)")
two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥")
avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址")
options: Optional[str] = Field(default=None, description="用户个人设置 (JSON格式)")
authn: Optional[str] = Field(default=None, description="WebAuthn 凭证")
open_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="第三方登录OpenID")
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户积分")
group_expires: Optional[datetime] = Field(default=None, description="当前用户组过期时间")
phone: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="手机号")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
group_id: int = Field(foreign_key="groups.id", index=True, description="所属用户组ID")
previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID用于过期后恢复")
# 关系
group: "Group" = Relationship(back_populates="users")
previous_group: Optional["Group"] = Relationship(back_populates="previous_users")
downloads: list["Download"] = Relationship(back_populates="user")
files: list["File"] = Relationship(back_populates="user")
folders: list["Folder"] = Relationship(back_populates="owner")
orders: list["Order"] = Relationship(back_populates="user")
shares: list["Share"] = Relationship(back_populates="user")
storage_packs: list["StoragePack"] = Relationship(back_populates="user")
tags: list["Tag"] = Relationship(back_populates="user")
tasks: list["Task"] = Relationship(back_populates="user")
webdavs: list["WebDAV"] = Relationship(back_populates="user")

54
models/webdav.py Normal file
View File

@@ -0,0 +1,54 @@
# my_project/models/webdav.py
from typing import Optional, TYPE_CHECKING
from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime
from .base import BaseModel
from datetime import datetime
if TYPE_CHECKING:
from .user import User
class WebDAV(BaseModel, table=True):
__tablename__ = 'webdavs'
__table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),)
name: str = Field(max_length=255, description="WebDAV账户名")
password: str = Field(max_length=255, description="WebDAV密码加密后")
root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径")
readonly: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否只读")
use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否使用代理下载")
created_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
comment="创建时间",
),
)
updated_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=False,
server_default=func.now(),
onupdate=func.now(),
comment="更新时间",
),
)
delete_at: Optional[datetime] = Field(
default=None,
sa_column=Column(
DateTime,
nullable=True,
comment="删除时间",
),
)
# 外键
user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID")
# 关系
user: "User" = Relationship(back_populates="webdavs")

39
pkg/lifespan/lifespan.py Normal file
View File

@@ -0,0 +1,39 @@
from fastapi import FastAPI
from contextlib import asynccontextmanager
__on_startup: list[callable] = []
__on_shutdown: list[callable] = []
def add_startup(func: callable):
"""
注册一个函数,在应用启动时调用。
:param func: 需要注册的函数。它应该是一个异步函数。
"""
__on_startup.append(func)
def add_shutdown(func: callable):
"""
注册一个函数,在应用关闭时调用。
:param func: 需要注册的函数。
"""
__on_shutdown.append(func)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""
应用程序的生命周期管理器。
此函数在应用启动时执行所有注册的启动函数,
并在应用关闭时执行所有注册的关闭函数。
"""
# Execute all startup functions
for func in __on_startup:
await func()
yield
# Execute all shutdown functions
for func in __on_shutdown:
await func()

211
pkg/log/log.py Normal file
View File

@@ -0,0 +1,211 @@
from rich import print
from rich.console import Console
from rich.markdown import Markdown
from configparser import ConfigParser
from typing import Literal, Optional, Dict, Union
from enum import Enum
import time
import os
import inspect
class LogLevelEnum(str, Enum):
DEBUG = 'debug'
INFO = 'info'
WARNING = 'warning'
ERROR = 'error'
SUCCESS = 'success'
# 默认日志级别
LogLevel = LogLevelEnum.INFO
# 日志文件路径
LOG_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# 是否启用文件日志
ENABLE_FILE_LOG = False
def set_log_level(level: Union[str, LogLevelEnum]) -> None:
"""设置日志级别"""
global LogLevel
if isinstance(level, str):
try:
LogLevel = LogLevelEnum(level.lower())
except ValueError:
print(f"[bold red]无效的日志级别: {level},使用默认级别: {LogLevel}[/bold red]")
else:
LogLevel = level
def enable_file_log(enable: bool = True) -> None:
"""启用或禁用文件日志"""
global ENABLE_FILE_LOG
ENABLE_FILE_LOG = enable
if enable and not os.path.exists(LOG_FILE_PATH):
try:
os.makedirs(LOG_FILE_PATH)
except Exception as e:
print(f"[bold red]创建日志目录失败: {e}[/bold red]")
ENABLE_FILE_LOG = False
def truncate_path(full_path: str, marker: str = "HeyAuth") -> str:
"""截断路径只保留从marker开始的部分"""
try:
marker_index = full_path.find(marker)
if marker_index != -1:
return '.' + full_path[marker_index + len(marker):]
return full_path
except Exception:
return full_path
def get_caller_info(depth: int = 2) -> tuple:
"""获取调用者信息"""
try:
frame = inspect.currentframe()
# 向上查找指定深度的调用帧
for _ in range(depth):
if frame.f_back is None:
break
frame = frame.f_back
filename = frame.f_code.co_filename
lineno = frame.f_lineno
return truncate_path(filename), lineno
except Exception:
return "<unknown>", 0
finally:
# 确保引用被释放
del frame
def log(level: str = 'debug', message: str = ''):
"""
输出日志
---
通过传入的`level`和`message`参数,输出不同级别的日志信息。<br>
`level`参数为日志级别,支持`红色error`、`紫色info`、`绿色success`、`黄色warning`、`淡蓝色debug`。<br>
`message`参数为日志信息。<br>
"""
level_colors: Dict[str, str] = {
'debug': '[bold cyan][DEBUG][/bold cyan]',
'info': '[bold blue][INFO][/bold blue]',
'warning': '[bold yellow][WARN][/bold yellow]',
'error': '[bold red][ERROR][/bold red]',
'success': '[bold green][SUCCESS][/bold green]'
}
level_value = level.lower()
lv = level_colors.get(level_value, '[bold magenta][UNKNOWN][/bold magenta]')
# 获取调用者信息
filename, lineno = get_caller_info(3) # 考虑lambda调用和包装函数深度为3
timestamp = time.strftime('%Y/%m/%d %H:%M:%S %p', time.localtime())
log_message = f"{lv}\t{timestamp} [bold]From {filename}, line {lineno}[/bold] {message}"
# 根据日志级别判断是否输出
global LogLevel
should_log = False
if level_value == 'debug' and LogLevel == LogLevelEnum.DEBUG:
should_log = True
elif level_value == 'info' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO]:
should_log = True
elif level_value == 'warning' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO, LogLevelEnum.WARNING]:
should_log = True
elif level_value == 'error':
should_log = True
elif level_value == 'success':
should_log = False
if should_log:
print(log_message)
# 文件日志记录
if ENABLE_FILE_LOG:
try:
# 去除rich格式化标记
clean_message = f"{level_value.upper()}\t{timestamp} From {filename}, line {lineno} {message}"
log_file = os.path.join(LOG_FILE_PATH, f"{time.strftime('%Y%m%d')}.log")
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"{clean_message}\n")
except Exception as e:
print(f"[bold red]写入日志文件失败: {e}[/bold red]")
# 便捷日志函数
debug = lambda message: log('debug', message)
info = lambda message: log('info', message)
warning = lambda message: log('warn', message)
error = lambda message: log('error', message)
success = lambda message: log('success', message)
def load_config(config_path: str) -> bool:
"""从配置文件加载日志配置"""
try:
if not os.path.exists(config_path):
return False
config = ConfigParser()
config.read(config_path, encoding='utf-8')
if 'log' in config:
log_config = config['log']
if 'level' in log_config:
set_log_level(log_config['level'])
if 'file_log' in log_config:
enable_file_log(log_config.getboolean('file_log'))
if 'log_path' in log_config:
global LOG_FILE_PATH
custom_path = log_config['log_path']
if os.path.exists(custom_path) or os.makedirs(custom_path, exist_ok=True):
LOG_FILE_PATH = custom_path
return True
except Exception as e:
error(f"加载日志配置失败: {e}")
return False
def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'):
"""
输出标题
---
通过传入的`title`参数,输出一个整行的标题。<br>
`title`参数为标题内容。<br>
"""
try:
console = Console()
markdown_sizes = {
'h1': '# ',
'h2': '## ',
'h3': '### ',
'h4': '#### ',
'h5': '##### '
}
markdown_tag = markdown_sizes.get(size, '# ')
console.print(Markdown(markdown_tag + title))
except Exception as e:
error(f"输出标题失败: {e}")
finally:
if 'console' in locals():
del console
if True:
pass
if __name__ == '__main__':
# 测试代码
title('海枫授权系统 日志组件测试', 'h1')
title('测试h2标题', 'h2')
title('测试h3标题', 'h3')
title('测试h4标题', 'h4')
title('测试h5标题', 'h5')
print("\n默认日志级别(INFO)测试:")
debug('这是一个debug日志') # 不会显示
info('这是一个info日志')
warning('这是一个warning日志')
error('这是一个error日志')
success('这是一个success日志')
print("\n设置为DEBUG级别测试:")
set_log_level(LogLevelEnum.DEBUG)
debug('这是一个debug日志') # 现在会显示
print("\n启用文件日志测试:")
enable_file_log()
info('此日志将同时记录到文件')

View File

@@ -1,15 +0,0 @@
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
def test_read_main():
from pkg.conf.appmeta import BackendVersion
response = client.get("/api/site/ping")
assert response.status_code == 200
assert response.json() == {
"code": 0,
'data': BackendVersion,
'msg': None}

9
tests/conftest.py Normal file
View File

@@ -0,0 +1,9 @@
"""
Pytest配置文件
"""
import pytest
import os
import sys
# 添加项目根目录到Python路径确保可以导入项目模块
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

8
tests/test_database.py Normal file
View File

@@ -0,0 +1,8 @@
from models import database
import pytest
@pytest.mark.asyncio
async def test_initialize_db():
"""Fixture to initialize the database before tests."""
await database.init_db()

23
tests/test_main.py Normal file
View File

@@ -0,0 +1,23 @@
from fastapi.testclient import TestClient
from main import app
client = TestClient(app)
def test_read_main():
from pkg.conf.appmeta import BackendVersion
import uuid
response = client.get("/api/site/ping")
json_response = response.json()
assert response.status_code == 200
assert json_response['code'] == 0
assert json_response['data'] == BackendVersion
assert json_response['msg'] is None
assert 'instance_id' in json_response
try:
uuid.UUID(json_response['instance_id'], version=4)
except (ValueError, TypeError):
assert False, f"instance_id is not a valid UUID4: {json_response['instance_id']}"