feat: add database session dependency for FastAPI routes

- Introduced a new dependency in `middleware/dependencies.py` to provide an asynchronous database session using SQLModel.
- This dependency can be utilized in route functions to facilitate database operations.
This commit is contained in:
2025-11-27 22:18:50 +08:00
parent b364b740ca
commit b02a4638da
25 changed files with 909 additions and 748 deletions

View File

@@ -41,183 +41,4 @@ class Group(TableBase, table=True):
previous_user: List["User"] = Relationship(
back_populates="previous_group",
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
)
@staticmethod
async def create(
group: Optional["Group"] = None,
name: str | None = None,
policies: str | None = None,
max_storage: int = 0,
share_enabled: bool = False,
web_dav_enabled: bool = False,
speed_limit: int = 0,
options: dict | None = None,
) -> "Group":
"""
向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。
:param group: 用户组对象
:type group: Group
:return: 新创建的用户组对象
:rtype: Group
"""
from .database import get_session
import json
if not group:
if not name:
raise ValueError("Group name is required.")
group = Group(
name=name,
policies=policies,
max_storage=max_storage,
share_enabled=share_enabled,
web_dav_enabled=web_dav_enabled,
speed_limit=speed_limit,
options=json.dumps(options) if options else None,
)
async for session in get_session():
try:
session.add(group)
await session.commit()
await session.refresh(group)
except Exception as e:
await session.rollback()
raise e
return group
@staticmethod
async def get(
id: int = None
) -> Optional["Group"]:
"""
获取用户组信息。
:param id: 用户组ID默认为 None
:type id: int
:return: 用户组对象或 None
:rtype: Optional[Group]
"""
from .database import get_session
from sqlmodel import select
session = get_session()
if id is None:
return None
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group:
return group
else:
return None
except Exception as e:
raise e
@staticmethod
async def set(
id: int,
name: str | None = None,
policies: str | None = None,
max_storage: int | None = None,
share_enabled: bool | None = None,
web_dav_enabled: bool | None = None,
speed_limit: int | None = None,
options: str | None = None
) -> Optional["Group"]:
"""
更新用户组信息。
:param id: 用户组ID
:type id: int
:param name: 用户组名
:type name: str | None
:param policies: 允许的策略ID列表逗号分隔
:type policies: str | None
:param max_storage: 最大存储空间(字节)
:type max_storage: int | None
:param share_enabled: 是否允许创建分享
:type share_enabled: bool | None
:param web_dav_enabled: 是否允许使用WebDAV
:type web_dav_enabled: bool | None
:param speed_limit: 速度限制 (KB/s), 0为不限制
:type speed_limit: int | None
:param options: 其他选项 (JSON格式)
:type options: str | None
:return: 更新后的用户组对象或 None
:rtype: Optional[Group]
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if not group:
raise ValueError(f"Group with id {id} not found.")
if name is not None:
group.name = name
if policies is not None:
group.policies = policies
if max_storage is not None:
group.max_storage = max_storage
if share_enabled is not None:
group.share_enabled = share_enabled
if web_dav_enabled is not None:
group.web_dav_enabled = web_dav_enabled
if speed_limit is not None:
group.speed_limit = speed_limit
if options is not None:
group.options = options
session.add(group)
await session.commit()
return group
except Exception as e:
await session.rollback()
raise e
@staticmethod
async def delete(
id: int
) -> None:
"""
删除用户组。
:param id: 用户组ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(Group).where(Group.id == id)
result = await session.exec(statement)
group = result.one_or_none()
if group is None:
raise ValueError(f"Group with id {id} not found.")
await session.delete(group)
await session.commit()
except Exception as e:
await session.rollback()
raise e
)

View File

