diff --git a/.gitignore b/.gitignore index 267ca8b..c52ca08 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ __pycache__/ *.pyo *.pyd -*.code-workspace \ No newline at end of file +*.code-workspace + +*.db \ No newline at end of file diff --git a/clean.py b/clean.py new file mode 100644 index 0000000..56fbb1a --- /dev/null +++ b/clean.py @@ -0,0 +1,243 @@ +import os +import shutil +from pkg.log import log as log +import argparse +from typing import List, Tuple, Set +import time + +Version = "2.1.0" + +# 默认排除的目录 +DEFAULT_EXCLUDE_DIRS = {"venv", "env", ".venv", ".env", "node_modules"} + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='清理 Python 缓存文件') + parser.add_argument('-p', '--path', default=os.getcwd(), + help='要清理的目录路径,默认为当前工作目录') + parser.add_argument('-y', '--yes', action='store_true', + help='自动确认所有操作,不再询问') + parser.add_argument('--no-pycache', action='store_true', + help='不清理 __pycache__ 目录') + parser.add_argument('--no-nicegui', action='store_true', + help='不清理 .nicegui 目录') + parser.add_argument('--no-testdb', action='store_true', + help='不清理 test.db 文件') + parser.add_argument('--pyc', action='store_true', + help='清理 .pyc 文件') + parser.add_argument('--pytest-cache', action='store_true', + help='清理 .pytest_cache 目录') + parser.add_argument('--exclude', type=str, default="", + help='排除的目录,多个目录用逗号分隔') + parser.add_argument('--log-file', + help='指定日志文件路径') + parser.add_argument('--dry-run', action='store_true', + help='仅列出将要删除的文件,不实际删除') + return parser.parse_args() + +def confirm_action(message: str, auto_yes: str = False) -> bool: + if auto_yes: + return True + return input(f"{message} (y/N): ").lower() == 'y' + +def safe_remove(path: str, dry_run: bool = False) -> Tuple[bool, str]: + """安全删除文件或目录""" + try: + if dry_run: + return True, f"DRY RUN: 将删除 {path}" + + if os.path.isdir(path): + shutil.rmtree(path) + elif os.path.isfile(path): + os.remove(path) + return True, "" + except PermissionError as e: + return False, f"权限错误: {e}" + except OSError as e: + return False, f"系统错误: {e}" + except Exception as e: + return False, f"未知错误: {e}" + +def get_excluded_dirs(exclude_arg: str) -> Set[str]: + """获取要排除的目录列表""" + result = set(DEFAULT_EXCLUDE_DIRS) + if exclude_arg: + for item in exclude_arg.split(','): + item = item.strip() + if item: + result.add(item) + return result + +def clean_pycache(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]: + """清理 __pycache__ 目录""" + log.info("开始清理 __pycache__ 目录...") + cleaned_paths = [] + + for dirpath, dirnames, _ in os.walk(root_dir): + # 排除指定目录 + for exclude in exclude_dirs: + if exclude in dirnames: + dirnames.remove(exclude) + + if "__pycache__" in dirnames: + pycache_dir = os.path.join(dirpath, "__pycache__") + success, error = safe_remove(pycache_dir, dry_run) + if success: + cleaned_paths.append(pycache_dir) + else: + log.error(f"无法清理 {pycache_dir}: {error}") + + return cleaned_paths + +def clean_pyc_files(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]: + """清理 .pyc 文件""" + log.info("开始清理 .pyc 文件...") + cleaned_files = [] + + for dirpath, dirnames, filenames in os.walk(root_dir): + # 排除指定目录 + for exclude in exclude_dirs: + if exclude in dirnames: + dirnames.remove(exclude) + + for filename in filenames: + if filename.endswith('.pyc'): + file_path = os.path.join(dirpath, filename) + success, error = safe_remove(file_path, dry_run) + if success: + cleaned_files.append(file_path) + else: + log.error(f"无法清理 {file_path}: {error}") + + return cleaned_files + +def clean_pytest_cache(root_dir: str, exclude_dirs: Set[str], dry_run: bool = False) -> List[str]: + """清理 .pytest_cache 目录""" + log.info("开始清理 .pytest_cache 目录...") + cleaned_paths = [] + + for dirpath, dirnames, _ in os.walk(root_dir): + # 排除指定目录 + for exclude in exclude_dirs: + if exclude in dirnames: + dirnames.remove(exclude) + + if ".pytest_cache" in dirnames: + cache_dir = os.path.join(dirpath, ".pytest_cache") + success, error = safe_remove(cache_dir, dry_run) + if success: + cleaned_paths.append(cache_dir) + else: + log.error(f"无法清理 {cache_dir}: {error}") + + return cleaned_paths + +def clean_nicegui(root_dir: str, dry_run: bool = False) -> Tuple[bool, str]: + """清理 .nicegui 目录""" + log.info("开始清理 .nicegui 目录...") + nicegui_dir = os.path.join(root_dir, ".nicegui") + if os.path.exists(nicegui_dir) and os.path.isdir(nicegui_dir): + success, error = safe_remove(nicegui_dir, dry_run) + if success: + return True, nicegui_dir + else: + log.error(f"无法清理 {nicegui_dir}: {error}") + return False, nicegui_dir + return False, nicegui_dir + +def clean_testdb(root_dir: str, dry_run: bool = False) -> Tuple[bool, str, str]: + """清理测试数据库文件""" + log.info("开始清理 test.db 文件...") + test_db = os.path.join(root_dir, "test.db") + if os.path.exists(test_db) and os.path.isfile(test_db): + success, error = safe_remove(test_db, dry_run) + if success: + return True, test_db, "" + else: + return False, test_db, error + return False, test_db, "文件不存在" + +def main(): + start_time = time.time() + args = parse_args() + root_dir = os.path.abspath(args.path) + exclude_dirs = get_excluded_dirs(args.exclude) + + # 设置日志文件 + if args.log_file: + log.set_log_file(args.log_file) + + log.title() + log.title(title=f"清理工具 Cleaner\t\tVersion:{Version}", size="h2") + print('') + + if not os.path.exists(root_dir): + log.error(f"目录不存在 Directory not exists: {root_dir}") + return 1 + + log.info(f"清理目录 Clean Directory: {root_dir}") + if args.dry_run: + log.warning("模拟运行模式: 将只列出要删除的文件,不会实际删除") + + if exclude_dirs: + log.info(f"排除目录: {', '.join(exclude_dirs)}") + + try: + total_cleaned = 0 + + # 清理 __pycache__ + if not args.no_pycache and confirm_action("是否清理 __pycache__ 目录?", args.yes): + cleaned = clean_pycache(root_dir, exclude_dirs, args.dry_run) + for path in cleaned: + log.info(f"已清理 Removed: {path}") + total_cleaned += len(cleaned) + + # 清理 .pyc 文件 + if args.pyc and confirm_action("是否清理 .pyc 文件?", args.yes): + cleaned = clean_pyc_files(root_dir, exclude_dirs, args.dry_run) + for path in cleaned: + log.info(f"已清理 Removed: {path}") + total_cleaned += len(cleaned) + + # 清理 .pytest_cache + if args.pytest_cache and confirm_action("是否清理 .pytest_cache 目录?", args.yes): + cleaned = clean_pytest_cache(root_dir, exclude_dirs, args.dry_run) + for path in cleaned: + log.info(f"已清理 Removed: {path}") + total_cleaned += len(cleaned) + + # 清理 .nicegui + if not args.no_nicegui and confirm_action("是否清理 .nicegui 目录?", args.yes): + cleaned, path = clean_nicegui(root_dir, args.dry_run) + if cleaned: + log.info(f"已清理 Removed: {path}") + total_cleaned += 1 + else: + log.debug(f"未找到 Not found: {path}") + + # 清理 test.db + if not args.no_testdb and confirm_action("是否清理 test.db 文件?", args.yes): + success, path, error = clean_testdb(root_dir, args.dry_run) + if success: + log.info(f"已清理 Removed: {path}") + total_cleaned += 1 + elif error == "文件不存在": + log.debug(f"未找到 Not found: {path}") + else: + log.error(f"清理失败 Failed: {error}") + + except KeyboardInterrupt: + log.warning("操作被用户中断") + return 1 + except Exception as e: + log.error(f"错误 Error: {e}") + return 1 + else: + elapsed_time = time.time() - start_time + if args.dry_run: + log.success(f"模拟清理结束,发现 {total_cleaned} 个可清理项目 (用时: {elapsed_time:.2f}秒)") + else: + log.success(f"清理结束,共清理 {total_cleaned} 个项目 (用时: {elapsed_time:.2f}秒)") + return 0 + +if __name__ == "__main__": + exit(main()) \ No newline at end of file diff --git a/main.py b/main.py index 84e0456..40427dc 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,10 @@ from fastapi import FastAPI from routers import routers from pkg.conf import appmeta +from models.database import init_db +from pkg.lifespan import lifespan + +lifespan.add_startup(init_db) app = FastAPI( title=appmeta.APP_NAME, @@ -9,7 +13,7 @@ app = FastAPI( version=appmeta.BackendVersion, openapi_tags=appmeta.tags_meta, license_info=appmeta.license_info, - + lifespan=lifespan.lifespan, ) for router in routers.Router: diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..fc977d1 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,30 @@ +# my_project/models/__init__.py + +from . import response + +# 将所有模型导入到这个包的命名空间中 +from .base import BaseModel +from .download import Download +from .file import File +from .folder import Folder +from .group import Group +from .node import Node +from .order import Order +from .policy import Policy +from .redeem import Redeem +from .report import Report +from .setting import Setting +from .share import Share +from .source_link import SourceLink +from .storage_pack import StoragePack +from .tag import Tag +from .task import Task +from .user import User +from .webdav import WebDAV + +# 可以定义一个 __all__ 列表来明确指定可以被 from .models import * 导入的内容 +__all__ = [ + "BaseModel", "Download", "File", "Folder", "Group", "Node", "Order", + "Policy", "Redeem", "Report", "Setting", "Share", "SourceLink", + "StoragePack", "Tag", "Task", "User", "WebDAV" +] \ No newline at end of file diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..e1fd70a --- /dev/null +++ b/models/base.py @@ -0,0 +1,9 @@ +# my_project/models/base.py + +from typing import Optional +from sqlmodel import SQLModel, Field + +class BaseModel(SQLModel): + __abstract__ = True + + id: Optional[int] = Field(default=None, primary_key=True, description="主键ID") \ No newline at end of file diff --git a/models/database.py b/models/database.py new file mode 100644 index 0000000..1764bb2 --- /dev/null +++ b/models/database.py @@ -0,0 +1,31 @@ +# my_project/database.py + +from sqlmodel import SQLModel +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel.ext.asyncio.session import AsyncSession +from sqlalchemy.orm import sessionmaker + +ASYNC_DATABASE_URL = "sqlite+aiosqlite:///database.db" + +engine = create_async_engine( + ASYNC_DATABASE_URL, + echo=True, + connect_args={"check_same_thread": False} + if ASYNC_DATABASE_URL.startswith("sqlite") + else None, + future=True, + # pool_size=POOL_SIZE, + # max_overflow=64, +) + +_async_session_factory = sessionmaker(engine, class_=AsyncSession) + +async def get_session(): + async with _async_session_factory() as session: + yield session + +async def init_db(): + """初始化数据库""" + # 创建所有表 + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) \ No newline at end of file diff --git a/models/download.py b/models/download.py new file mode 100644 index 0000000..c52ef92 --- /dev/null +++ b/models/download.py @@ -0,0 +1,56 @@ +# my_project/models/download.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + from .task import Task + from .node import Node + +class Download(BaseModel, table=True): + __tablename__ = 'downloads' + + status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载状态: 0=进行中, 1=完成, 2=错误") + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务类型") + source: str = Field(description="来源URL或标识") + total_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="总大小(字节)") + downloaded_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已下载大小(字节)") + g_id: Optional[str] = Field(default=None, index=True, description="Aria2 GID") + speed: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载速度 (bytes/s)") + parent: Optional[str] = Field(default=None, description="父任务标识") + attrs: Optional[str] = Field(default=None, description="额外属性 (JSON格式)") + error: Optional[str] = Field(default=None, description="错误信息") + dst: str = Field(description="目标存储路径") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + task_id: Optional[int] = Field(default=None, foreign_key="tasks.id", index=True, description="关联的任务ID") + node_id: int = Field(foreign_key="nodes.id", index=True, description="执行下载的节点ID") + + # 关系 + user: "User" = Relationship(back_populates="downloads") + task: Optional["Task"] = Relationship(back_populates="downloads") + node: "Node" = Relationship(back_populates="downloads") \ No newline at end of file diff --git a/models/file.py b/models/file.py new file mode 100644 index 0000000..9af8416 --- /dev/null +++ b/models/file.py @@ -0,0 +1,53 @@ +# my_project/models/file.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + from .folder import Folder + from .policy import Policy + from .source_link import SourceLink + +class File(BaseModel, table=True): + __tablename__ = 'files' + + name: str = Field(max_length=255, description="文件名") + source_name: Optional[str] = Field(default=None, description="源文件名") + size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="文件大小(字节)") + pic_info: Optional[str] = Field(default=None, max_length=255, description="图片信息(如尺寸)") + upload_session_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="分块上传会话ID") + file_metadata: Optional[str] = Field(default=None, description="文件元数据 (JSON格式)") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + folder_id: int = Field(foreign_key="folders.id", index=True, description="所在目录ID") + policy_id: int = Field(foreign_key="policies.id", index=True, description="所属存储策略ID") + + # 关系 + user: list["User"] = Relationship(back_populates="files") + folder: list["Folder"] = Relationship(back_populates="files") + policy: list["Policy"] = Relationship(back_populates="files") + source_links: list["SourceLink"] = Relationship(back_populates="file") \ No newline at end of file diff --git a/models/folder.py b/models/folder.py new file mode 100644 index 0000000..bf7e208 --- /dev/null +++ b/models/folder.py @@ -0,0 +1,52 @@ +# my_project/models/folder.py + +from typing import Optional, List, TYPE_CHECKING +from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + from .policy import Policy + from .file import File + +class Folder(BaseModel, table=True): + __tablename__ = 'folders' + __table_args__ = (UniqueConstraint("name", "parent_id", name="uq_folder_name_parent"),) + + name: str = Field(max_length=255, description="目录名") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 外键 + parent_id: Optional[int] = Field(default=None, foreign_key="folders.id", index=True, description="父目录ID") + owner_id: int = Field(foreign_key="users.id", index=True, description="所有者用户ID") + policy_id: int = Field(foreign_key="policies.id", index=True, description="所属存储策略ID") + + # 关系 + owner: "User" = Relationship(back_populates="folders") + policy: "Policy" = Relationship(back_populates="folders") + + # 自我引用关系 + parent: Optional["Folder"] = Relationship(back_populates="children", sa_relationship_kwargs={"remote_side": "Folder.id"}) + children: List["Folder"] = Relationship(back_populates="parent") + + files: List["File"] = Relationship(back_populates="folder") \ No newline at end of file diff --git a/models/group.py b/models/group.py new file mode 100644 index 0000000..c8d31b3 --- /dev/null +++ b/models/group.py @@ -0,0 +1,44 @@ +# my_project/models/group.py + +from typing import Optional, List, TYPE_CHECKING +from sqlmodel import Field, Relationship, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + +class Group(BaseModel, table=True): + __tablename__ = 'groups' + + name: str = Field(max_length=255, unique=True, description="用户组名") + policies: Optional[str] = Field(default=None, max_length=255, description="允许的策略ID列表,逗号分隔") + max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="最大存储空间(字节)") + share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许创建分享") + web_dav_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许使用WebDAV") + speed_limit: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="速度限制 (KB/s), 0为不限制") + options: Optional[str] = Field(default=None, description="其他选项 (JSON格式)") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 关系:一个组可以有多个用户 + users: List["User"] = Relationship(back_populates="group") + previous_users: List["User"] = Relationship(back_populates="previous_group") \ No newline at end of file diff --git a/models/model.py b/models/model.py deleted file mode 100644 index 824629e..0000000 --- a/models/model.py +++ /dev/null @@ -1,251 +0,0 @@ -from sqlalchemy import ( - Column, Integer, String, Text, BigInteger, Boolean, DateTime, - ForeignKey, func, text, UniqueConstraint -) -from sqlalchemy.orm import declarative_base - -Base = declarative_base() - -class BaseModel(Base): - __abstract__ = True - - id = Column(Integer, primary_key=True, comment="主键ID") - - created_at = Column( - DateTime, - server_default=func.now(), - comment="创建时间" - ) - - updated_at = Column( - DateTime, - server_default=func.now(), - onupdate=func.now(), - server_onupdate=func.now(), - comment="更新时间" - ) - - deleted_at = Column(DateTime, nullable=True, comment="软删除时间") - -class Download(BaseModel): - __tablename__ = 'downloads' - - status = Column(Integer, nullable=False, server_default='0', comment="下载状态: 0=进行中, 1=完成, 2=错误") - type = Column(Integer, nullable=False, server_default='0', comment="任务类型") - source = Column(Text, nullable=False, comment="来源URL或标识") - total_size = Column(BigInteger, nullable=False, server_default='0', comment="总大小(字节)") - downloaded_size = Column(BigInteger, nullable=False, server_default='0', comment="已下载大小(字节)") - g_id = Column(Text, index=True, comment="Aria2 GID") # GID经常用于查询,建议索引 - speed = Column(Integer, nullable=False, server_default='0', comment="下载速度 (bytes/s)") - parent = Column(Text, comment="父任务标识") - attrs = Column(Text, comment="额外属性 (JSON格式)") - error = Column(Text, comment="错误信息") - dst = Column(Text, nullable=False, comment="目标存储路径") - - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - task_id = Column(Integer, ForeignKey('tasks.id'), nullable=True, index=True, comment="关联的任务ID") - node_id = Column(Integer, ForeignKey('nodes.id'), nullable=False, index=True, comment="执行下载的节点ID") - -class File(BaseModel): - __tablename__ = 'files' - - name = Column(String(255), nullable=False, comment="文件名") - source_name = Column(Text, comment="源文件名") - size = Column(BigInteger, nullable=False, server_default='0', comment="文件大小(字节)") - pic_info = Column(String(255), comment="图片信息(如尺寸)") - upload_session_id = Column(String(255), unique=True, index=True, comment="分块上传会话ID") - metadata = Column(Text, comment="文件元数据 (JSON格式)") - - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - folder_id = Column(Integer, ForeignKey('folders.id'), nullable=False, index=True, comment="所在目录ID") - policy_id = Column(Integer, ForeignKey('policies.id'), nullable=False, index=True, comment="所属存储策略ID") - -class Folder(BaseModel): - __tablename__ = 'folders' - - name = Column(String(255), nullable=False, comment="目录名") - - parent_id = Column(Integer, ForeignKey('folders.id'), nullable=True, index=True, comment="父目录ID") - owner_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所有者用户ID") - policy_id = Column(Integer, ForeignKey('policies.id'), nullable=False, index=True, comment="所属存储策略ID") - - __table_args__ = ( - UniqueConstraint('name', 'parent_id', name='uq_folder_name_parent'), - ) - -class Group(BaseModel): - __tablename__ = 'groups' - - name = Column(String(255), nullable=False, unique=True, comment="用户组名") - policies = Column(String(255), comment="允许的策略ID列表,逗号分隔") - max_storage = Column(BigInteger, nullable=False, server_default='0', comment="最大存储空间(字节)") - - share_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否允许创建分享") - web_dav_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否允许使用WebDAV") - - speed_limit = Column(Integer, nullable=False, server_default='0', comment="速度限制 (KB/s), 0为不限制") - options = Column(Text, comment="其他选项 (JSON格式)") - -class Node(BaseModel): - __tablename__ = 'nodes' - - status = Column(Integer, nullable=False, server_default='0', comment="节点状态: 0=正常, 1=离线") - name = Column(String(255), nullable=False, unique=True, comment="节点名称") - type = Column(Integer, nullable=False, server_default='0', comment="节点类型") - server = Column(String(255), nullable=False, comment="节点地址(IP或域名)") - slave_key = Column(Text, comment="从机通讯密钥") - master_key = Column(Text, comment="主机通讯密钥") - aria2_enabled = Column(Boolean, nullable=False, server_default=text('false'), comment="是否启用Aria2") - aria2_options = Column(Text, comment="Aria2配置 (JSON格式)") - rank = Column(Integer, nullable=False, server_default='0', comment="节点排序权重") - -class Order(BaseModel): - __tablename__ = 'orders' - - order_no = Column(String(255), nullable=False, unique=True, index=True, comment="订单号,唯一") - type = Column(Integer, nullable=False, comment="订单类型") - method = Column(String(255), comment="支付方式") - product_id = Column(BigInteger, comment="商品ID") - num = Column(Integer, nullable=False, server_default='1', comment="购买数量") - name = Column(String(255), nullable=False, comment="商品名称") - price = Column(Integer, nullable=False, server_default='0', comment="订单价格(分)") - status = Column(Integer, nullable=False, server_default='0', comment="订单状态: 0=待支付, 1=已完成, 2=已取消") - - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - -class Policy(BaseModel): - __tablename__ = 'policies' - - name = Column(String(255), nullable=False, unique=True, comment="策略名称") - type = Column(String(255), nullable=False, comment="存储类型 (e.g. 'local', 's3')") - server = Column(String(255), comment="服务器地址(本地策略为路径)") - bucket_name = Column(String(255), comment="存储桶名称") - is_private = Column(Boolean, nullable=False, server_default=text('true'), comment="是否为私有空间") - base_url = Column(String(255), comment="访问文件的基础URL") - access_key = Column(Text, comment="Access Key") - secret_key = Column(Text, comment="Secret Key") - max_size = Column(BigInteger, nullable=False, server_default='0', comment="允许上传的最大文件尺寸(字节)") - auto_rename = Column(Boolean, nullable=False, server_default=text('false'), comment="是否自动重命名") - dir_name_rule = Column(String(255), comment="目录命名规则") - file_name_rule = Column(String(255), comment="文件命名规则") - is_origin_link_enable = Column(Boolean, nullable=False, server_default=text('false'), comment="是否开启源链接访问") - options = Column(Text, comment="其他选项 (JSON格式)") - -class Setting(BaseModel): - __tablename__ = 'settings' - - # 优化点: type和name的组合应该是唯一的 - type = Column(String(255), nullable=False, comment="设置类型/分组") - name = Column(String(255), nullable=False, comment="设置项名称") - value = Column(Text, comment="设置值") - - __table_args__ = ( - UniqueConstraint('type', 'name', name='uq_setting_type_name'), - ) - -class Share(BaseModel): - __tablename__ = 'shares' - - password = Column(String(255), comment="分享密码(加密后)") - is_dir = Column(Boolean, nullable=False, server_default=text('false'), comment="是否为目录分享") - source_id = Column(Integer, nullable=False, comment="源文件或目录的ID") - views = Column(Integer, nullable=False, server_default='0', comment="浏览次数") - downloads = Column(Integer, nullable=False, server_default='0', comment="下载次数") - remain_downloads = Column(Integer, comment="剩余下载次数 (NULL为不限制)") - expires = Column(DateTime, comment="过期时间 (NULL为永不过期)") - preview_enabled = Column(Boolean, nullable=False, server_default=text('true'), comment="是否允许预览") - source_name = Column(String(255), index=True, comment="源名称(冗余字段,便于展示)") - score = Column(Integer, nullable=False, server_default='0', comment="分享评分/权重") - - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="创建分享的用户ID") - -class Task(BaseModel): - __tablename__ = 'tasks' - - status = Column(Integer, nullable=False, server_default='0', comment="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误") - type = Column(Integer, nullable=False, comment="任务类型") - progress = Column(Integer, nullable=False, server_default='0', comment="任务进度 (0-100)") - error = Column(Text, comment="错误信息") - props = Column(Text, comment="任务属性 (JSON格式)") - - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - -class User(BaseModel): - __tablename__ = 'users' - - email = Column(String(100), nullable=False, unique=True, index=True, comment="用户邮箱,唯一") - nick = Column(String(50), comment="用户昵称") - password = Column(String(255), nullable=False, comment="用户密码(加密后)") - status = Column(Integer, nullable=False, server_default='0', comment="用户状态: 0=正常, 1=未激活, 2=封禁") - storage = Column(BigInteger, nullable=False, server_default='0', comment="已用存储空间(字节)") - two_factor = Column(String(255), comment="两步验证密钥") - avatar = Column(String(255), comment="头像地址") - options = Column(Text, comment="用户个人设置 (JSON格式)") - authn = Column(Text, comment="WebAuthn 凭证") - open_id = Column(String(255), unique=True, index=True, nullable=True, comment="第三方登录OpenID") - score = Column(Integer, nullable=False, server_default='0', comment="用户积分") - group_expires = Column(DateTime, comment="当前用户组过期时间") - phone = Column(String(255), unique=True, nullable=True, index=True, comment="手机号") - - group_id = Column(Integer, ForeignKey('groups.id'), nullable=False, index=True, comment="所属用户组ID") - previous_group_id = Column(Integer, ForeignKey('groups.id'), nullable=True, comment="之前的用户组ID(用于过期后恢复)") - -class Redeem(BaseModel): - __tablename__ = 'redeems' - - type = Column(Integer, nullable=False, comment="兑换码类型") - product_id = Column(BigInteger, comment="关联的商品/权益ID") - num = Column(Integer, nullable=False, server_default='1', comment="可兑换数量/时长等") - code = Column(Text, nullable=False, unique=True, index=True, comment="兑换码,唯一") - used = Column(Boolean, nullable=False, server_default=text('false'), comment="是否已使用") - -class Report(BaseModel): - __tablename__ = 'reports' - - share_id = Column(Integer, ForeignKey('shares.id'), index=True, nullable=False, comment="被举报的分享ID") - reason = Column(Integer, nullable=False, comment="举报原因代码") - description = Column(String(255), comment="补充描述") - -class SourceLink(BaseModel): - __tablename__ = 'source_links' - - file_id = Column(Integer, ForeignKey('files.id'), nullable=False, index=True, comment="关联的文件ID") - name = Column(String(255), nullable=False, comment="链接名称") - downloads = Column(Integer, nullable=False, server_default='0', comment="通过此链接的下载次数") - -class StoragePack(BaseModel): - __tablename__ = 'storage_packs' - - name = Column(String(255), nullable=False, comment="容量包名称") - active_time = Column(DateTime, comment="激活时间") - expired_time = Column(DateTime, index=True, comment="过期时间") - size = Column(BigInteger, nullable=False, comment="容量包大小(字节)") - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - -class Tag(BaseModel): - __tablename__ = 'tags' - - name = Column(String(255), nullable=False, comment="标签名称") - icon = Column(String(255), comment="标签图标") - color = Column(String(255), comment="标签颜色") - type = Column(Integer, nullable=False, server_default='0', comment="标签类型: 0=手动, 1=自动") - expression = Column(Text, comment="自动标签的匹配表达式") - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - - __table_args__ = ( - UniqueConstraint('name', 'user_id', name='uq_tag_name_user'), - ) - -class WebDAV(BaseModel): - __tablename__ = 'webdavs' - - name = Column(String(255), nullable=False, comment="WebDAV账户名") - password = Column(String(255), nullable=False, comment="WebDAV密码(加密后)") - root = Column(Text, nullable=False, server_default="'/'", comment="根目录路径") - readonly = Column(Boolean, nullable=False, server_default=text('false'), comment="是否只读") - use_proxy = Column(Boolean, nullable=False, server_default=text('false'), comment="是否使用代理下载") - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, index=True, comment="所属用户ID") - - __table_args__ = ( - UniqueConstraint('name', 'user_id', name='uq_webdav_name_user'), - ) \ No newline at end of file diff --git a/models/node.py b/models/node.py new file mode 100644 index 0000000..76b3949 --- /dev/null +++ b/models/node.py @@ -0,0 +1,45 @@ +# my_project/models/node.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .download import Download + +class Node(BaseModel, table=True): + __tablename__ = 'nodes' + + status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点状态: 0=正常, 1=离线") + name: str = Field(max_length=255, unique=True, description="节点名称") + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点类型") + server: str = Field(max_length=255, description="节点地址(IP或域名)") + slave_key: Optional[str] = Field(default=None, description="从机通讯密钥") + master_key: Optional[str] = Field(default=None, description="主机通讯密钥") + aria2_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否启用Aria2") + aria2_options: Optional[str] = Field(default=None, description="Aria2配置 (JSON格式)") + rank: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点排序权重") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 关系 + downloads: list["Download"] = Relationship(back_populates="node") \ No newline at end of file diff --git a/models/order.py b/models/order.py new file mode 100644 index 0000000..5aa6e1a --- /dev/null +++ b/models/order.py @@ -0,0 +1,47 @@ +# my_project/models/order.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + +class Order(BaseModel, table=True): + __tablename__ = 'orders' + + order_no: str = Field(max_length=255, unique=True, index=True, description="订单号,唯一") + type: int = Field(description="订单类型") + method: Optional[str] = Field(default=None, max_length=255, description="支付方式") + product_id: Optional[int] = Field(default=None, description="商品ID") + num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="购买数量") + name: str = Field(max_length=255, description="商品名称") + price: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单价格(分)") + status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="订单状态: 0=待支付, 1=已完成, 2=已取消") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + + # 关系 + user: "User" = Relationship(back_populates="orders") \ No newline at end of file diff --git a/models/policy.py b/models/policy.py new file mode 100644 index 0000000..2b0652b --- /dev/null +++ b/models/policy.py @@ -0,0 +1,52 @@ +# my_project/models/policy.py + +from typing import Optional, List, TYPE_CHECKING +from sqlmodel import Field, Relationship, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .file import File + from .folder import Folder + +class Policy(BaseModel, table=True): + __tablename__ = 'policies' + + name: str = Field(max_length=255, unique=True, description="策略名称") + type: str = Field(max_length=255, description="存储类型 (e.g. 'local', 's3')") + server: Optional[str] = Field(default=None, max_length=255, description="服务器地址(本地策略为路径)") + bucket_name: Optional[str] = Field(default=None, max_length=255, description="存储桶名称") + is_private: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否为私有空间") + base_url: Optional[str] = Field(default=None, max_length=255, description="访问文件的基础URL") + access_key: Optional[str] = Field(default=None, description="Access Key") + secret_key: Optional[str] = Field(default=None, description="Secret Key") + max_size: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="允许上传的最大文件尺寸(字节)") + auto_rename: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否自动重命名") + dir_name_rule: Optional[str] = Field(default=None, max_length=255, description="目录命名规则") + file_name_rule: Optional[str] = Field(default=None, max_length=255, description="文件命名规则") + is_origin_link_enable: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否开启源链接访问") + options: Optional[str] = Field(default=None, description="其他选项 (JSON格式)") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 关系 + files: List["File"] = Relationship(back_populates="policy") + folders: List["Folder"] = Relationship(back_populates="policy") \ No newline at end of file diff --git a/models/redeem.py b/models/redeem.py new file mode 100644 index 0000000..eb40758 --- /dev/null +++ b/models/redeem.py @@ -0,0 +1,35 @@ +# my_project/models/redeem.py + +from typing import Optional +from sqlmodel import Field, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +class Redeem(BaseModel, table=True): + __tablename__ = 'redeems' + + type: int = Field(description="兑换码类型") + product_id: Optional[int] = Field(default=None, description="关联的商品/权益ID") + num: int = Field(default=1, sa_column_kwargs={"server_default": "1"}, description="可兑换数量/时长等") + code: str = Field(unique=True, index=True, description="兑换码,唯一") + used: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否已使用") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) \ No newline at end of file diff --git a/models/report.py b/models/report.py new file mode 100644 index 0000000..f64a27a --- /dev/null +++ b/models/report.py @@ -0,0 +1,41 @@ +# my_project/models/report.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .share import Share + +class Report(BaseModel, table=True): + __tablename__ = 'reports' + + reason: int = Field(description="举报原因代码") + description: Optional[str] = Field(default=None, max_length=255, description="补充描述") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + # 外键 + share_id: int = Field(foreign_key="shares.id", index=True, description="被举报的分享ID") + + # 关系 + share: "Share" = Relationship(back_populates="reports") \ No newline at end of file diff --git a/models/response.py b/models/response.py index c00211d..824ad4e 100644 --- a/models/response.py +++ b/models/response.py @@ -1,7 +1,20 @@ from pydantic import BaseModel, Field -from typing import Union, Optional +from typing import Literal, Union, Optional +from uuid import uuid4 class ResponseModel(BaseModel): code: int = Field(default=0, description="系统内部状态码, 0表示成功,其他表示失败", lt=60000, gt=0) data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据") - msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示") \ No newline at end of file + msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示") + instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID,用于标识请求的唯一性") + +class SiteConfigModel(ResponseModel): + title: str = Field(default="DiskNext", description="网站标题") + themes: dict = Field(default_factory=dict, description="网站主题配置") + default_theme: str = Field(default="default", description="默认主题RGB色号") + site_notice: Optional[str] = Field(default=None, description="网站公告") + user: dict = Field(default_factory=dict, description="用户信息") + logo_light: Optional[str] = Field(default=None, description="网站Logo URL") + logo_dark: Optional[str] = Field(default=None, description="网站Logo URL(深色模式)") + captcha_type: Literal['none', 'default', 'gcaptcha', 'cloudflare turnstile'] = Field(default='none', description="验证码类型") + captcha_key: Optional[str] = Field(default=None, description="验证码密钥") \ No newline at end of file diff --git a/models/setting.py b/models/setting.py new file mode 100644 index 0000000..75fe289 --- /dev/null +++ b/models/setting.py @@ -0,0 +1,34 @@ +# my_project/models/setting.py + +from typing import Optional +from sqlmodel import Field, UniqueConstraint, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +class Setting(BaseModel, table=True): + __tablename__ = 'settings' + __table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),) + + type: str = Field(max_length=255, description="设置类型/分组") + name: str = Field(max_length=255, description="设置项名称") + value: Optional[str] = Field(default=None, description="设置值") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) \ No newline at end of file diff --git a/models/share.py b/models/share.py new file mode 100644 index 0000000..c37611d --- /dev/null +++ b/models/share.py @@ -0,0 +1,61 @@ +# my_project/models/share.py + +from typing import Optional, TYPE_CHECKING +from datetime import datetime +from sqlmodel import Field, Relationship, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + from .report import Report + +class Share(BaseModel, table=True): + __tablename__ = 'shares' + + password: Optional[str] = Field(default=None, max_length=255, description="分享密码(加密后)") + is_dir: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否为目录分享") + source_id: int = Field(description="源文件或目录的ID") + views: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="浏览次数") + downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载次数") + remain_downloads: Optional[int] = Field(default=None, description="剩余下载次数 (NULL为不限制)") + expires: Optional[datetime] = Field(default=None, description="过期时间 (NULL为永不过期)") + preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否允许预览") + source_name: Optional[str] = Field(default=None, max_length=255, index=True, description="源名称(冗余字段,便于展示)") + score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="兑换此分享所需的积分") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="创建分享的用户ID") + + # 关系 + user: "User" = Relationship(back_populates="shares") + reports: list["Report"] = Relationship(back_populates="share") \ No newline at end of file diff --git a/models/source_link.py b/models/source_link.py new file mode 100644 index 0000000..8142ab1 --- /dev/null +++ b/models/source_link.py @@ -0,0 +1,50 @@ +# my_project/models/source_link.py + +from typing import TYPE_CHECKING, Optional +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .file import File + +class SourceLink(BaseModel, table=True): + __tablename__ = 'source_links' + + name: str = Field(max_length=255, description="链接名称") + downloads: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="通过此链接的下载次数") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + file_id: int = Field(foreign_key="files.id", index=True, description="关联的文件ID") + + # 关系 + file: "File" = Relationship(back_populates="source_links") \ No newline at end of file diff --git a/models/storage_pack.py b/models/storage_pack.py new file mode 100644 index 0000000..a843593 --- /dev/null +++ b/models/storage_pack.py @@ -0,0 +1,52 @@ +# my_project/models/storage_pack.py + +from typing import Optional, TYPE_CHECKING +from datetime import datetime +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + +class StoragePack(BaseModel, table=True): + __tablename__ = 'storage_packs' + + name: str = Field(max_length=255, description="容量包名称") + active_time: Optional[datetime] = Field(default=None, description="激活时间") + expired_time: Optional[datetime] = Field(default=None, index=True, description="过期时间") + size: int = Field(description="容量包大小(字节)") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + + # 关系 + user: "User" = Relationship(back_populates="storage_packs") \ No newline at end of file diff --git a/models/tag.py b/models/tag.py new file mode 100644 index 0000000..3d7f8fc --- /dev/null +++ b/models/tag.py @@ -0,0 +1,54 @@ +# my_project/models/tag.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + +class Tag(BaseModel, table=True): + __tablename__ = 'tags' + __table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),) + + name: str = Field(max_length=255, description="标签名称") + icon: Optional[str] = Field(default=None, max_length=255, description="标签图标") + color: Optional[str] = Field(default=None, max_length=255, description="标签颜色") + type: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="标签类型: 0=手动, 1=自动") + expression: Optional[str] = Field(default=None, description="自动标签的匹配表达式") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + + # 关系 + user: "User" = Relationship(back_populates="tags") \ No newline at end of file diff --git a/models/task.py b/models/task.py new file mode 100644 index 0000000..eea0df3 --- /dev/null +++ b/models/task.py @@ -0,0 +1,55 @@ +# my_project/models/task.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + from .download import Download + +class Task(BaseModel, table=True): + __tablename__ = 'tasks' + + status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误") + type: int = Field(description="任务类型") + progress: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务进度 (0-100)") + error: Optional[str] = Field(default=None, description="错误信息") + props: Optional[str] = Field(default=None, description="任务属性 (JSON格式)") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + user_id: "int" = Field(foreign_key="users.id", index=True, description="所属用户ID") + + # 关系 + user: "User" = Relationship(back_populates="tasks") + downloads: list["Download"] = Relationship(back_populates="task") \ No newline at end of file diff --git a/models/user.py b/models/user.py new file mode 100644 index 0000000..1f8aa70 --- /dev/null +++ b/models/user.py @@ -0,0 +1,83 @@ +# my_project/models/user.py + +from typing import Optional, TYPE_CHECKING +from datetime import datetime +from sqlmodel import Field, Relationship, Column, func, DateTime +from .base import BaseModel + +# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入 +if TYPE_CHECKING: + from .group import Group + from .download import Download + from .file import File + from .folder import Folder + from .order import Order + from .share import Share + from .storage_pack import StoragePack + from .tag import Tag + from .task import Task + from .webdav import WebDAV + +class User(BaseModel, table=True): + __tablename__ = 'users' + + email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一") + nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称") + password: str = Field(max_length=255, description="用户密码(加密后)") + status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户状态: 0=正常, 1=未激活, 2=封禁") + storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)") + two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥") + avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址") + options: Optional[str] = Field(default=None, description="用户个人设置 (JSON格式)") + authn: Optional[str] = Field(default=None, description="WebAuthn 凭证") + open_id: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="第三方登录OpenID") + score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户积分") + group_expires: Optional[datetime] = Field(default=None, description="当前用户组过期时间") + phone: Optional[str] = Field(default=None, max_length=255, unique=True, index=True, description="手机号") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + group_id: int = Field(foreign_key="groups.id", index=True, description="所属用户组ID") + previous_group_id: Optional[int] = Field(default=None, foreign_key="groups.id", description="之前的用户组ID(用于过期后恢复)") + + # 关系 + group: "Group" = Relationship(back_populates="users") + previous_group: Optional["Group"] = Relationship(back_populates="previous_users") + + downloads: list["Download"] = Relationship(back_populates="user") + files: list["File"] = Relationship(back_populates="user") + folders: list["Folder"] = Relationship(back_populates="owner") + orders: list["Order"] = Relationship(back_populates="user") + shares: list["Share"] = Relationship(back_populates="user") + storage_packs: list["StoragePack"] = Relationship(back_populates="user") + tags: list["Tag"] = Relationship(back_populates="user") + tasks: list["Task"] = Relationship(back_populates="user") + webdavs: list["WebDAV"] = Relationship(back_populates="user") \ No newline at end of file diff --git a/models/webdav.py b/models/webdav.py new file mode 100644 index 0000000..cb3af41 --- /dev/null +++ b/models/webdav.py @@ -0,0 +1,54 @@ +# my_project/models/webdav.py + +from typing import Optional, TYPE_CHECKING +from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime +from .base import BaseModel +from datetime import datetime + +if TYPE_CHECKING: + from .user import User + +class WebDAV(BaseModel, table=True): + __tablename__ = 'webdavs' + __table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),) + + name: str = Field(max_length=255, description="WebDAV账户名") + password: str = Field(max_length=255, description="WebDAV密码(加密后)") + root: str = Field(default="/", sa_column_kwargs={"server_default": "'/'"}, description="根目录路径") + readonly: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否只读") + use_proxy: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否使用代理下载") + created_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + comment="创建时间", + ), + ) + + updated_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=False, + server_default=func.now(), + onupdate=func.now(), + comment="更新时间", + ), + ) + + delete_at: Optional[datetime] = Field( + default=None, + sa_column=Column( + DateTime, + nullable=True, + comment="删除时间", + ), + ) + + # 外键 + user_id: int = Field(foreign_key="users.id", index=True, description="所属用户ID") + + # 关系 + user: "User" = Relationship(back_populates="webdavs") \ No newline at end of file diff --git a/pkg/lifespan/lifespan.py b/pkg/lifespan/lifespan.py new file mode 100644 index 0000000..57824d7 --- /dev/null +++ b/pkg/lifespan/lifespan.py @@ -0,0 +1,39 @@ +from fastapi import FastAPI +from contextlib import asynccontextmanager + +__on_startup: list[callable] = [] +__on_shutdown: list[callable] = [] + +def add_startup(func: callable): + """ + 注册一个函数,在应用启动时调用。 + + :param func: 需要注册的函数。它应该是一个异步函数。 + """ + __on_startup.append(func) + +def add_shutdown(func: callable): + """ + 注册一个函数,在应用关闭时调用。 + + :param func: 需要注册的函数。 + """ + __on_shutdown.append(func) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + 应用程序的生命周期管理器。 + + 此函数在应用启动时执行所有注册的启动函数, + 并在应用关闭时执行所有注册的关闭函数。 + """ + # Execute all startup functions + for func in __on_startup: + await func() + + yield + + # Execute all shutdown functions + for func in __on_shutdown: + await func() \ No newline at end of file diff --git a/pkg/log/log.py b/pkg/log/log.py new file mode 100644 index 0000000..50718af --- /dev/null +++ b/pkg/log/log.py @@ -0,0 +1,211 @@ +from rich import print +from rich.console import Console +from rich.markdown import Markdown +from configparser import ConfigParser +from typing import Literal, Optional, Dict, Union +from enum import Enum +import time +import os +import inspect + +class LogLevelEnum(str, Enum): + DEBUG = 'debug' + INFO = 'info' + WARNING = 'warning' + ERROR = 'error' + SUCCESS = 'success' + +# 默认日志级别 +LogLevel = LogLevelEnum.INFO +# 日志文件路径 +LOG_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') +# 是否启用文件日志 +ENABLE_FILE_LOG = False + +def set_log_level(level: Union[str, LogLevelEnum]) -> None: + """设置日志级别""" + global LogLevel + if isinstance(level, str): + try: + LogLevel = LogLevelEnum(level.lower()) + except ValueError: + print(f"[bold red]无效的日志级别: {level},使用默认级别: {LogLevel}[/bold red]") + else: + LogLevel = level + +def enable_file_log(enable: bool = True) -> None: + """启用或禁用文件日志""" + global ENABLE_FILE_LOG + ENABLE_FILE_LOG = enable + if enable and not os.path.exists(LOG_FILE_PATH): + try: + os.makedirs(LOG_FILE_PATH) + except Exception as e: + print(f"[bold red]创建日志目录失败: {e}[/bold red]") + ENABLE_FILE_LOG = False + +def truncate_path(full_path: str, marker: str = "HeyAuth") -> str: + """截断路径,只保留从marker开始的部分""" + try: + marker_index = full_path.find(marker) + if marker_index != -1: + return '.' + full_path[marker_index + len(marker):] + return full_path + except Exception: + return full_path + +def get_caller_info(depth: int = 2) -> tuple: + """获取调用者信息""" + try: + frame = inspect.currentframe() + # 向上查找指定深度的调用帧 + for _ in range(depth): + if frame.f_back is None: + break + frame = frame.f_back + + filename = frame.f_code.co_filename + lineno = frame.f_lineno + return truncate_path(filename), lineno + except Exception: + return "", 0 + finally: + # 确保引用被释放 + del frame + +def log(level: str = 'debug', message: str = ''): + """ + 输出日志 + --- + 通过传入的`level`和`message`参数,输出不同级别的日志信息。
+ `level`参数为日志级别,支持`红色error`、`紫色info`、`绿色success`、`黄色warning`、`淡蓝色debug`。
+ `message`参数为日志信息。
+ """ + level_colors: Dict[str, str] = { + 'debug': '[bold cyan][DEBUG][/bold cyan]', + 'info': '[bold blue][INFO][/bold blue]', + 'warning': '[bold yellow][WARN][/bold yellow]', + 'error': '[bold red][ERROR][/bold red]', + 'success': '[bold green][SUCCESS][/bold green]' + } + + level_value = level.lower() + lv = level_colors.get(level_value, '[bold magenta][UNKNOWN][/bold magenta]') + + # 获取调用者信息 + filename, lineno = get_caller_info(3) # 考虑lambda调用和包装函数,深度为3 + timestamp = time.strftime('%Y/%m/%d %H:%M:%S %p', time.localtime()) + log_message = f"{lv}\t{timestamp} [bold]From {filename}, line {lineno}[/bold] {message}" + + # 根据日志级别判断是否输出 + global LogLevel + should_log = False + + if level_value == 'debug' and LogLevel == LogLevelEnum.DEBUG: + should_log = True + elif level_value == 'info' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO]: + should_log = True + elif level_value == 'warning' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO, LogLevelEnum.WARNING]: + should_log = True + elif level_value == 'error': + should_log = True + elif level_value == 'success': + should_log = False + + if should_log: + print(log_message) + + # 文件日志记录 + if ENABLE_FILE_LOG: + try: + # 去除rich格式化标记 + clean_message = f"{level_value.upper()}\t{timestamp} From {filename}, line {lineno} {message}" + log_file = os.path.join(LOG_FILE_PATH, f"{time.strftime('%Y%m%d')}.log") + with open(log_file, 'a', encoding='utf-8') as f: + f.write(f"{clean_message}\n") + except Exception as e: + print(f"[bold red]写入日志文件失败: {e}[/bold red]") + +# 便捷日志函数 +debug = lambda message: log('debug', message) +info = lambda message: log('info', message) +warning = lambda message: log('warn', message) +error = lambda message: log('error', message) +success = lambda message: log('success', message) + +def load_config(config_path: str) -> bool: + """从配置文件加载日志配置""" + try: + if not os.path.exists(config_path): + return False + + config = ConfigParser() + config.read(config_path, encoding='utf-8') + + if 'log' in config: + log_config = config['log'] + if 'level' in log_config: + set_log_level(log_config['level']) + if 'file_log' in log_config: + enable_file_log(log_config.getboolean('file_log')) + if 'log_path' in log_config: + global LOG_FILE_PATH + custom_path = log_config['log_path'] + if os.path.exists(custom_path) or os.makedirs(custom_path, exist_ok=True): + LOG_FILE_PATH = custom_path + return True + except Exception as e: + error(f"加载日志配置失败: {e}") + return False + +def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'): + """ + 输出标题 + --- + 通过传入的`title`参数,输出一个整行的标题。
+ `title`参数为标题内容。
+ """ + try: + console = Console() + markdown_sizes = { + 'h1': '# ', + 'h2': '## ', + 'h3': '### ', + 'h4': '#### ', + 'h5': '##### ' + } + + markdown_tag = markdown_sizes.get(size, '# ') + console.print(Markdown(markdown_tag + title)) + except Exception as e: + error(f"输出标题失败: {e}") + finally: + if 'console' in locals(): + del console + +if True: + pass + + +if __name__ == '__main__': + # 测试代码 + title('海枫授权系统 日志组件测试', 'h1') + title('测试h2标题', 'h2') + title('测试h3标题', 'h3') + title('测试h4标题', 'h4') + title('测试h5标题', 'h5') + + print("\n默认日志级别(INFO)测试:") + debug('这是一个debug日志') # 不会显示 + info('这是一个info日志') + warning('这是一个warning日志') + error('这是一个error日志') + success('这是一个success日志') + + print("\n设置为DEBUG级别测试:") + set_log_level(LogLevelEnum.DEBUG) + debug('这是一个debug日志') # 现在会显示 + + print("\n启用文件日志测试:") + enable_file_log() + info('此日志将同时记录到文件') \ No newline at end of file diff --git a/test_main.py b/test_main.py deleted file mode 100644 index 074f5fc..0000000 --- a/test_main.py +++ /dev/null @@ -1,15 +0,0 @@ -from fastapi.testclient import TestClient - -from main import app - -client = TestClient(app) - - -def test_read_main(): - from pkg.conf.appmeta import BackendVersion - response = client.get("/api/site/ping") - assert response.status_code == 200 - assert response.json() == { - "code": 0, - 'data': BackendVersion, - 'msg': None} \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d3a5d7d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,9 @@ +""" +Pytest配置文件 +""" +import pytest +import os +import sys + +# 添加项目根目录到Python路径,确保可以导入项目模块 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..cb3fe33 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,8 @@ +from models import database + +import pytest + +@pytest.mark.asyncio +async def test_initialize_db(): + """Fixture to initialize the database before tests.""" + await database.init_db() \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..e1e9125 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,23 @@ +from fastapi.testclient import TestClient + +from main import app + +client = TestClient(app) + + +def test_read_main(): + from pkg.conf.appmeta import BackendVersion + import uuid + + response = client.get("/api/site/ping") + json_response = response.json() + + assert response.status_code == 200 + assert json_response['code'] == 0 + assert json_response['data'] == BackendVersion + assert json_response['msg'] is None + assert 'instance_id' in json_response + try: + uuid.UUID(json_response['instance_id'], version=4) + except (ValueError, TypeError): + assert False, f"instance_id is not a valid UUID4: {json_response['instance_id']}" \ No newline at end of file