From 33cca4e2716fc926601ed8cac428316b49a83204 Mon Sep 17 00:00:00 2001 From: Yuerchu Date: Tue, 15 Jul 2025 17:32:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E8=BF=81=E7=A7=BB=E9=97=AE=E9=A2=98=E3=80=81=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E7=8E=AF=E5=A2=83=E5=8F=98=E9=87=8F=E8=AF=BB=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- main.py | 12 ++-- models/database.py | 9 +-- models/group.py | 143 ++++++++++++++++++++++++++++++++++++++--- models/migration.py | 74 ++++++++++++++++++--- models/user.py | 129 ++++++++++++++++++++++++++++++++++++- pkg/conf/appmeta.py | 18 ++++++ tests/test_db_group.py | 37 +++++++++++ tests/test_db_user.py | 21 +++++- tests/test_main.py | 25 +++++-- 10 files changed, 432 insertions(+), 39 deletions(-) create mode 100644 tests/test_db_group.py diff --git a/.gitignore b/.gitignore index c52ca08..3d4415e 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ __pycache__/ *.code-workspace -*.db \ No newline at end of file +*.db +.env \ No newline at end of file diff --git a/main.py b/main.py index 6c110ee..e0d1944 100644 --- a/main.py +++ b/main.py @@ -2,13 +2,13 @@ from fastapi import FastAPI from routers import routers from pkg.conf import appmeta from models.database import init_db -from models.migration import init_default_settings +from models.migration import migration from pkg.lifespan import lifespan from pkg.JWT import jwt # 添加初始化数据库启动项 lifespan.add_startup(init_db) -lifespan.add_startup(init_default_settings) +lifespan.add_startup(migration) lifespan.add_startup(jwt.load_secret_key) # 创建应用实例并设置元数据 @@ -20,6 +20,7 @@ app = FastAPI( openapi_tags=appmeta.tags_meta, license_info=appmeta.license_info, lifespan=lifespan.lifespan, + debug=appmeta.debug, ) # 挂载路由 @@ -38,5 +39,8 @@ for router in routers.Router: # 启动时打印欢迎信息 if __name__ == "__main__": import uvicorn - # uvicorn.run(app=app, host="0.0.0.0", port=5213) # 生产环境 - uvicorn.run(app='main:app', host="0.0.0.0", port=5213, reload=True) # 开发环境 \ No newline at end of file + + if appmeta.debug: + uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True) + else: + uvicorn.run(app=app, host=appmeta.host, port=appmeta.port) \ No newline at end of file diff --git a/models/database.py b/models/database.py index e24f438..9c5bf5f 100644 --- a/models/database.py +++ b/models/database.py @@ -1,16 +1,17 @@ # my_project/database.py from sqlmodel import SQLModel -from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession +from pkg.conf import appmeta from sqlalchemy.orm import sessionmaker from typing import AsyncGenerator -ASYNC_DATABASE_URL = "sqlite+aiosqlite:///:memory:" +ASYNC_DATABASE_URL = appmeta.database_url -engine = create_async_engine( +engine: AsyncEngine = create_async_engine( ASYNC_DATABASE_URL, - echo=True, + echo=appmeta.debug, connect_args={"check_same_thread": False} if ASYNC_DATABASE_URL.startswith("sqlite") else None, diff --git a/models/group.py b/models/group.py index b6ae66e..9ade27e 100644 --- a/models/group.py +++ b/models/group.py @@ -1,5 +1,6 @@ # my_project/models/group.py +from tokenize import group from typing import Optional, List, TYPE_CHECKING from sqlmodel import Field, Relationship, text, Column, func, DateTime from .base import BaseModel @@ -51,10 +52,17 @@ class Group(BaseModel, table=True): @staticmethod async def create( - group: "Group" + group: Optional["Group"] = None, + name: Optional[str] = None, + policies: Optional[str] = None, + max_storage: int = 0, + share_enabled: bool = False, + web_dav_enabled: bool = False, + speed_limit: int = 0, + options: Optional[dict] = None, ) -> "Group": """ - 向数据库内添加用户组。 + 向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。 :param group: 用户组对象 :type group: Group @@ -63,12 +71,27 @@ class Group(BaseModel, table=True): """ from .database import get_session + import json + + if not group: + + if not name: + raise ValueError("Group name is required.") + + group = Group( + name=name, + policies=policies, + max_storage=max_storage, + share_enabled=share_enabled, + web_dav_enabled=web_dav_enabled, + speed_limit=speed_limit, + options=json.dumps(options) if options else None, + ) async for session in get_session(): try: session.add(group) await session.commit() - await session.refresh(group) except Exception as e: await session.rollback() raise e @@ -96,11 +119,111 @@ class Group(BaseModel, table=True): return None async for session in get_session(): - statement = select(Group).where(Group.id == id) - result = await session.exec(statement) - group = result.one_or_none() - - if group: + try: + statement = select(Group).where(Group.id == id) + result = await session.exec(statement) + group = result.one_or_none() + + if group: + return group + else: + return None + except Exception as e: + raise e + + @staticmethod + async def set( + id: int, + name: Optional[str] = None, + policies: Optional[str] = None, + max_storage: Optional[int] = None, + share_enabled: Optional[bool] = None, + web_dav_enabled: Optional[bool] = None, + speed_limit: Optional[int] = None, + options: Optional[str] = None + ) -> Optional["Group"]: + """ + 更新用户组信息。 + + :param id: 用户组ID + :type id: int + :param name: 用户组名 + :type name: Optional[str] + :param policies: 允许的策略ID列表,逗号分隔 + :type policies: Optional[str] + :param max_storage: 最大存储空间(字节) + :type max_storage: Optional[int] + :param share_enabled: 是否允许创建分享 + :type share_enabled: Optional[bool] + :param web_dav_enabled: 是否允许使用WebDAV + :type web_dav_enabled: Optional[bool] + :param speed_limit: 速度限制 (KB/s), 0为不限制 + :type speed_limit: Optional[int] + :param options: 其他选项 (JSON格式) + :type options: Optional[str] + :return: 更新后的用户组对象或 None + :rtype: Optional[Group] + """ + + from .database import get_session + from sqlmodel import select + + async for session in get_session(): + try: + statement = select(Group).where(Group.id == id) + result = await session.exec(statement) + group = result.one_or_none() + + if not group: + raise ValueError(f"Group with id {id} not found.") + + if name is not None: + group.name = name + if policies is not None: + group.policies = policies + if max_storage is not None: + group.max_storage = max_storage + if share_enabled is not None: + group.share_enabled = share_enabled + if web_dav_enabled is not None: + group.web_dav_enabled = web_dav_enabled + if speed_limit is not None: + group.speed_limit = speed_limit + if options is not None: + group.options = options + + session.add(group) + await session.commit() + return group - else: - return None \ No newline at end of file + except Exception as e: + await session.rollback() + raise e + + @staticmethod + async def delete( + id: int + ) -> None: + """ + 删除用户组。 + + :param id: 用户组ID + :type id: int + """ + from .database import get_session + from sqlmodel import select + + async for session in get_session(): + try: + statement = select(Group).where(Group.id == id) + result = await session.exec(statement) + group = result.one_or_none() + + if group is None: + raise ValueError(f"Group with id {id} not found.") + + await session.delete(group) + await session.commit() + except Exception as e: + await session.rollback() + raise e \ No newline at end of file diff --git a/models/migration.py b/models/migration.py index 5d4a6d8..0099b74 100644 --- a/models/migration.py +++ b/models/migration.py @@ -1,6 +1,22 @@ from .setting import Setting from pkg.conf.appmeta import BackendVersion from pkg.password.pwd import Password +from pkg.log import log + +async def migration() -> None: + """ + 数据库迁移函数,初始化默认设置和用户组。 + + :return: None + """ + + log.info('开始进行数据库初始化...') + + await init_default_settings() + await init_default_group() + await init_default_user() + + log.info('数据库初始化结束') default_settings: list[Setting] = [ Setting(name="siteURL", value="http://localhost", type="basic"), @@ -101,6 +117,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti async def init_default_settings() -> None: from .setting import Setting + log.info('初始化设置...') + try: # 检查是否已经存在版本设置 ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}") @@ -118,10 +136,12 @@ async def init_default_settings() -> None: async def init_default_group() -> None: from .group import Group + log.info('初始化用户组...') + try: # 未找到初始管理组时,则创建 - if not Group.get(id=1): - Group.add( + if not await Group.get(id=1): + await Group.create( name="管理员", max_storage=1 * 1024 * 1024 * 1024, # 1GB share_enabled=True, @@ -134,12 +154,12 @@ async def init_default_group() -> None: } ) except Exception as e: - raise RuntimeError(f"无法创建管理员用户组: {e}") from e + raise RuntimeError(f"无法创建管理员用户组: {e}") try: # 未找到初始注册会员时,则创建 - if not Group.get(id=2): - Group.add( + if not await Group.get(id=2): + await Group.create( name="注册会员", max_storage=1 * 1024 * 1024 * 1024, # 1GB share_enabled=True, @@ -149,12 +169,12 @@ async def init_default_group() -> None: } ) except Exception as e: - raise RuntimeError(f"无法创建初始注册会员用户组: {e}") from e + raise RuntimeError(f"无法创建初始注册会员用户组: {e}") try: # 未找到初始游客组时,则创建 - if not Group.get(id=3): - Group.add( + if not await Group.get(id=3): + await Group.create( name="游客", policies="[]", share_enabled=False, @@ -164,4 +184,40 @@ async def init_default_group() -> None: } ) except Exception as e: - raise RuntimeError(f"无法创建初始游客用户组: {e}") from e + raise RuntimeError(f"无法创建初始游客用户组: {e}") + +async def init_default_user() -> None: + + log.info('初始化管理员用户...') + + from .user import User + from .group import Group + + # 检查管理员用户是否存在 + admin_user = await User.get(id=1) + + if not admin_user: + # 创建初始管理员用户 + + # 获取管理员组 + admin_group = await Group.get(id=1) + if not admin_group: + raise RuntimeError("管理员用户组不存在,无法创建管理员用户") + + # 生成管理员密码 + from pkg.password.pwd import Password + admin_password = Password.generate(8) + hashed_admin_password = Password.hash(admin_password) + + admin_user = User( + email="admin@yxqi.cn", + nick="admin", + status=1, # 正常状态 + group_id=admin_group.id, + password=hashed_admin_password, + ) + + admin_user = await User.create(admin_user) + + log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]') + log.info(f'初始管理员密码:[bold]{admin_password}[/bold]') \ No newline at end of file diff --git a/models/user.py b/models/user.py index 4a571f1..3a9ca0a 100644 --- a/models/user.py +++ b/models/user.py @@ -88,8 +88,20 @@ class User(BaseModel, table=True): tasks: list["Task"] = Relationship(back_populates="user") webdavs: list["WebDAV"] = Relationship(back_populates="user") + @staticmethod async def create( - user: "User" + user: Optional["User"] = None, + email: str = None, + nick: Optional[str] = None, + password: str = None, + status: int = 0, + two_factor: Optional[str] = None, + avatar: Optional[str] = None, + options: Optional[str] = None, + authn: Optional[str] = None, + open_id: Optional[str] = None, + score: int = 0, + phone: Optional[str] = None ): """ 向数据库内添加用户。 @@ -97,18 +109,33 @@ class User(BaseModel, table=True): :param user: User 实例 :type user: User """ + if not user: + user = User( + email=email, + nick=nick, + password=password, + status=status, + two_factor=two_factor, + avatar=avatar, + options=options, + authn=authn, + open_id=open_id, + score=score, + phone=phone + ) + from .database import get_session async for session in get_session(): try: session.add(user) await session.commit() - await session.refresh(user) except Exception as e: await session.rollback() raise e return user + @staticmethod async def get( id: int = None, email: str = None @@ -142,4 +169,100 @@ class User(BaseModel, table=True): result = await session.exec(query) user = result.one_or_none() - return user \ No newline at end of file + return user + + @staticmethod + async def update( + id: int, + email: Optional[str] = None, + nick: Optional[str] = None, + password: Optional[str] = None, + status: Optional[int] = None, + storage: Optional[int] = None, + two_factor: Optional[str] = None, + avatar: Optional[str] = None, + options: Optional[str] = None, + authn: Optional[str] = None, + open_id: Optional[str] = None, + score: Optional[int] = None, + group_id: Optional[int] = None + ) -> "User": + """ + 更新用户信息。 + + :return: 更新后的用户对象 + :rtype: User + """ + + from .database import get_session + from sqlmodel import select + + async for session in get_session(): + try: + statement = select(User).where(User.id == id) + result = await session.exec(statement) + user = result.first() + + if user is None: + raise ValueError(f"User with id {id} not found.") + + if email is not None: + user.email = email + if nick is not None: + user.nick = nick + if password is not None: + user.password = password + if status is not None: + user.status = status + if storage is not None: + user.storage = storage + if two_factor is not None: + user.two_factor = two_factor + if avatar is not None: + user.avatar = avatar + if options is not None: + user.options = options + if authn is not None: + user.authn = authn + if open_id is not None: + user.open_id = open_id + if score is not None: + user.score = score + if group_id is not None: + user.group_id = group_id + + await session.commit() + await session.refresh(user) + return user + except Exception as e: + await session.rollback() + raise e + + @staticmethod + async def delete( + id: int + ) -> None: + """ + 删除用户。 + + :param id: 用户ID + :type id: int + """ + + from .database import get_session + from sqlmodel import select + + async for session in get_session(): + try: + statement = select(User).where(User.id == id) + result = await session.exec(statement) + user = result.one_or_none() + + if user is None: + raise ValueError(f"User with id {id} not found.") + + await session.delete(user) + await session.commit() + except Exception as e: + await session.rollback() + raise e \ No newline at end of file diff --git a/pkg/conf/appmeta.py b/pkg/conf/appmeta.py index f0f10d6..fccb7c0 100644 --- a/pkg/conf/appmeta.py +++ b/pkg/conf/appmeta.py @@ -1,3 +1,9 @@ +import os +from dotenv import load_dotenv +from pkg.log import log + +load_dotenv() + APP_NAME = 'DiskNext Server' summary = '一款基于 FastAPI 的可公私兼备的网盘系统' description = 'DiskNext Server 是一款基于 FastAPI 的网盘系统,支持个人和企业使用。它提供了高性能的文件存储和管理功能,支持多种认证方式。' @@ -7,6 +13,18 @@ BackendVersion = "0.0.1" IsPro = False +debug: bool = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes") + +if debug: + log.info("Debug mode is enabled. This is not recommended for production use.") + +host: str = os.getenv("HOST", "0.0.0.0") +port: int = int(os.getenv("PORT", 5213)) + +log.info(f"Starting DiskNext Server {BackendVersion} on {host}:{port}") + +database_url: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///disknext.db") + tags_meta = [ { "name": "site", diff --git a/tests/test_db_group.py b/tests/test_db_group.py new file mode 100644 index 0000000..2590d2b --- /dev/null +++ b/tests/test_db_group.py @@ -0,0 +1,37 @@ +import pytest + +@pytest.mark.asyncio +async def test_group_curd(): + """测试数据库的增删改查""" + from models import database + from models.group import Group + + await database.init_db(url='sqlite:///:memory:') + + # 测试增 Create + test_group = Group(name='test_group') + created_group = await Group.create(test_group) + + assert created_group is not None + assert created_group.id is not None + assert created_group.name == 'test_group' + + # 测试查 Read + fetched_group = await Group.get(id=created_group.id) + assert fetched_group is not None + assert fetched_group.id == created_group.id + assert fetched_group.name == 'test_group' + + # 测试更新 Update + updated_group = await Group.set( + id=fetched_group.id, + name='updated_group') + + assert updated_group is not None + assert updated_group.id == fetched_group.id + assert updated_group.name == 'updated_group' + + # 测试删除 Delete + await Group.delete(id=updated_group.id) + deleted_group = await Group.get(id=updated_group.id) + assert deleted_group is None \ No newline at end of file diff --git a/tests/test_db_user.py b/tests/test_db_user.py index 6455016..fec7f04 100644 --- a/tests/test_db_user.py +++ b/tests/test_db_user.py @@ -10,8 +10,8 @@ async def test_user_curd(): await database.init_db(url='sqlite:///:memory:') # 新建一个测试用户组 - test_group = Group(name='test_group') - created_group = await Group.create(test_group) + test_user_group = Group(name='test_user_group') + created_group = await Group.create(test_user_group) test_user = User( email='test_user', @@ -36,4 +36,19 @@ async def test_user_curd(): assert fetched_user.password == 'test_password' assert fetched_user.group_id == created_group.id - # 测试改 Update \ No newline at end of file + # 测试改 Update + updated_user = await User.update( + id=fetched_user.id, + email='updated_user', + password='updated_password' + ) + + assert updated_user is not None + assert updated_user.email == 'updated_user' + assert updated_user.password == 'updated_password' + + # 测试删除 Delete + await User.delete(id=updated_user.id) + deleted_user = await User.get(id=updated_user.id) + + assert deleted_user is None \ No newline at end of file diff --git a/tests/test_main.py b/tests/test_main.py index e1e9125..d5025d3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,10 +4,18 @@ from main import app client = TestClient(app) +def is_valid_instance_id(instance_id): + """Check if a string is a valid UUID4.""" + + import uuid + + try: + uuid.UUID(instance_id, version=4) + except (ValueError, TypeError): + assert False, f"instance_id is not a valid UUID4: {instance_id}" def test_read_main(): from pkg.conf.appmeta import BackendVersion - import uuid response = client.get("/api/site/ping") json_response = response.json() @@ -17,7 +25,14 @@ def test_read_main(): assert json_response['data'] == BackendVersion assert json_response['msg'] is None assert 'instance_id' in json_response - try: - uuid.UUID(json_response['instance_id'], version=4) - except (ValueError, TypeError): - assert False, f"instance_id is not a valid UUID4: {json_response['instance_id']}" \ No newline at end of file + is_valid_instance_id(json_response['instance_id']) + + response = client.get("/api/site/config") + json_response = response.json() + + assert response.status_code == 200 + assert json_response['code'] == 0 + assert json_response['data'] is not None + assert json_response['msg'] is None + assert 'instance_id' in json_response + is_valid_instance_id(json_response['instance_id']) \ No newline at end of file