修复数据库迁移问题、新增环境变量读写

This commit is contained in:
2025-07-15 17:32:00 +08:00
parent dc522a8e93
commit 33cca4e271
10 changed files with 432 additions and 39 deletions

1
.gitignore vendored
View File

@@ -12,3 +12,4 @@ __pycache__/
*.code-workspace *.code-workspace
*.db *.db
.env

12
main.py
View File

@@ -2,13 +2,13 @@ from fastapi import FastAPI
from routers import routers from routers import routers
from pkg.conf import appmeta from pkg.conf import appmeta
from models.database import init_db from models.database import init_db
from models.migration import init_default_settings from models.migration import migration
from pkg.lifespan import lifespan from pkg.lifespan import lifespan
from pkg.JWT import jwt from pkg.JWT import jwt
# 添加初始化数据库启动项 # 添加初始化数据库启动项
lifespan.add_startup(init_db) lifespan.add_startup(init_db)
lifespan.add_startup(init_default_settings) lifespan.add_startup(migration)
lifespan.add_startup(jwt.load_secret_key) lifespan.add_startup(jwt.load_secret_key)
# 创建应用实例并设置元数据 # 创建应用实例并设置元数据
@@ -20,6 +20,7 @@ app = FastAPI(
openapi_tags=appmeta.tags_meta, openapi_tags=appmeta.tags_meta,
license_info=appmeta.license_info, license_info=appmeta.license_info,
lifespan=lifespan.lifespan, lifespan=lifespan.lifespan,
debug=appmeta.debug,
) )
# 挂载路由 # 挂载路由
@@ -38,5 +39,8 @@ for router in routers.Router:
# 启动时打印欢迎信息 # 启动时打印欢迎信息
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
# uvicorn.run(app=app, host="0.0.0.0", port=5213) # 生产环境
uvicorn.run(app='main:app', host="0.0.0.0", port=5213, reload=True) # 开发环境 if appmeta.debug:
uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True)
else:
uvicorn.run(app=app, host=appmeta.host, port=appmeta.port)

View File

