修复数据库迁移问题、新增环境变量读写
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,3 +12,4 @@ __pycache__/
|
||||
*.code-workspace
|
||||
|
||||
*.db
|
||||
.env
|
||||
12
main.py
12
main.py
@@ -2,13 +2,13 @@ from fastapi import FastAPI
|
||||
from routers import routers
|
||||
from pkg.conf import appmeta
|
||||
from models.database import init_db
|
||||
from models.migration import init_default_settings
|
||||
from models.migration import migration
|
||||
from pkg.lifespan import lifespan
|
||||
from pkg.JWT import jwt
|
||||
|
||||
# 添加初始化数据库启动项
|
||||
lifespan.add_startup(init_db)
|
||||
lifespan.add_startup(init_default_settings)
|
||||
lifespan.add_startup(migration)
|
||||
lifespan.add_startup(jwt.load_secret_key)
|
||||
|
||||
# 创建应用实例并设置元数据
|
||||
@@ -20,6 +20,7 @@ app = FastAPI(
|
||||
openapi_tags=appmeta.tags_meta,
|
||||
license_info=appmeta.license_info,
|
||||
lifespan=lifespan.lifespan,
|
||||
debug=appmeta.debug,
|
||||
)
|
||||
|
||||
# 挂载路由
|
||||
@@ -38,5 +39,8 @@ for router in routers.Router:
|
||||
# 启动时打印欢迎信息
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# uvicorn.run(app=app, host="0.0.0.0", port=5213) # 生产环境
|
||||
uvicorn.run(app='main:app', host="0.0.0.0", port=5213, reload=True) # 开发环境
|
||||
|
||||
if appmeta.debug:
|
||||
uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True)
|
||||
else:
|
||||
uvicorn.run(app=app, host=appmeta.host, port=appmeta.port)
|
||||
@@ -1,16 +1,17 @@
|
||||
# my_project/database.py
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from pkg.conf import appmeta
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import AsyncGenerator
|
||||
|
||||
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
||||
ASYNC_DATABASE_URL = appmeta.database_url
|
||||
|
||||
engine = create_async_engine(
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=True,
|
||||
echo=appmeta.debug,
|
||||
connect_args={"check_same_thread": False}
|
||||
if ASYNC_DATABASE_URL.startswith("sqlite")
|
||||
else None,
|
||||
|
||||
129
models/group.py
129
models/group.py
@@ -1,5 +1,6 @@
|
||||
# my_project/models/group.py
|
||||
|
||||
from tokenize import group
|
||||
from typing import Optional, List, TYPE_CHECKING
|
||||
from sqlmodel import Field, Relationship, text, Column, func, DateTime
|
||||
from .base import BaseModel
|
||||
@@ -51,10 +52,17 @@ class Group(BaseModel, table=True):
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
group: "Group"
|
||||
group: Optional["Group"] = None,
|
||||
name: Optional[str] = None,
|
||||
policies: Optional[str] = None,
|
||||
max_storage: int = 0,
|
||||
share_enabled: bool = False,
|
||||
web_dav_enabled: bool = False,
|
||||
speed_limit: int = 0,
|
||||
options: Optional[dict] = None,
|
||||
) -> "Group":
|
||||
"""
|
||||
向数据库内添加用户组。
|
||||
向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。
|
||||
|
||||
:param group: 用户组对象
|
||||
:type group: Group
|
||||
@@ -63,12 +71,27 @@ class Group(BaseModel, table=True):
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
import json
|
||||
|
||||
if not group:
|
||||
|
||||
if not name:
|
||||
raise ValueError("Group name is required.")
|
||||
|
||||
group = Group(
|
||||
name=name,
|
||||
policies=policies,
|
||||
max_storage=max_storage,
|
||||
share_enabled=share_enabled,
|
||||
web_dav_enabled=web_dav_enabled,
|
||||
speed_limit=speed_limit,
|
||||
options=json.dumps(options) if options else None,
|
||||
)
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
await session.refresh(group)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
@@ -96,6 +119,7 @@ class Group(BaseModel, table=True):
|
||||
return None
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
@@ -104,3 +128,102 @@ class Group(BaseModel, table=True):
|
||||
return group
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def set(
|
||||
id: int,
|
||||
name: Optional[str] = None,
|
||||
policies: Optional[str] = None,
|
||||
max_storage: Optional[int] = None,
|
||||
share_enabled: Optional[bool] = None,
|
||||
web_dav_enabled: Optional[bool] = None,
|
||||
speed_limit: Optional[int] = None,
|
||||
options: Optional[str] = None
|
||||
) -> Optional["Group"]:
|
||||
"""
|
||||
更新用户组信息。
|
||||
|
||||
:param id: 用户组ID
|
||||
:type id: int
|
||||
:param name: 用户组名
|
||||
:type name: Optional[str]
|
||||
:param policies: 允许的策略ID列表,逗号分隔
|
||||
:type policies: Optional[str]
|
||||
:param max_storage: 最大存储空间(字节)
|
||||
:type max_storage: Optional[int]
|
||||
:param share_enabled: 是否允许创建分享
|
||||
:type share_enabled: Optional[bool]
|
||||
:param web_dav_enabled: 是否允许使用WebDAV
|
||||
:type web_dav_enabled: Optional[bool]
|
||||
:param speed_limit: 速度限制 (KB/s), 0为不限制
|
||||
:type speed_limit: Optional[int]
|
||||
:param options: 其他选项 (JSON格式)
|
||||
:type options: Optional[str]
|
||||
:return: 更新后的用户组对象或 None
|
||||
:rtype: Optional[Group]
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if not group:
|
||||
raise ValueError(f"Group with id {id} not found.")
|
||||
|
||||
if name is not None:
|
||||
group.name = name
|
||||
if policies is not None:
|
||||
group.policies = policies
|
||||
if max_storage is not None:
|
||||
group.max_storage = max_storage
|
||||
if share_enabled is not None:
|
||||
group.share_enabled = share_enabled
|
||||
if web_dav_enabled is not None:
|
||||
group.web_dav_enabled = web_dav_enabled
|
||||
if speed_limit is not None:
|
||||
group.speed_limit = speed_limit
|
||||
if options is not None:
|
||||
group.options = options
|
||||
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
|
||||
return group
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
id: int
|
||||
) -> None:
|
||||
"""
|
||||
删除用户组。
|
||||
|
||||
:param id: 用户组ID
|
||||
:type id: int
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if group is None:
|
||||
raise ValueError(f"Group with id {id} not found.")
|
||||
|
||||
await session.delete(group)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
@@ -1,6 +1,22 @@
|
||||
from .setting import Setting
|
||||
from pkg.conf.appmeta import BackendVersion
|
||||
from pkg.password.pwd import Password
|
||||
from pkg.log import log
|
||||
|
||||
async def migration() -> None:
|
||||
"""
|
||||
数据库迁移函数,初始化默认设置和用户组。
|
||||
|
||||
:return: None
|
||||
"""
|
||||
|
||||
log.info('开始进行数据库初始化...')
|
||||
|
||||
await init_default_settings()
|
||||
await init_default_group()
|
||||
await init_default_user()
|
||||
|
||||
log.info('数据库初始化结束')
|
||||
|
||||
default_settings: list[Setting] = [
|
||||
Setting(name="siteURL", value="http://localhost", type="basic"),
|
||||
@@ -101,6 +117,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
try:
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}")
|
||||
@@ -118,10 +136,12 @@ async def init_default_settings() -> None:
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
try:
|
||||
# 未找到初始管理组时,则创建
|
||||
if not Group.get(id=1):
|
||||
Group.add(
|
||||
if not await Group.get(id=1):
|
||||
await Group.create(
|
||||
name="管理员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
@@ -134,12 +154,12 @@ async def init_default_group() -> None:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建管理员用户组: {e}") from e
|
||||
raise RuntimeError(f"无法创建管理员用户组: {e}")
|
||||
|
||||
try:
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not Group.get(id=2):
|
||||
Group.add(
|
||||
if not await Group.get(id=2):
|
||||
await Group.create(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
@@ -149,12 +169,12 @@ async def init_default_group() -> None:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始注册会员用户组: {e}") from e
|
||||
raise RuntimeError(f"无法创建初始注册会员用户组: {e}")
|
||||
|
||||
try:
|
||||
# 未找到初始游客组时,则创建
|
||||
if not Group.get(id=3):
|
||||
Group.add(
|
||||
if not await Group.get(id=3):
|
||||
await Group.create(
|
||||
name="游客",
|
||||
policies="[]",
|
||||
share_enabled=False,
|
||||
@@ -164,4 +184,40 @@ async def init_default_group() -> None:
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始游客用户组: {e}") from e
|
||||
raise RuntimeError(f"无法创建初始游客用户组: {e}")
|
||||
|
||||
async def init_default_user() -> None:
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
from .user import User
|
||||
from .group import Group
|
||||
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(id=1)
|
||||
|
||||
if not admin_user:
|
||||
# 创建初始管理员用户
|
||||
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(id=1)
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 生成管理员密码
|
||||
from pkg.password.pwd import Password
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
email="admin@yxqi.cn",
|
||||
nick="admin",
|
||||
status=1, # 正常状态
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
|
||||
admin_user = await User.create(admin_user)
|
||||
|
||||
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
|
||||
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
||||
127
models/user.py
127
models/user.py
@@ -88,8 +88,20 @@ class User(BaseModel, table=True):
|
||||
tasks: list["Task"] = Relationship(back_populates="user")
|
||||
webdavs: list["WebDAV"] = Relationship(back_populates="user")
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
user: "User"
|
||||
user: Optional["User"] = None,
|
||||
email: str = None,
|
||||
nick: Optional[str] = None,
|
||||
password: str = None,
|
||||
status: int = 0,
|
||||
two_factor: Optional[str] = None,
|
||||
avatar: Optional[str] = None,
|
||||
options: Optional[str] = None,
|
||||
authn: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
score: int = 0,
|
||||
phone: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
向数据库内添加用户。
|
||||
@@ -97,18 +109,33 @@ class User(BaseModel, table=True):
|
||||
:param user: User 实例
|
||||
:type user: User
|
||||
"""
|
||||
if not user:
|
||||
user = User(
|
||||
email=email,
|
||||
nick=nick,
|
||||
password=password,
|
||||
status=status,
|
||||
two_factor=two_factor,
|
||||
avatar=avatar,
|
||||
options=options,
|
||||
authn=authn,
|
||||
open_id=open_id,
|
||||
score=score,
|
||||
phone=phone
|
||||
)
|
||||
|
||||
from .database import get_session
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
id: int = None,
|
||||
email: str = None
|
||||
@@ -143,3 +170,99 @@ class User(BaseModel, table=True):
|
||||
user = result.one_or_none()
|
||||
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def update(
|
||||
id: int,
|
||||
email: Optional[str] = None,
|
||||
nick: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
status: Optional[int] = None,
|
||||
storage: Optional[int] = None,
|
||||
two_factor: Optional[str] = None,
|
||||
avatar: Optional[str] = None,
|
||||
options: Optional[str] = None,
|
||||
authn: Optional[str] = None,
|
||||
open_id: Optional[str] = None,
|
||||
score: Optional[int] = None,
|
||||
group_id: Optional[int] = None
|
||||
) -> "User":
|
||||
"""
|
||||
更新用户信息。
|
||||
|
||||
:return: 更新后的用户对象
|
||||
:rtype: User
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(User).where(User.id == id)
|
||||
result = await session.exec(statement)
|
||||
user = result.first()
|
||||
|
||||
if user is None:
|
||||
raise ValueError(f"User with id {id} not found.")
|
||||
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if nick is not None:
|
||||
user.nick = nick
|
||||
if password is not None:
|
||||
user.password = password
|
||||
if status is not None:
|
||||
user.status = status
|
||||
if storage is not None:
|
||||
user.storage = storage
|
||||
if two_factor is not None:
|
||||
user.two_factor = two_factor
|
||||
if avatar is not None:
|
||||
user.avatar = avatar
|
||||
if options is not None:
|
||||
user.options = options
|
||||
if authn is not None:
|
||||
user.authn = authn
|
||||
if open_id is not None:
|
||||
user.open_id = open_id
|
||||
if score is not None:
|
||||
user.score = score
|
||||
if group_id is not None:
|
||||
user.group_id = group_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
id: int
|
||||
) -> None:
|
||||
"""
|
||||
删除用户。
|
||||
|
||||
:param id: 用户ID
|
||||
:type id: int
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(User).where(User.id == id)
|
||||
result = await session.exec(statement)
|
||||
user = result.one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise ValueError(f"User with id {id} not found.")
|
||||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
@@ -1,3 +1,9 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from pkg.log import log
|
||||
|
||||
load_dotenv()
|
||||
|
||||
APP_NAME = 'DiskNext Server'
|
||||
summary = '一款基于 FastAPI 的可公私兼备的网盘系统'
|
||||
description = 'DiskNext Server 是一款基于 FastAPI 的网盘系统,支持个人和企业使用。它提供了高性能的文件存储和管理功能,支持多种认证方式。'
|
||||
@@ -7,6 +13,18 @@ BackendVersion = "0.0.1"
|
||||
|
||||
IsPro = False
|
||||
|
||||
debug: bool = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
if debug:
|
||||
log.info("Debug mode is enabled. This is not recommended for production use.")
|
||||
|
||||
host: str = os.getenv("HOST", "0.0.0.0")
|
||||
port: int = int(os.getenv("PORT", 5213))
|
||||
|
||||
log.info(f"Starting DiskNext Server {BackendVersion} on {host}:{port}")
|
||||
|
||||
database_url: str = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///disknext.db")
|
||||
|
||||
tags_meta = [
|
||||
{
|
||||
"name": "site",
|
||||
|
||||
37
tests/test_db_group.py
Normal file
37
tests/test_db_group.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database
|
||||
from models.group import Group
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
# 测试增 Create
|
||||
test_group = Group(name='test_group')
|
||||
created_group = await Group.create(test_group)
|
||||
|
||||
assert created_group is not None
|
||||
assert created_group.id is not None
|
||||
assert created_group.name == 'test_group'
|
||||
|
||||
# 测试查 Read
|
||||
fetched_group = await Group.get(id=created_group.id)
|
||||
assert fetched_group is not None
|
||||
assert fetched_group.id == created_group.id
|
||||
assert fetched_group.name == 'test_group'
|
||||
|
||||
# 测试更新 Update
|
||||
updated_group = await Group.set(
|
||||
id=fetched_group.id,
|
||||
name='updated_group')
|
||||
|
||||
assert updated_group is not None
|
||||
assert updated_group.id == fetched_group.id
|
||||
assert updated_group.name == 'updated_group'
|
||||
|
||||
# 测试删除 Delete
|
||||
await Group.delete(id=updated_group.id)
|
||||
deleted_group = await Group.get(id=updated_group.id)
|
||||
assert deleted_group is None
|
||||
@@ -10,8 +10,8 @@ async def test_user_curd():
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
# 新建一个测试用户组
|
||||
test_group = Group(name='test_group')
|
||||
created_group = await Group.create(test_group)
|
||||
test_user_group = Group(name='test_user_group')
|
||||
created_group = await Group.create(test_user_group)
|
||||
|
||||
test_user = User(
|
||||
email='test_user',
|
||||
@@ -37,3 +37,18 @@ async def test_user_curd():
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await User.update(
|
||||
id=fetched_user.id,
|
||||
email='updated_user',
|
||||
password='updated_password'
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.email == 'updated_user'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
await User.delete(id=updated_user.id)
|
||||
deleted_user = await User.get(id=updated_user.id)
|
||||
|
||||
assert deleted_user is None
|
||||
@@ -4,10 +4,18 @@ from main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
def is_valid_instance_id(instance_id):
|
||||
"""Check if a string is a valid UUID4."""
|
||||
|
||||
import uuid
|
||||
|
||||
try:
|
||||
uuid.UUID(instance_id, version=4)
|
||||
except (ValueError, TypeError):
|
||||
assert False, f"instance_id is not a valid UUID4: {instance_id}"
|
||||
|
||||
def test_read_main():
|
||||
from pkg.conf.appmeta import BackendVersion
|
||||
import uuid
|
||||
|
||||
response = client.get("/api/site/ping")
|
||||
json_response = response.json()
|
||||
@@ -17,7 +25,14 @@ def test_read_main():
|
||||
assert json_response['data'] == BackendVersion
|
||||
assert json_response['msg'] is None
|
||||
assert 'instance_id' in json_response
|
||||
try:
|
||||
uuid.UUID(json_response['instance_id'], version=4)
|
||||
except (ValueError, TypeError):
|
||||
assert False, f"instance_id is not a valid UUID4: {json_response['instance_id']}"
|
||||
is_valid_instance_id(json_response['instance_id'])
|
||||
|
||||
response = client.get("/api/site/config")
|
||||
json_response = response.json()
|
||||
|
||||
assert response.status_code == 200
|
||||
assert json_response['code'] == 0
|
||||
assert json_response['data'] is not None
|
||||
assert json_response['msg'] is None
|
||||
assert 'instance_id' in json_response
|
||||
is_valid_instance_id(json_response['instance_id'])
|
||||
Reference in New Issue
Block a user