@@ -73,7 +73,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
Setting(name="hot_share_num", value="10", type="share"),
Setting(name="gravatar_server", value="https://www.gravatar.com/", type="avatar"),
Setting(name="defaultTheme", value="#3f51b5", type="basic"),
Setting(name="themes", value=ThemeModel().model_dump(), type="basic"),
Setting(name="themes", value=ThemeModel().model_dump_json(), type="basic"),
Setting(name="aria2_token", value="", type="aria2"),
Setting(name="aria2_rpcurl", value="", type="aria2"),
Setting(name="aria2_temp_path", value="", type="aria2"),
@@ -118,110 +118,99 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
async def init_default_settings() -> None:
from .setting import Setting
from .database import get_session
from sqlalchemy import and_
log.info('初始化设置...')
try:
async for session in get_session():
# 检查是否已经存在版本设置
ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}")
if ver == "installed":
ver = await Setting.get(
session,
and_(Setting.type == "version", Setting.name == f"db_version_{BackendVersion}")
)
if ver and ver.value == "installed":
return
else: raise ValueError("Database version mismatch or not installed.")
except:
for setting in default_settings:
await Setting.add(
type=setting.type,
name=setting.name,
value=setting.value
)
# 批量添加默认设置
await Setting.add(session, default_settings)
async def init_default_group() -> None:
from .group import Group
from .group import Group, GroupOptions
from .database import get_session
log.info('初始化用户组...')
try:
async for session in get_session():
# 未找到初始管理组时,则创建
if not await Group.get(id=1):
await Group.create(
if not await Group.get(session, Group.id == 1):
admin_group = Group(
name="管理员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
options={
"ArchiveDownload": True,
"ArchiveTask": True,
"ShareDownload": True,
"Aria2": True,
}
admin=True,
options=GroupOptions(
archive_download=True,
archive_task=True,
share_download=True,
aria2=True,
).model_dump(),
)
except Exception as e:
raise RuntimeError(f"无法创建管理员用户组: {e}")
await admin_group.save(session)
try:
# 未找到初始注册会员时,则创建
if not await Group.get(id=2):
await Group.create(
if not await Group.get(session, Group.id == 2):
member_group = Group(
name="注册会员",
max_storage=1 * 1024 * 1024 * 1024, # 1GB
share_enabled=True,
web_dav_enabled=True,
options={
"ShareDownload": True,
}
options=GroupOptions(share_download=True).model_dump(),
)
except Exception as e:
raise RuntimeError(f"无法创建初始注册会员用户组: {e}")
try:
await member_group.save(session)
# 未找到初始游客组时,则创建
if not await Group.get(id=3):
await Group.create(
if not await Group.get(session, Group.id == 3):
guest_group = Group(
name="游客",
policies="[]",
share_enabled=False,
web_dav_enabled=False,
options={
"ShareDownload": True,
}
options=GroupOptions(share_download=True).model_dump(),
)
except Exception as e:
raise RuntimeError(f"无法创建初始游客用户组: {e}")
await guest_group.save(session)
async def init_default_user() -> None:
log.info('初始化管理员用户...')
from .user import User
from .group import Group
from .database import get_session
log.info('初始化管理员用户...')
async for session in get_session():
# 检查管理员用户是否存在
admin_user = await User.get(session, User.id == 1)
if not admin_user:
# 创建初始管理员用户
# 获取管理员组
admin_group = await Group.get(id=1)
admin_group = await Group.get(session, Group.id == 1)
if not admin_group:
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
# 生成管理员密码
from pkg.password.pwd import Password
admin_password = Password.generate(8)
hashed_admin_password = Password.hash(admin_password)
admin_user = User(
email="admin@yxqi.cn",
username="admin",
nick="admin",
status=True, # 正常状态
status=True,
group_id=admin_group.id,
password=hashed_admin_password,
)
admin_user = await admin_user.save(session)
await admin_user.save(session)
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
log.info(f'初始管理员账号:[bold]admin[/bold]')
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')

View File

@@ -39,7 +39,7 @@ class TokenModel(BaseModel):
refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间")
refresh_token: str = Field(default=None, description="刷新令牌")
class groupModel(BaseModel):
class GroupModel(BaseModel):
'''
用户组模型
'''
@@ -58,7 +58,7 @@ class groupModel(BaseModel):
selectNode: bool = Field(default=False, description="是否允许选择离线下载节点")
advanceDelete: bool = Field(default=False, description="是否允许高级删除")
class userModel(BaseModel):
class UserModel(BaseModel):
'''
用户模型
'''
@@ -71,7 +71,7 @@ class userModel(BaseModel):
preferred_theme: ThemeModel = Field(default_factory=ThemeModel, description="用户首选主题")
score: int = Field(default=0, description="用户积分")
anonymous: bool = Field(default=False, description="是否为匿名用户")
group: groupModel = Field(default_factory=None, description="用户所属用户组")
group: GroupModel = Field(default_factory=None, description="用户所属用户组")
tags: list = Field(default_factory=list, description="用户标签列表")
class SiteConfigModel(ResponseModel):

View File

@@ -1,36 +1,36 @@
from typing import Optional, Literal
from sqlmodel import Field, UniqueConstraint, Column, func, DateTime
from sqlmodel import Field, UniqueConstraint
from .base import TableBase
from datetime import datetime
from enum import StrEnum
SETTINGS_TYPE = Literal[
"auth",
"authn",
"avatar",
"basic",
"captcha",
"cron",
"file_edit",
"login",
"mail",
"mail_template",
"mobile",
"path",
"preview",
"pwa",
"register",
"retry",
"share",
"slave",
"task",
"thumb",
"timeout",
"upload",
"version",
"view",
"wopi"
]
class SettingsType(StrEnum):
"""设置类型枚举"""
ARIA2 = "aria2"
AUTH = "auth"
AUTHN = "authn"
AVATAR = "avatar"
BASIC = "basic"
CAPTCHA = "captcha"
CRON = "cron"
FILE_EDIT = "file_edit"
LOGIN = "login"
MAIL = "mail"
MAIL_TEMPLATE = "mail_template"
MOBILE = "mobile"
PATH = "path"
PREVIEW = "preview"
PWA = "pwa"
REGISTER = "register"
RETRY = "retry"
SHARE = "share"
SLAVE = "slave"
TASK = "task"
THUMB = "thumb"
TIMEOUT = "timeout"
UPLOAD = "upload"
VERSION = "version"
VIEW = "view"
WOPI = "wopi"
# 数据库模型
class Setting(TableBase, table=True):
@@ -38,149 +38,6 @@ class Setting(TableBase, table=True):
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
type: str = Field(max_length=255, description="设置类型/分组")
type: SettingsType = Field(max_length=255, description="设置类型/分组")
name: str = Field(max_length=255, description="设置项名称")
value: str | None = Field(default=None, description="设置值")
@staticmethod
async def add(
type: SETTINGS_TYPE = None,
name: str = None,
value: str | None = None
) -> None:
"""
向数据库内添加设置项目。
:param type: 设置类型/分组
:type type: SETTINGS_TYPE
:param name: 设置项名称
:type name: str
:param value: 设置值,默认为 None
:type value: str | None
"""
from .database import get_session
if isinstance(value, (dict, list)):
value = str(value)
async for session in get_session():
new_setting = Setting(type=type, name=name, value=value)
session.add(new_setting)
await session.commit()
@staticmethod
async def get(
type: SETTINGS_TYPE,
name: str,
format: Literal['int', 'float', 'bool', 'str'] = 'str'
) -> Optional['Setting']:
"""
从数据库中获取指定类型和名称的设置项。
:param type: 设置类型/分组
:type type: SETTINGS_TYPE
:param name: 设置项名称
:type name: str
:return: 返回设置项对象,如果不存在则返回 None
:rtype: Optional[Setting]
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
statment = select(Setting).where(
Setting.type == type,
Setting.name == name
)
statment = await session.exec(statment)
result = statment.one_or_none()
result = result.value if result else None
# 根据 format 参数转换结果类型
if format == 'int':
return int(result) if result is not None else None
elif format == 'float':
return float(result) if result is not None else None
elif format == 'bool':
return result.lower() in ['true', '1'] if isinstance(result, str) else bool(result)
elif format == 'str':
return str(result) if result is not None else None
else:
raise ValueError(f"Unsupported format: {format}")
@staticmethod
async def set(
type: SETTINGS_TYPE,
name: str,
value: str | None = None
) -> None:
"""
更新指定类型和名称的设置项的值。
:param type: 设置类型/分组
:type type: SETTINGS_TYPE
:param name: 设置项名称
:type name: str
:param value: 新的设置值,默认为 None
:type value: str | None
:raises ValueError: 如果设置项不存在,则抛出异常
"""
from .database import get_session
from sqlmodel import select
if isinstance(value, (dict, list)):
value = str(value)
async for session in get_session():
statment = select(Setting).where(
Setting.type == type,
Setting.name == name
)
result = await session.exec(statment)
setting = result.one_or_none()
if not setting:
raise ValueError(f"Setting {type}.{name} does not exist.")
# 设置项存在,更新数据
setting.value = value
await session.commit()
@staticmethod
async def delete(
type: SETTINGS_TYPE,
name: str
) -> None:
"""
删除指定类型和名称的设置项。
:param type: 设置类型/分组
:type type: SETTINGS_TYPE
:param name: 设置项名称
:type name: str
:raises ValueError: 如果设置项不存在,则抛出异常
"""
from .database import get_session
from sqlmodel import select, delete
async for session in get_session():
statment = select(Setting).where(
Setting.type == type,
Setting.name == name
)
result = await session.exec(statment)
setting = result.one_or_none()
if not setting:
raise ValueError(f"Setting {type}.{name} does not exist.")
# 设置项存在,删除数据
await session.delete(setting)
await session.commit()
value: str | None = Field(default=None, description="设置值")

View File

@@ -31,7 +31,7 @@ class Share(TableBase, table=True):
remain_downloads: int | None = Field(default=None, description="剩余下载次数 (NULL为不限制)")
expires: datetime | None = Field(default=None, description="过期时间 (NULL为永不过期)")
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否允许预览")
source_name: str | None = Field(default=None, max_length=255, index=True, description="源名称(冗余字段,便于展示)")
source_name: str | None = Field(default=None, max_length=255, description="源名称(冗余字段,便于展示)")
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="兑换此分享所需的积分")
# 外键

View File

@@ -3,25 +3,23 @@ from datetime import datetime
from sqlmodel import Field, Relationship, UniqueConstraint
from .base import TableBase
from .group import Group
from .download import Download
from .file import File
from .folder import Folder
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .webdav import WebDAV
if TYPE_CHECKING:
from .group import Group
from .download import Download
from .file import File
from .folder import Folder
from .order import Order
from .share import Share
from .storage_pack import StoragePack
from .tag import Tag
from .task import Task
from .webdav import WebDAV
class User(TableBase, table=True):
"""用户模型"""
email: str = Field(max_length=100, unique=True, index=True)
"""用户邮箱,唯一"""
phone: str | None = Field(default=None, nullable=True, index=True)
"""用户手机号,唯一"""
username: str = Field(max_length=50, unique=True, index=True)
"""用户,唯一"""
nick: str | None = Field(default=None, max_length=50)
"""用户昵称"""