修复数据库迁移问题、新增环境变量读写

This commit is contained in:
2025-07-15 17:32:00 +08:00
parent dc522a8e93
commit 33cca4e271
10 changed files with 432 additions and 39 deletions

View File

@@ -1,16 +1,17 @@
# my_project/database.py
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from pkg.conf import appmeta
from sqlalchemy.orm import sessionmaker
from typing import AsyncGenerator
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
ASYNC_DATABASE_URL = appmeta.database_url
engine = create_async_engine(
engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL,
echo=True,
echo=appmeta.debug,
connect_args={"check_same_thread": False}
if ASYNC_DATABASE_URL.startswith("sqlite")
else None,

View File

@@ -1,5 +1,6 @@
# my_project/models/group.py
from tokenize import group
from typing import Optional, List, TYPE_CHECKING
from sqlmodel import Field, Relationship, text, Column, func, DateTime
from .base import BaseModel
@@ -51,10 +52,17 @@ class Group(BaseModel, table=True):
@staticmethod
async def create(
group: "Group"
group: Optional["Group"] = None,
name: Optional[str] = None,
policies: Optional[str] = None,
max_storage: int = 0,
share_enabled: bool = False,
web_dav_enabled: bool = False,
speed_limit: int = 0,
options: Optional[dict] = None,
) -> "Group":
"""
向数据库内添加用户组。
向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。
:param group: 用户组对象
:type group: Group
@@ -63,12 +71,27 @@ class Group(BaseModel, table=True):
"""
from .database import get_session
import json
if not group:
if not name:
raise ValueError("Group name is required.")
group = Group(
name=name,
policies=policies,
max_storage=max_storage,
share_enabled=share_enabled,
web_dav_enabled=web_dav_enabled,
speed_limit=speed_limit,
options=json.dumps(options) if options else None,
)
async for session in get_session():
try:
session.add(group)
await session.commit()
await session.refresh(group)
except Exception as e:
await session.rollback()
raise e
@@ -96,11 +119,111 @@ class Group(BaseModel, table=True):
return None
async for session in get_session():
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group:
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group:
return group
else:
return None
except Exception as e:
raise e
@staticmethod
async def set(
id: int,
name: Optional[str] = None,
policies: Optional[str] = None,
max_storage: Optional[int] = None,
share_enabled: Optional[bool] = None,
web_dav_enabled: Optional[bool] = None,
speed_limit: Optional[int] = None,
options: Optional[str] = None
) -> Optional["Group"]:
"""
更新用户组信息。
:param id: 用户组ID
:type id: int
:param name: 用户组名
:type name: Optional[str]
:param policies: 允许的策略ID列表逗号分隔
:type policies: Optional[str]
:param max_storage: 最大存储空间(字节)
:type max_storage: Optional[int]
:param share_enabled: 是否允许创建分享
:type share_enabled: Optional[bool]
:param web_dav_enabled: 是否允许使用WebDAV
:type web_dav_enabled: Optional[bool]
:param speed_limit: 速度限制 (KB/s), 0为不限制
:type speed_limit: Optional[int]
:param options: 其他选项 (JSON格式)
:type options: Optional[str]
:return: 更新后的用户组对象或 None
:rtype: Optional[Group]
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if not group:
raise ValueError(f"Group with id {id} not found.")
if name is not None:
group.name = name
if policies is not None:
group.policies = policies
if max_storage is not None:
group.max_storage = max_storage
if share_enabled is not None:
group.share_enabled = share_enabled
if web_dav_enabled is not None:
group.web_dav_enabled = web_dav_enabled
if speed_limit is not None:
group.speed_limit = speed_limit
if options is not None:
group.options = options
session.add(group)
await session.commit()
return group
else:
return None
except Exception as e:
await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户组。
:param id: 用户组ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group is None:
raise ValueError(f"Group with id {id} not found.")
await session.delete(group)
await session.commit()
except Exception as e:
await session.rollback()
raise e

View File

@@ -1,6 +1,22 @@
from .setting import Setting
from pkg.conf.appmeta import BackendVersion
from pkg.password.pwd import Password
from pkg.log import log
async def migration() -> None:
"""
数据库迁移函数,初始化默认设置和用户组。
:return: None
"""
log.info('开始进行数据库初始化...')
await init_default_settings()
await init_default_group()
await init_default_user()
log.info('数据库初始化结束')
default_settings: list[Setting] = [
Setting(name="siteURL", value="http://localhost", type="basic"),
@@ -101,6 +117,8 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None:
from .setting import Setting
log.info('初始化设置...')
try:
# 检查是否已经存在版本设置
ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}")
@@ -118,10 +136,12 @@ async def init_default_settings() -> None:
async def init_default_group() -> None:
from .group import Group
log.info('初始化用户组...')
try:
# 未找到初始管理组时,则创建
if not Group.get(id=1):
Group.add(
if not await Group.get(id=1):
await Group.create(
name="管理员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
@@ -134,12 +154,12 @@ async def init_default_group() -> None:
}
)
except Exception as e:
raise RuntimeError(f"无法创建管理员用户组: {e}") from e
raise RuntimeError(f"无法创建管理员用户组: {e}")
try:
# 未找到初始注册会员时,则创建
if not Group.get(id=2):
Group.add(
if not await Group.get(id=2):
await Group.create(
name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
@@ -149,12 +169,12 @@ async def init_default_group() -> None:
}
)
except Exception as e:
raise RuntimeError(f"无法创建初始注册会员用户组: {e}") from e
raise RuntimeError(f"无法创建初始注册会员用户组: {e}")
try:
# 未找到初始游客组时,则创建
if not Group.get(id=3):
Group.add(
if not await Group.get(id=3):
await Group.create(
name="游客",
policies="[]",
share_enabled=False,
@@ -164,4 +184,40 @@ async def init_default_group() -> None:
}
)
except Exception as e:
raise RuntimeError(f"无法创建初始游客用户组: {e}") from e
raise RuntimeError(f"无法创建初始游客用户组: {e}")
async def init_default_user() -> None:
log.info('初始化管理员用户...')
from .user import User
from .group import Group
# 检查管理员用户是否存在
admin_user = await User.get(id=1)
if not admin_user:
# 创建初始管理员用户
# 获取管理员组
admin_group = await Group.get(id=1)
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 生成管理员密码
from pkg.password.pwd import Password
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
status=1, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
)
admin_user = await User.create(admin_user)
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')

View File

@@ -88,8 +88,20 @@ class User(BaseModel, table=True):
tasks: list["Task"] = Relationship(back_populates="user")
webdavs: list["WebDAV"] = Relationship(back_populates="user")
@staticmethod
async def create(
user: "User"
user: Optional["User"] = None,
email: str = None,
nick: Optional[str] = None,
password: str = None,
status: int = 0,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: int = 0,
phone: Optional[str] = None
):
"""
向数据库内添加用户。
@@ -97,18 +109,33 @@ class User(BaseModel, table=True):
:param user: User 实例
:type user: User
"""
if not user:
user = User(
email=email,
nick=nick,
password=password,
status=status,
two_factor=two_factor,
avatar=avatar,
options=options,
authn=authn,
open_id=open_id,
score=score,
phone=phone
)
from .database import get_session
async for session in get_session():
try:
session.add(user)
await session.commit()
await session.refresh(user)
except Exception as e:
await session.rollback()
raise e
return user
@staticmethod
async def get(
id: int = None,
email: str = None
@@ -142,4 +169,100 @@ class User(BaseModel, table=True):
result = await session.exec(query)
user = result.one_or_none()
return user
return user
@staticmethod
async def update(
id: int,
email: Optional[str] = None,
nick: Optional[str] = None,
password: Optional[str] = None,
status: Optional[int] = None,
storage: Optional[int] = None,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: Optional[int] = None,
group_id: Optional[int] = None
) -> "User":
"""
更新用户信息。
:return: 更新后的用户对象
:rtype: User
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.first()
if user is None:
raise ValueError(f"User with id {id} not found.")
if email is not None:
user.email = email
if nick is not None:
user.nick = nick
if password is not None:
user.password = password
if status is not None:
user.status = status
if storage is not None:
user.storage = storage
if two_factor is not None:
user.two_factor = two_factor
if avatar is not None:
user.avatar = avatar
if options is not None:
user.options = options
if authn is not None:
user.authn = authn
if open_id is not None:
user.open_id = open_id
if score is not None:
user.score = score
if group_id is not None:
user.group_id = group_id
await session.commit()
await session.refresh(user)
return user
except Exception as e:
await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户。
:param id: 用户ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(User).where(User.id == id)
result = await session.exec(statement)
user = result.one_or_none()
if user is None:
raise ValueError(f"User with id {id} not found.")
await session.delete(user)
await session.commit()
except Exception as e:
await session.rollback()
raise e