feat: add database session dependency for FastAPI routes
- Introduced a new dependency in `middleware/dependencies.py` to provide an asynchronous database session using SQLModel. - This dependency can be utilized in route functions to facilitate database operations.
This commit is contained in:
181
models/group.py
181
models/group.py
@@ -41,183 +41,4 @@ class Group(TableBase, table=True):
|
||||
previous_user: List["User"] = Relationship(
|
||||
back_populates="previous_group",
|
||||
sa_relationship_kwargs={"foreign_keys": "User.previous_group_id"}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
group: Optional["Group"] = None,
|
||||
name: str | None = None,
|
||||
policies: str | None = None,
|
||||
max_storage: int = 0,
|
||||
share_enabled: bool = False,
|
||||
web_dav_enabled: bool = False,
|
||||
speed_limit: int = 0,
|
||||
options: dict | None = None,
|
||||
) -> "Group":
|
||||
"""
|
||||
向数据库内添加用户组。如果提供了 `group` 参数,则使用该对象,否则创建一个新的用户组对象。
|
||||
|
||||
:param group: 用户组对象
|
||||
:type group: Group
|
||||
:return: 新创建的用户组对象
|
||||
:rtype: Group
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
import json
|
||||
|
||||
if not group:
|
||||
|
||||
if not name:
|
||||
raise ValueError("Group name is required.")
|
||||
|
||||
group = Group(
|
||||
name=name,
|
||||
policies=policies,
|
||||
max_storage=max_storage,
|
||||
share_enabled=share_enabled,
|
||||
web_dav_enabled=web_dav_enabled,
|
||||
speed_limit=speed_limit,
|
||||
options=json.dumps(options) if options else None,
|
||||
)
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
await session.refresh(group)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
return group
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
id: int = None
|
||||
) -> Optional["Group"]:
|
||||
"""
|
||||
获取用户组信息。
|
||||
|
||||
:param id: 用户组ID,默认为 None
|
||||
:type id: int
|
||||
|
||||
:return: 用户组对象或 None
|
||||
:rtype: Optional[Group]
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
session = get_session()
|
||||
|
||||
if id is None:
|
||||
return None
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if group:
|
||||
return group
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def set(
|
||||
id: int,
|
||||
name: str | None = None,
|
||||
policies: str | None = None,
|
||||
max_storage: int | None = None,
|
||||
share_enabled: bool | None = None,
|
||||
web_dav_enabled: bool | None = None,
|
||||
speed_limit: int | None = None,
|
||||
options: str | None = None
|
||||
) -> Optional["Group"]:
|
||||
"""
|
||||
更新用户组信息。
|
||||
|
||||
:param id: 用户组ID
|
||||
:type id: int
|
||||
:param name: 用户组名
|
||||
:type name: str | None
|
||||
:param policies: 允许的策略ID列表,逗号分隔
|
||||
:type policies: str | None
|
||||
:param max_storage: 最大存储空间(字节)
|
||||
:type max_storage: int | None
|
||||
:param share_enabled: 是否允许创建分享
|
||||
:type share_enabled: bool | None
|
||||
:param web_dav_enabled: 是否允许使用WebDAV
|
||||
:type web_dav_enabled: bool | None
|
||||
:param speed_limit: 速度限制 (KB/s), 0为不限制
|
||||
:type speed_limit: int | None
|
||||
:param options: 其他选项 (JSON格式)
|
||||
:type options: str | None
|
||||
:return: 更新后的用户组对象或 None
|
||||
:rtype: Optional[Group]
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if not group:
|
||||
raise ValueError(f"Group with id {id} not found.")
|
||||
|
||||
if name is not None:
|
||||
group.name = name
|
||||
if policies is not None:
|
||||
group.policies = policies
|
||||
if max_storage is not None:
|
||||
group.max_storage = max_storage
|
||||
if share_enabled is not None:
|
||||
group.share_enabled = share_enabled
|
||||
if web_dav_enabled is not None:
|
||||
group.web_dav_enabled = web_dav_enabled
|
||||
if speed_limit is not None:
|
||||
group.speed_limit = speed_limit
|
||||
if options is not None:
|
||||
group.options = options
|
||||
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
|
||||
return group
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
id: int
|
||||
) -> None:
|
||||
"""
|
||||
删除用户组。
|
||||
|
||||
:param id: 用户组ID
|
||||
:type id: int
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(Group).where(Group.id == id)
|
||||
result = await session.exec(statement)
|
||||
group = result.one_or_none()
|
||||
|
||||
if group is None:
|
||||
raise ValueError(f"Group with id {id} not found.")
|
||||
|
||||
await session.delete(group)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
)
|
||||
@@ -73,7 +73,7 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
Setting(name="hot_share_num", value="10", type="share"),
|
||||
Setting(name="gravatar_server", value="https://www.gravatar.com/", type="avatar"),
|
||||
Setting(name="defaultTheme", value="#3f51b5", type="basic"),
|
||||
Setting(name="themes", value=ThemeModel().model_dump(), type="basic"),
|
||||
Setting(name="themes", value=ThemeModel().model_dump_json(), type="basic"),
|
||||
Setting(name="aria2_token", value="", type="aria2"),
|
||||
Setting(name="aria2_rpcurl", value="", type="aria2"),
|
||||
Setting(name="aria2_temp_path", value="", type="aria2"),
|
||||
@@ -118,110 +118,99 @@ Neue',Helvetica,Arial,sans-serif; box-sizing: border-box; font-size: 14px; verti
|
||||
|
||||
async def init_default_settings() -> None:
|
||||
from .setting import Setting
|
||||
|
||||
from .database import get_session
|
||||
from sqlalchemy import and_
|
||||
|
||||
log.info('初始化设置...')
|
||||
|
||||
try:
|
||||
|
||||
async for session in get_session():
|
||||
# 检查是否已经存在版本设置
|
||||
ver = await Setting.get(type="version", name=f"db_version_{BackendVersion}")
|
||||
if ver == "installed":
|
||||
ver = await Setting.get(
|
||||
session,
|
||||
and_(Setting.type == "version", Setting.name == f"db_version_{BackendVersion}")
|
||||
)
|
||||
if ver and ver.value == "installed":
|
||||
return
|
||||
else: raise ValueError("Database version mismatch or not installed.")
|
||||
except:
|
||||
for setting in default_settings:
|
||||
await Setting.add(
|
||||
type=setting.type,
|
||||
name=setting.name,
|
||||
value=setting.value
|
||||
)
|
||||
|
||||
# 批量添加默认设置
|
||||
await Setting.add(session, default_settings)
|
||||
|
||||
async def init_default_group() -> None:
|
||||
from .group import Group
|
||||
|
||||
from .group import Group, GroupOptions
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化用户组...')
|
||||
|
||||
try:
|
||||
|
||||
async for session in get_session():
|
||||
# 未找到初始管理组时,则创建
|
||||
if not await Group.get(id=1):
|
||||
await Group.create(
|
||||
if not await Group.get(session, Group.id == 1):
|
||||
admin_group = Group(
|
||||
name="管理员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
options={
|
||||
"ArchiveDownload": True,
|
||||
"ArchiveTask": True,
|
||||
"ShareDownload": True,
|
||||
"Aria2": True,
|
||||
}
|
||||
admin=True,
|
||||
options=GroupOptions(
|
||||
archive_download=True,
|
||||
archive_task=True,
|
||||
share_download=True,
|
||||
aria2=True,
|
||||
).model_dump(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建管理员用户组: {e}")
|
||||
await admin_group.save(session)
|
||||
|
||||
try:
|
||||
# 未找到初始注册会员时,则创建
|
||||
if not await Group.get(id=2):
|
||||
await Group.create(
|
||||
if not await Group.get(session, Group.id == 2):
|
||||
member_group = Group(
|
||||
name="注册会员",
|
||||
max_storage=1 * 1024 * 1024 * 1024, # 1GB
|
||||
share_enabled=True,
|
||||
web_dav_enabled=True,
|
||||
options={
|
||||
"ShareDownload": True,
|
||||
}
|
||||
options=GroupOptions(share_download=True).model_dump(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始注册会员用户组: {e}")
|
||||
|
||||
try:
|
||||
await member_group.save(session)
|
||||
|
||||
# 未找到初始游客组时,则创建
|
||||
if not await Group.get(id=3):
|
||||
await Group.create(
|
||||
if not await Group.get(session, Group.id == 3):
|
||||
guest_group = Group(
|
||||
name="游客",
|
||||
policies="[]",
|
||||
share_enabled=False,
|
||||
web_dav_enabled=False,
|
||||
options={
|
||||
"ShareDownload": True,
|
||||
}
|
||||
options=GroupOptions(share_download=True).model_dump(),
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"无法创建初始游客用户组: {e}")
|
||||
await guest_group.save(session)
|
||||
|
||||
async def init_default_user() -> None:
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
from .user import User
|
||||
from .group import Group
|
||||
from .database import get_session
|
||||
|
||||
log.info('初始化管理员用户...')
|
||||
|
||||
async for session in get_session():
|
||||
# 检查管理员用户是否存在
|
||||
admin_user = await User.get(session, User.id == 1)
|
||||
|
||||
if not admin_user:
|
||||
# 创建初始管理员用户
|
||||
|
||||
# 获取管理员组
|
||||
admin_group = await Group.get(id=1)
|
||||
admin_group = await Group.get(session, Group.id == 1)
|
||||
if not admin_group:
|
||||
raise RuntimeError("管理员用户组不存在,无法创建管理员用户")
|
||||
|
||||
# 生成管理员密码
|
||||
from pkg.password.pwd import Password
|
||||
admin_password = Password.generate(8)
|
||||
hashed_admin_password = Password.hash(admin_password)
|
||||
|
||||
admin_user = User(
|
||||
email="admin@yxqi.cn",
|
||||
username="admin",
|
||||
nick="admin",
|
||||
status=True, # 正常状态
|
||||
status=True,
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
|
||||
admin_user = await admin_user.save(session)
|
||||
await admin_user.save(session)
|
||||
|
||||
log.info(f'初始管理员账号:[bold]admin@yxqi.cn[/bold]')
|
||||
log.info(f'初始管理员账号:[bold]admin[/bold]')
|
||||
log.info(f'初始管理员密码:[bold]{admin_password}[/bold]')
|
||||
@@ -39,7 +39,7 @@ class TokenModel(BaseModel):
|
||||
refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间")
|
||||
refresh_token: str = Field(default=None, description="刷新令牌")
|
||||
|
||||
class groupModel(BaseModel):
|
||||
class GroupModel(BaseModel):
|
||||
'''
|
||||
用户组模型
|
||||
'''
|
||||
@@ -58,7 +58,7 @@ class groupModel(BaseModel):
|
||||
selectNode: bool = Field(default=False, description="是否允许选择离线下载节点")
|
||||
advanceDelete: bool = Field(default=False, description="是否允许高级删除")
|
||||
|
||||
class userModel(BaseModel):
|
||||
class UserModel(BaseModel):
|
||||
'''
|
||||
用户模型
|
||||
'''
|
||||
@@ -71,7 +71,7 @@ class userModel(BaseModel):
|
||||
preferred_theme: ThemeModel = Field(default_factory=ThemeModel, description="用户首选主题")
|
||||
score: int = Field(default=0, description="用户积分")
|
||||
anonymous: bool = Field(default=False, description="是否为匿名用户")
|
||||
group: groupModel = Field(default_factory=None, description="用户所属用户组")
|
||||
group: GroupModel = Field(default_factory=None, description="用户所属用户组")
|
||||
tags: list = Field(default_factory=list, description="用户标签列表")
|
||||
|
||||
class SiteConfigModel(ResponseModel):
|
||||
|
||||
@@ -1,36 +1,36 @@
|
||||
|
||||
from typing import Optional, Literal
|
||||
from sqlmodel import Field, UniqueConstraint, Column, func, DateTime
|
||||
from sqlmodel import Field, UniqueConstraint
|
||||
from .base import TableBase
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
SETTINGS_TYPE = Literal[
|
||||
"auth",
|
||||
"authn",
|
||||
"avatar",
|
||||
"basic",
|
||||
"captcha",
|
||||
"cron",
|
||||
"file_edit",
|
||||
"login",
|
||||
"mail",
|
||||
"mail_template",
|
||||
"mobile",
|
||||
"path",
|
||||
"preview",
|
||||
"pwa",
|
||||
"register",
|
||||
"retry",
|
||||
"share",
|
||||
"slave",
|
||||
"task",
|
||||
"thumb",
|
||||
"timeout",
|
||||
"upload",
|
||||
"version",
|
||||
"view",
|
||||
"wopi"
|
||||
]
|
||||
class SettingsType(StrEnum):
|
||||
"""设置类型枚举"""
|
||||
|
||||
ARIA2 = "aria2"
|
||||
AUTH = "auth"
|
||||
AUTHN = "authn"
|
||||
AVATAR = "avatar"
|
||||
BASIC = "basic"
|
||||
CAPTCHA = "captcha"
|
||||
CRON = "cron"
|
||||
FILE_EDIT = "file_edit"
|
||||
LOGIN = "login"
|
||||
MAIL = "mail"
|
||||
MAIL_TEMPLATE = "mail_template"
|
||||
MOBILE = "mobile"
|
||||
PATH = "path"
|
||||
PREVIEW = "preview"
|
||||
PWA = "pwa"
|
||||
REGISTER = "register"
|
||||
RETRY = "retry"
|
||||
SHARE = "share"
|
||||
SLAVE = "slave"
|
||||
TASK = "task"
|
||||
THUMB = "thumb"
|
||||
TIMEOUT = "timeout"
|
||||
UPLOAD = "upload"
|
||||
VERSION = "version"
|
||||
VIEW = "view"
|
||||
WOPI = "wopi"
|
||||
|
||||
# 数据库模型
|
||||
class Setting(TableBase, table=True):
|
||||
@@ -38,149 +38,6 @@ class Setting(TableBase, table=True):
|
||||
|
||||
__table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),)
|
||||
|
||||
type: str = Field(max_length=255, description="设置类型/分组")
|
||||
type: SettingsType = Field(max_length=255, description="设置类型/分组")
|
||||
name: str = Field(max_length=255, description="设置项名称")
|
||||
value: str | None = Field(default=None, description="设置值")
|
||||
|
||||
@staticmethod
|
||||
async def add(
|
||||
type: SETTINGS_TYPE = None,
|
||||
name: str = None,
|
||||
value: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
向数据库内添加设置项目。
|
||||
|
||||
:param type: 设置类型/分组
|
||||
:type type: SETTINGS_TYPE
|
||||
:param name: 设置项名称
|
||||
:type name: str
|
||||
:param value: 设置值,默认为 None
|
||||
:type value: str | None
|
||||
"""
|
||||
from .database import get_session
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
value = str(value)
|
||||
|
||||
async for session in get_session():
|
||||
new_setting = Setting(type=type, name=name, value=value)
|
||||
session.add(new_setting)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def get(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str,
|
||||
format: Literal['int', 'float', 'bool', 'str'] = 'str'
|
||||
) -> Optional['Setting']:
|
||||
"""
|
||||
从数据库中获取指定类型和名称的设置项。
|
||||
|
||||
:param type: 设置类型/分组
|
||||
:type type: SETTINGS_TYPE
|
||||
:param name: 设置项名称
|
||||
:type name: str
|
||||
|
||||
:return: 返回设置项对象,如果不存在则返回 None
|
||||
:rtype: Optional[Setting]
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
statment = select(Setting).where(
|
||||
Setting.type == type,
|
||||
Setting.name == name
|
||||
)
|
||||
|
||||
statment = await session.exec(statment)
|
||||
result = statment.one_or_none()
|
||||
result = result.value if result else None
|
||||
|
||||
# 根据 format 参数转换结果类型
|
||||
if format == 'int':
|
||||
return int(result) if result is not None else None
|
||||
elif format == 'float':
|
||||
return float(result) if result is not None else None
|
||||
elif format == 'bool':
|
||||
return result.lower() in ['true', '1'] if isinstance(result, str) else bool(result)
|
||||
elif format == 'str':
|
||||
return str(result) if result is not None else None
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
@staticmethod
|
||||
async def set(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str,
|
||||
value: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
更新指定类型和名称的设置项的值。
|
||||
|
||||
:param type: 设置类型/分组
|
||||
:type type: SETTINGS_TYPE
|
||||
:param name: 设置项名称
|
||||
:type name: str
|
||||
:param value: 新的设置值,默认为 None
|
||||
:type value: str | None
|
||||
|
||||
:raises ValueError: 如果设置项不存在,则抛出异常
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
if isinstance(value, (dict, list)):
|
||||
value = str(value)
|
||||
|
||||
async for session in get_session():
|
||||
statment = select(Setting).where(
|
||||
Setting.type == type,
|
||||
Setting.name == name
|
||||
)
|
||||
|
||||
result = await session.exec(statment)
|
||||
setting = result.one_or_none()
|
||||
|
||||
if not setting:
|
||||
raise ValueError(f"Setting {type}.{name} does not exist.")
|
||||
|
||||
# 设置项存在,更新数据
|
||||
setting.value = value
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def delete(
|
||||
type: SETTINGS_TYPE,
|
||||
name: str
|
||||
) -> None:
|
||||
"""
|
||||
删除指定类型和名称的设置项。
|
||||
|
||||
:param type: 设置类型/分组
|
||||
:type type: SETTINGS_TYPE
|
||||
:param name: 设置项名称
|
||||
:type name: str
|
||||
|
||||
:raises ValueError: 如果设置项不存在,则抛出异常
|
||||
"""
|
||||
from .database import get_session
|
||||
from sqlmodel import select, delete
|
||||
|
||||
async for session in get_session():
|
||||
statment = select(Setting).where(
|
||||
Setting.type == type,
|
||||
Setting.name == name
|
||||
)
|
||||
|
||||
result = await session.exec(statment)
|
||||
setting = result.one_or_none()
|
||||
|
||||
if not setting:
|
||||
raise ValueError(f"Setting {type}.{name} does not exist.")
|
||||
|
||||
# 设置项存在,删除数据
|
||||
await session.delete(setting)
|
||||
await session.commit()
|
||||
value: str | None = Field(default=None, description="设置值")
|
||||
@@ -31,7 +31,7 @@ class Share(TableBase, table=True):
|
||||
remain_downloads: int | None = Field(default=None, description="剩余下载次数 (NULL为不限制)")
|
||||
expires: datetime | None = Field(default=None, description="过期时间 (NULL为永不过期)")
|
||||
preview_enabled: bool = Field(default=True, sa_column_kwargs={"server_default": text("true")}, description="是否允许预览")
|
||||
source_name: str | None = Field(default=None, max_length=255, index=True, description="源名称(冗余字段,便于展示)")
|
||||
source_name: str | None = Field(default=None, max_length=255, description="源名称(冗余字段,便于展示)")
|
||||
score: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="兑换此分享所需的积分")
|
||||
|
||||
# 外键
|
||||
|
||||
@@ -3,25 +3,23 @@ from datetime import datetime
|
||||
from sqlmodel import Field, Relationship, UniqueConstraint
|
||||
from .base import TableBase
|
||||
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
from .order import Order
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .webdav import WebDAV
|
||||
if TYPE_CHECKING:
|
||||
from .group import Group
|
||||
from .download import Download
|
||||
from .file import File
|
||||
from .folder import Folder
|
||||
from .order import Order
|
||||
from .share import Share
|
||||
from .storage_pack import StoragePack
|
||||
from .tag import Tag
|
||||
from .task import Task
|
||||
from .webdav import WebDAV
|
||||
|
||||
class User(TableBase, table=True):
|
||||
"""用户模型"""
|
||||
|
||||
email: str = Field(max_length=100, unique=True, index=True)
|
||||
"""用户邮箱,唯一"""
|
||||
|
||||
phone: str | None = Field(default=None, nullable=True, index=True)
|
||||
"""用户手机号,唯一"""
|
||||
username: str = Field(max_length=50, unique=True, index=True)
|
||||
"""用户名,唯一"""
|
||||
|
||||
nick: str | None = Field(default=None, max_length=50)
|
||||
"""用户昵称"""
|
||||
|
||||
Reference in New Issue
Block a user