@@ -1,16 +1,17 @@
# my_project/database.py # my_project/database.py
from sqlmodel import SQLModel from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from pkg.conf import appmeta
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from typing import AsyncGenerator from typing import AsyncGenerator
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///:memory:" ASYNC_DATABASE_URL = appmeta.database_url
engine = create_async_engine( engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL, ASYNC_DATABASE_URL,
echo=True, echo=appmeta.debug,
connect_args={"check_same_thread": False} connect_args={"check_same_thread": False}
if ASYNC_DATABASE_URL.startswith("sqlite") if ASYNC_DATABASE_URL.startswith("sqlite")
else None, else None,

View File

@@ -1,5 +1,6 @@
# my_project/models/group.py # my_project/models/group.py
from tokenize import group
from typing import Optional, List, TYPE_CHECKING from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel from .base import BaseModel
@@ -51,10 +52,17 @@ class Group(BaseModel, table=True):
@staticmethod @staticmethod
async def create( async def create(
group: "Group" group: Optional["Group"] = None,
name: Optional[str] = None,
policies: Optional[str] = None,
max_storage: int = 0,
share_enabled: bool = False,
web_dav_enabled: bool = False,
speed_limit: int = 0,
options: Optional[dict] = None,
) -> "Group": ) -> "Group":
""" """
向数据库内添加用户组。 向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。
:param group: 用户组对象 :param group: 用户组对象
:type group: Group :type group: Group
@@ -63,12 +71,27 @@ class Group(BaseModel, table=True):
""" """
from .database import get_session from .database import get_session
import json
if not group:
if not name:
raise ValueError("Group name is required.")
group = Group(
name=name,
policies=policies,
max_storage=max_storage,
share_enabled=share_enabled,
web_dav_enabled=web_dav_enabled,
speed_limit=speed_limit,
options=json.dumps(options) if options else None,
)
async for session in get_session(): async for session in get_session():
try: try:
session.add(group) session.add(group)
await session.commit() await session.commit()
await session.refresh(group)
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise e
@@ -96,11 +119,111 @@ class Group(BaseModel, table=True):
return None return None
async for session in get_session(): async for session in get_session():
statement = select(Group).where(Group.id == id) try:
result = await session.exec(statement) statement = select(Group).where(Group.id == id)
group = result.one_or_none() result = await session.exec(statement)
group = result.one_or_none()
if group:
return group
else:
return None
except Exception as e:
raise e
@staticmethod
async def set(
id: int,
name: Optional[str] = None,
policies: Optional[str] = None,
max_storage: Optional[int] = None,
share_enabled: Optional[bool] = None,
web_dav_enabled: Optional[bool] = None,
speed_limit: Optional[int] = None,
options: Optional[str] = None
) -> Optional["Group"]:
"""
更新用户组信息。
:param id: 用户组ID
:type id: int
:param name: 用户组名
:type name: Optional[str]
:param policies: 允许的策略ID列表逗号分隔
:type policies: Optional[str]
:param max_storage: 最大存储空间(字节)
:type max_storage: Optional[int]
:param share_enabled: 是否允许创建分享
:type share_enabled: Optional[bool]
:param web_dav_enabled: 是否允许使用WebDAV
:type web_dav_enabled: Optional[bool]
:param speed_limit: 速度限制 (KB/s), 0为不限制
:type speed_limit: Optional[int]
:param options: 其他选项 (JSON格式)
:type options: Optional[str]
:return: 更新后的用户组对象或 None
:rtype: Optional[Group]
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if not group:
raise ValueError(f"Group with id {id} not found.")
if name is not None:
group.name = name
if policies is not None:
group.policies = policies
if max_storage is not None:
group.max_storage = max_storage
if share_enabled is not None:
group.share_enabled = share_enabled
if web_dav_enabled is not None:
group.web_dav_enabled = web_dav_enabled
if speed_limit is not None:
group.speed_limit = speed_limit
if options is not None:
group.options = options
session.add(group)
await session.commit()
if group:
return group return group
else: except Exception as e:
return None await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户组。
:param id: 用户组ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group is None:
raise ValueError(f"Group with id {id} not found.")
await session.delete(group)
await session.commit()
except Exception as e:
await session.rollback()
raise e

View File

@@ -1,6 +1,22 @@
from .setting import Setting from .setting import Setting
from pkg.conf.appmeta import BackendVersion from pkg.conf.appmeta import BackendVersion
from pkg.password.pwd import Password from pkg.password.pwd import Password
from pkg.log import log
async def migration() -> None:
"""
数据库迁移函数,初始化默认设置和用户组。
:return: None
"""
log.info('开始进行数据库初始化...')
await init_default_settings()
await init_default_group()
await init_default_user()
log.info('数据库初始化结束')
default_settings: list[Setting] = [ default_settings: list[Setting] = [
Setting(name="siteURL", value="http://localhost", type="basic"), Setting(name="siteURL", value="http://localhost", type="basic"),
@@ -101,6 +117,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None: async def init_default_settings() -> None:
from .setting import Setting from .setting import Setting
log.info('初始化设置...')
try: try:
# 检查是否已经存在版本设置 # 检查是否已经存在版本设置
ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}") ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}")
@@ -118,10 +136,12 @@ async def init_default_settings() -> None:
async def init_default_group() -> None: async def init_default_group() -> None:
from .group import Group from .group import Group
log.info('初始化用户组...')
try: try:
# 未找到初始管理组时,则创建 # 未找到初始管理组时,则创建
if not Group.get(id=1): if not await Group.get(id=1):
Group.add( await Group.create(
name="管理员", name="管理员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True, share_enabled=True,
@@ -134,12 +154,12 @@ async def init_default_group() -> None:
} }
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"无法创建管理员用户组: {e}") from e raise RuntimeError(f"无法创建管理员用户组: {e}")
try: try:
# 未找到初始注册会员时,则创建 # 未找到初始注册会员时,则创建
if not Group.get(id=2): if not await Group.get(id=2):
Group.add( await Group.create(
name="注册会员", name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True, share_enabled=True,
@@ -149,12 +169,12 @@ async def init_default_group() -> None:
} }
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"无法创建初始注册会员用户组: {e}") from e raise RuntimeError(f"无法创建初始注册会员用户组: {e}")
try: try:
# 未找到初始游客组时,则创建 # 未找到初始游客组时,则创建
if not Group.get(id=3): if not await Group.get(id=3):
Group.add( await Group.create(
name="游客", name="游客",
policies="[]", policies="[]",
share_enabled=False, share_enabled=False,
@@ -164,4 +184,40 @@ async def init_default_group() -> None:
} }
) )
except Exception as e: except Exception as e:
raise RuntimeError(f"无法创建初始游客用户组: {e}") from e raise RuntimeError(f"无法创建初始游客用户组: {e}")
async def init_default_user() -> None:
log.info('初始化管理员用户...')
from .user import User
from .group import Group
# 检查管理员用户是否存在
admin_user = await User.get(id=1)
if not admin_user:
# 创建初始管理员用户
# 获取管理员组
admin_group = await Group.get(id=1)
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 生成管理员密码
from pkg.password.pwd import Password
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
status=1, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
)
admin_user = await User.create(admin_user)
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')

View File

@@ -88,8 +88,20 @@ class User(BaseModel, table=True):
tasks: list["Task"] = Relationship(back_populates="user") tasks: list["Task"] = Relationship(back_populates="user")
webdavs: list["WebDAV"] = Relationship(back_populates="user") webdavs: list["WebDAV"] = Relationship(back_populates="user")
@staticmethod
async def create( async def create(
user: "User" user: Optional["User"] = None,
email: str = None,
nick: Optional[str] = None,
password: str = None,
status: int = 0,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: int = 0,
phone: Optional[str] = None
): ):
""" """
向数据库内添加用户。 向数据库内添加用户。
@@ -97,18 +109,33 @@ class User(BaseModel, table=True):
:param user: User 实例 :param user: User 实例
:type user: User :type user: User
""" """
if not user:
user = User(
email=email,
nick=nick,
password=password,
status=status,
two_factor=two_factor,
avatar=avatar,
options=options,
authn=authn,
open_id=open_id,
score=score,
phone=phone
)
from .database import get_session from .database import get_session
async for session in get_session(): async for session in get_session():
try: try:
session.add(user) session.add(user)
await session.commit() await session.commit()
await session.refresh(user)
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise e
return user return user
@staticmethod
async def get( async def get(
id: int = None, id: int = None,
email: str = None email: str = None
@@ -143,3 +170,99 @@ class User(BaseModel, table=True):
user = result.one_or_none() user = result.one_or_none()
return user return user
@staticmethod
async def update(
id: int,
email: Optional[str] = None,
nick: Optional[str] = None,
password: Optional[str] = None,
status: Optional[int] = None,
storage: Optional[int] = None,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: Optional[int] = None,
group_id: Optional[int] = None
) -> "User":
"""
更新用户信息。
:return: 更新后的用户对象
:rtype: User
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.first()
if user is None:
raise ValueError(f"User with id {id} not found.")
if email is not None:
user.email = email
if nick is not None:
user.nick = nick
if password is not None:
user.password = password
if status is not None:
user.status = status
if storage is not None:
user.storage = storage
if two_factor is not None:
user.two_factor = two_factor
if avatar is not None:
user.avatar = avatar
if options is not None:
user.options = options
if authn is not None:
user.authn = authn
if open_id is not None:
user.open_id = open_id
if score is not None:
user.score = score
if group_id is not None:
user.group_id = group_id
await session.commit()
await session.refresh(user)
return user
except Exception as e:
await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户。
:param id: 用户ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.one_or_none()
if user is None:
raise ValueError(f"User with id {id} not found.")
await session.delete(user)
await session.commit()
except Exception as e:
await session.rollback()
raise e

View File

@@ -1,3 +1,9 @@
import os
from dotenv import load_dotenv
from pkg.log import log
load_dotenv()
APP_NAME = 'DiskNext Server' APP_NAME = 'DiskNext Server'
summary = '一款基于 FastAPI 的可公私兼备的网盘系统' summary = '一款基于 FastAPI 的可公私兼备的网盘系统'
description = 'DiskNext Server 是一款基于 FastAPI 的网盘系统,支持个人和企业使用。它提供了高性能的文件存储和管理功能,支持多种认证方式。' description = 'DiskNext Server 是一款基于 FastAPI 的网盘系统,支持个人和企业使用。它提供了高性能的文件存储和管理功能,支持多种认证方式。'
@@ -7,6 +13,18 @@ BackendVersion = "0.0.1"
IsPro = False IsPro = False
debug: bool = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes")
if debug:
log.info("Debug mode is enabled. This is not recommended for production use.")
host: str = os.getenv("HOST", "0.0.0.0")
port: int = int(os.getenv("PORT", 5213))
log.info(f"Starting DiskNext Server {BackendVersion} on {host}:{port}")
database_url: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///disknext.db")
tags_meta = [ tags_meta = [
{ {
"name": "site", "name": "site",

37
tests/test_db_group.py Normal file
View File

@@ -0,0 +1,37 @@
import pytest
@pytest.mark.asyncio
async def test_group_curd():
"""测试数据库的增删改查"""
from models import database
from models.group import Group
await database.init_db(url='sqlite:///:memory:')
# 测试增 Create
test_group = Group(name='test_group')
created_group = await Group.create(test_group)
assert created_group is not None
assert created_group.id is not None
assert created_group.name == 'test_group'
# 测试查 Read
fetched_group = await Group.get(id=created_group.id)
assert fetched_group is not None
assert fetched_group.id == created_group.id
assert fetched_group.name == 'test_group'
# 测试更新 Update
updated_group = await Group.set(
id=fetched_group.id,
name='updated_group')
assert updated_group is not None
assert updated_group.id == fetched_group.id
assert updated_group.name == 'updated_group'
# 测试删除 Delete
await Group.delete(id=updated_group.id)
deleted_group = await Group.get(id=updated_group.id)
assert deleted_group is None

View File

@@ -10,8 +10,8 @@ async def test_user_curd():
await database.init_db(url='sqlite:///:memory:') await database.init_db(url='sqlite:///:memory:')
# 新建一个测试用户组 # 新建一个测试用户组
test_group = Group(name='test_group') test_user_group = Group(name='test_user_group')
created_group = await Group.create(test_group) created_group = await Group.create(test_user_group)
test_user = User( test_user = User(
email='test_user', email='test_user',
@@ -37,3 +37,18 @@ async def test_user_curd():
assert fetched_user.group_id == created_group.id assert fetched_user.group_id == created_group.id
# 测试改 Update # 测试改 Update
updated_user = await User.update(
id=fetched_user.id,
email='updated_user',
password='updated_password'
)
assert updated_user is not None
assert updated_user.email == 'updated_user'
assert updated_user.password == 'updated_password'
# 测试删除 Delete
await User.delete(id=updated_user.id)
deleted_user = await User.get(id=updated_user.id)
assert deleted_user is None

View File

@@ -4,10 +4,18 @@ from main import app
client = TestClient(app) client = TestClient(app)
def is_valid_instance_id(instance_id):
"""Check if a string is a valid UUID4."""
import uuid
try:
uuid.UUID(instance_id, version=4)
except (ValueError, TypeError):
assert False, f"instance_id is not a valid UUID4: {instance_id}"
def test_read_main(): def test_read_main():
from pkg.conf.appmeta import BackendVersion from pkg.conf.appmeta import BackendVersion
import uuid
response = client.get("/api/site/ping") response = client.get("/api/site/ping")
json_response = response.json() json_response = response.json()
@@ -17,7 +25,14 @@ def test_read_main():
assert json_response['data'] == BackendVersion assert json_response['data'] == BackendVersion
assert json_response['msg'] is None assert json_response['msg'] is None
assert 'instance_id' in json_response assert 'instance_id' in json_response
try: is_valid_instance_id(json_response['instance_id'])
uuid.UUID(json_response['instance_id'], version=4)
except (ValueError, TypeError): response = client.get("/api/site/config")
assert False, f"instance_id is not a valid UUID4: {json_response['instance_id']}" json_response = response.json()
assert response.status_code == 200
assert json_response['code'] == 0
assert json_response['data'] is not None
assert json_response['msg'] is None
assert 'instance_id' in json_response
is_valid_instance_id(json_response['instance_id'])