Refactor models and routes for item management
Reorganized model structure by replacing 'object' and 'items' with a unified 'item' model using UUIDs, and moved base model logic into separate files. Updated routes to use the new item model and improved request/response handling. Enhanced user and setting models, added utility functions, and improved error handling throughout the codebase. Also added initial .idea project files and minor admin API improvements. Co-Authored-By: 砂糖橘 <54745033+Foxerine@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from . import token
|
||||
from .setting import Setting
|
||||
from .object import Object
|
||||
from .user import User
|
||||
from .response import DefaultResponse, TokenResponse, TokenData
|
||||
from .setting import Setting, SettingResponse
|
||||
from .item import Item, ItemDataResponse, ItemTypeEnum, ItemStatusEnum
|
||||
from .user import User, UserTypeEnum
|
||||
from .database import Database
|
||||
|
||||
151
model/base.py
151
model/base.py
@@ -1,151 +0,0 @@
|
||||
# model/base.py
|
||||
from datetime import datetime, timezone
|
||||
from typing import Type, TypeVar, Union, Literal, List
|
||||
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlmodel import SQLModel, Field, select, Relationship
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
|
||||
B = TypeVar('B', bound='TableBase')
|
||||
M = TypeVar('M', bound='SQLModel')
|
||||
|
||||
utcnow = lambda: datetime.now(tz=timezone.utc)
|
||||
|
||||
class TableBase(AsyncAttrs, SQLModel):
|
||||
__abstract__ = True
|
||||
|
||||
created_at: datetime = Field(
|
||||
default_factory=utcnow,
|
||||
description="创建时间",
|
||||
)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
description="更新时间",
|
||||
sa_column_kwargs={"default": utcnow, "onupdate": utcnow},
|
||||
default_factory=utcnow
|
||||
)
|
||||
deleted_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="删除时间",
|
||||
sa_column={"nullable": True}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
instances: B | List[B],
|
||||
refresh: bool = True
|
||||
) -> B | List[B]:
|
||||
is_list = isinstance(instances, list)
|
||||
if is_list:
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
await session.commit()
|
||||
if refresh:
|
||||
if is_list:
|
||||
for i in instances:
|
||||
await session.refresh(i)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
return instances
|
||||
|
||||
async def save(
|
||||
self: B,
|
||||
session: AsyncSession,
|
||||
load: Union[Relationship, None] = None, # 设默认值,避免必须传
|
||||
):
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load) # 若该模型没有 id,请别用 load 模式
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: B,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict = None,
|
||||
exclude_unset: bool = True,
|
||||
) -> B:
|
||||
self.sqlmodel_update(
|
||||
other.model_dump(exclude_unset=exclude_unset),
|
||||
update=extra_data
|
||||
)
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
instance: B | list[B],
|
||||
) -> None:
|
||||
if isinstance(instance, list):
|
||||
for inst in instance:
|
||||
await session.delete(inst)
|
||||
else:
|
||||
await session.delete(instance)
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: Type[B],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: Type[B] | tuple[Type[B], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: Union[Relationship, None] = None,
|
||||
order_by: list[ClauseElement] | None = None
|
||||
) -> B | List[B] | None:
|
||||
statement = select(cls)
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
if join is not None:
|
||||
statement = statement.join(*join)
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
if load:
|
||||
statement = statement.options(selectinload(load))
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await session.exec(statement)
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[B], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> B:
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
|
||||
# 需要“自增 id 主键”的模型才混入它;Setting 不混入
|
||||
class IdMixin(SQLModel):
|
||||
id: int | None = Field(default=None, primary_key=True, description="主键ID")
|
||||
3
model/base/__init__.py
Normal file
3
model/base/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sqlmodel_base import SQLModelBase
|
||||
from .table_base import TableBase, UUIDTableBase
|
||||
|
||||
5
model/base/sqlmodel_base.py
Normal file
5
model/base/sqlmodel_base.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from pydantic import ConfigDict
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
class SQLModelBase(SQLModel):
|
||||
model_config = ConfigDict(use_attribute_docstrings=True)
|
||||
200
model/base/table_base.py
Normal file
200
model/base/table_base.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Union, TypeVar, Type, Literal, override, Optional, Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import DateTime, BinaryExpression, ClauseElement
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import Field, select, Relationship, SQLModel
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlalchemy.sql._typing import _OnClauseArgument
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
|
||||
T = TypeVar("T", bound="TableBase")
|
||||
M = TypeVar("M", bound="SQLModel")
|
||||
|
||||
now = lambda: datetime.now()
|
||||
now_date = lambda: datetime.now().date()
|
||||
|
||||
class TableBase(AsyncAttrs):
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
|
||||
created_at: datetime = Field(default_factory=now)
|
||||
updated_at: datetime = Field(
|
||||
sa_type=DateTime,
|
||||
sa_column_kwargs={"default": now, "onupdate": now},
|
||||
default_factory=now
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def add(cls: Type[T], session: AsyncSession, instances: T | list[T], refresh: bool = True) -> T | list[T]:
|
||||
"""
|
||||
新增一条记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:param refresh:
|
||||
:return: 新增的实例对象
|
||||
|
||||
usage:
|
||||
item1 = Item(...)
|
||||
item2 = Item(...)
|
||||
|
||||
Item.add(session, [item1, item2])
|
||||
|
||||
item1_id = item1.id
|
||||
"""
|
||||
is_list = False
|
||||
if isinstance(instances, list):
|
||||
is_list = True
|
||||
session.add_all(instances)
|
||||
else:
|
||||
session.add(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
if refresh:
|
||||
if is_list:
|
||||
for instance in instances:
|
||||
await session.refresh(instance)
|
||||
else:
|
||||
await session.refresh(instances)
|
||||
|
||||
return instances
|
||||
|
||||
async def save(self: T, session: AsyncSession, load: Optional[Relationship] = None) -> T:
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
if load is not None:
|
||||
cls = type(self)
|
||||
return await cls.get(session, cls.id == self.id, load=load)
|
||||
else:
|
||||
await session.refresh(self)
|
||||
return self
|
||||
|
||||
async def update(
|
||||
self: T,
|
||||
session: AsyncSession,
|
||||
other: M,
|
||||
extra_data: dict = None,
|
||||
exclude_unset: bool = True
|
||||
) -> T:
|
||||
"""
|
||||
更新记录
|
||||
:param session: 数据库会话
|
||||
:param other:
|
||||
:param extra_data:
|
||||
:param exclude_unset:
|
||||
:return:
|
||||
"""
|
||||
self.sqlmodel_update(other.model_dump(exclude_unset=exclude_unset), update=extra_data)
|
||||
|
||||
session.add(self)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(self)
|
||||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def delete(cls: Type[T], session: AsyncSession, instances: T | list[T]) -> None:
|
||||
"""
|
||||
删除一些记录
|
||||
:param session: 数据库会话
|
||||
:param instances:
|
||||
:return: None
|
||||
|
||||
usage:
|
||||
item1 = Item.get(...)
|
||||
item2 = Item.get(...)
|
||||
|
||||
Item.delete(session, [item1, item2])
|
||||
|
||||
"""
|
||||
if isinstance(instances, list):
|
||||
for instance in instances:
|
||||
await session.delete(instance)
|
||||
else:
|
||||
await session.delete(instances)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get(
|
||||
cls: Type[T],
|
||||
session: AsyncSession,
|
||||
condition: BinaryExpression | ClauseElement | None,
|
||||
*,
|
||||
offset: int | None = None,
|
||||
limit: int | None = None,
|
||||
fetch_mode: Literal["one", "first", "all"] = "first",
|
||||
join: Type[T] | tuple[Type[T], _OnClauseArgument] | None = None,
|
||||
options: list | None = None,
|
||||
load: Union[Relationship, None] = None,
|
||||
order_by: list[ClauseElement] | None = None
|
||||
) -> T | list[T] | None:
|
||||
"""
|
||||
异步获取模型实例
|
||||
|
||||
参数:
|
||||
session: 异步数据库会话
|
||||
condition: SQLAlchemy查询条件,如Model.id == 1
|
||||
offset: 结果偏移量
|
||||
limit: 结果数量限制
|
||||
options: 查询选项,如selectinload(Model.relation),异步访问关系属性必备,不然会报错
|
||||
fetch_mode: 获取模式 - "one"/"all"/"first"
|
||||
join: 要联接的模型类
|
||||
|
||||
返回:
|
||||
根据fetch_mode返回相应的查询结果
|
||||
"""
|
||||
statement = select(cls)
|
||||
|
||||
if condition is not None:
|
||||
statement = statement.where(condition)
|
||||
|
||||
if join is not None:
|
||||
statement = statement.join(*join)
|
||||
|
||||
if options:
|
||||
statement = statement.options(*options)
|
||||
|
||||
if load:
|
||||
statement = statement.options(selectinload(load))
|
||||
|
||||
if order_by is not None:
|
||||
statement = statement.order_by(*order_by)
|
||||
|
||||
if offset:
|
||||
statement = statement.offset(offset)
|
||||
|
||||
if limit:
|
||||
statement = statement.limit(limit)
|
||||
|
||||
result = await session.exec(statement)
|
||||
|
||||
if fetch_mode == "one":
|
||||
return result.one()
|
||||
elif fetch_mode == "first":
|
||||
return result.first()
|
||||
elif fetch_mode == "all":
|
||||
return list(result.all())
|
||||
else:
|
||||
raise ValueError(f"无效的 fetch_mode: {fetch_mode}")
|
||||
|
||||
@classmethod
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: int, load: Union[Relationship, None] = None) -> T:
|
||||
"""此方法和 await session.get(cls, 主键)的区别就是当不存在时不返回None,
|
||||
而是会抛出fastapi 404 异常"""
|
||||
instance = await cls.get(session, cls.id == id, load=load)
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return instance
|
||||
|
||||
class UUIDTableBase(TableBase):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
"""override"""
|
||||
|
||||
@override
|
||||
async def get_exist_one(cls: Type[T], session: AsyncSession, id: uuid.UUID, load: Union[Relationship, None] = None) -> T:
|
||||
return super().get_exist_one(session, id, load) # type: ignore
|
||||
@@ -71,4 +71,4 @@ class Database:
|
||||
|
||||
# For internal use, create a temporary context manager
|
||||
async with self.session_context() as session:
|
||||
await migration(session) # 执行迁移脚本
|
||||
await migration(session) # 执行迁移脚本
|
||||
|
||||
77
model/item.py
Normal file
77
model/item.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Self, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class ItemTypeEnum(StrEnum):
|
||||
normal = 'normal'
|
||||
car = 'car'
|
||||
|
||||
class ItemStatusEnum(StrEnum):
|
||||
ok = 'ok'
|
||||
lost = 'lost'
|
||||
|
||||
class ItemBase(SQLModelBase):
|
||||
type: ItemTypeEnum = ItemTypeEnum.normal
|
||||
"""物品的类型"""
|
||||
|
||||
name: str
|
||||
"""物品名称"""
|
||||
|
||||
icon: str | None = None
|
||||
"""物品图标"""
|
||||
|
||||
status: ItemStatusEnum = ItemStatusEnum.ok
|
||||
"""物品状态"""
|
||||
|
||||
phone: str | None = None
|
||||
"""联系电话"""
|
||||
|
||||
description: str | None = None
|
||||
"""物品描述"""
|
||||
|
||||
class Item(ItemBase, UUIDTableBase, table=True):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
find_ip: str | None = None
|
||||
"""最后一次发现的IP地址"""
|
||||
|
||||
user_id: UUID = Field(foreign_key='user.id', ondelete='CASCADE')
|
||||
"""所属用户ID"""
|
||||
|
||||
user: 'User' = Relationship(back_populates='items')
|
||||
|
||||
parent_item_id: UUID | None = Field(foreign_key='item.id', ondelete='RESTRICT')
|
||||
parent_item: Optional['Item'] = Relationship(back_populates='sub_items', sa_relationship_kwargs={'remote_side': 'Item.id'})
|
||||
sub_items: list['Item'] = Relationship(back_populates='parent_item', passive_deletes='all')
|
||||
|
||||
class ItemDataUpdateRequest(ItemBase):
|
||||
pass
|
||||
|
||||
class ItemDataResponse(ItemBase):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
class ItemDataResponseAdmin(ItemBase):
|
||||
expires_at: datetime | None = None
|
||||
"""物品过期时间"""
|
||||
|
||||
lost_at: datetime | None = None
|
||||
"""物品丢失的时间"""
|
||||
|
||||
user_id: UUID = Field(foreign_key='user.id')
|
||||
"""所属用户ID"""
|
||||
@@ -1,14 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Item(BaseModel):
|
||||
id: int
|
||||
type: str
|
||||
key: str
|
||||
name: str
|
||||
icon: str
|
||||
status: str
|
||||
phone: int
|
||||
lost_description: str | None
|
||||
find_ip: str | None
|
||||
create_time: str
|
||||
lost_time: str | None
|
||||
@@ -1,13 +1,12 @@
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
from .setting import Setting
|
||||
from .user import User
|
||||
from .user import User, UserTypeEnum
|
||||
from pkg import Password
|
||||
|
||||
default_settings: list[Setting] = [
|
||||
Setting(type='string', name='version', value='1.0.0'),
|
||||
Setting(type='int', name='jwt_token_exp', value='30'),
|
||||
Setting(type='string', name='server_chan_key', value=''),
|
||||
Setting(type='string', name='version', value='2.0.0'), # 版本号,用于考虑是否需要数据迁移
|
||||
Setting(type='int', name='jwt_token_exp', value='30'), # JWT Token 访问令牌
|
||||
Setting(type='string', name='server_chan_key', value=''), # Server 酱推送密钥
|
||||
]
|
||||
|
||||
async def migration(session):
|
||||
@@ -24,39 +23,32 @@ async def migration(session):
|
||||
names = [s.name for s in settings]
|
||||
existed_settings = await Setting.get(
|
||||
session,
|
||||
Setting.name.in_(names),
|
||||
fetch_mode="all"
|
||||
Setting.name in names,
|
||||
fetch_mode='all'
|
||||
)
|
||||
existed: set[str] = {s.name for s in (existed_settings or [])}
|
||||
|
||||
to_insert = [s for s in settings if s.name not in existed]
|
||||
if to_insert:
|
||||
await Setting.add(session, to_insert, refresh=False)
|
||||
|
||||
if await User.get(session, User.id == 1):
|
||||
# 已有超级管理员用户,说明不是第一次运行
|
||||
await Setting.add(session, to_insert)
|
||||
|
||||
# 修复数据库id为1的用户不是管理员的问题
|
||||
admin_user = await User.get(session, User.id == 1)
|
||||
if admin_user and not admin_user.is_admin:
|
||||
admin_user.is_admin = True
|
||||
await User.update(session, admin_user, refresh=False)
|
||||
|
||||
# 已有用户,直接返回
|
||||
return
|
||||
if not await User.get(session, User.role == UserTypeEnum.super_admin):
|
||||
# 生成初始密码与密钥
|
||||
admin_password = Password.generate()
|
||||
logger.warning("当前无管理员用户,已自动创建初始管理员用户:")
|
||||
logger.warning("邮箱: admin@yxqi.cn")
|
||||
logger.warning(f"密码: {admin_password}")
|
||||
|
||||
# 生成初始密码与密钥
|
||||
admin_password = Password.generate()
|
||||
logger.warning("当前无管理员用户,已自动创建初始管理员用户:")
|
||||
logger.warning("邮箱: admin@yxqi.cn")
|
||||
logger.warning(f"密码: {admin_password}")
|
||||
User._initializing = True
|
||||
|
||||
admin_user = User(
|
||||
id=1,
|
||||
email='admin@yxqi.cn',
|
||||
username='Admin',
|
||||
password=Password.hash(admin_password),
|
||||
is_admin=True
|
||||
)
|
||||
admin_user = User(
|
||||
email='admin@yxqi.cn',
|
||||
username='Admin',
|
||||
password=Password.hash(admin_password),
|
||||
role=UserTypeEnum.super_admin,
|
||||
_initializing=True
|
||||
)
|
||||
|
||||
await User.add(session, admin_user, refresh=False)
|
||||
await User.add(session, admin_user)
|
||||
|
||||
User._initializing = False
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
from typing import Literal, TYPE_CHECKING
|
||||
from sqlmodel import Field, Column, String, DateTime, Relationship
|
||||
from .base import TableBase, IdMixin
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import User
|
||||
|
||||
class Object(IdMixin, TableBase, table=True):
|
||||
|
||||
user_id: int = Field(foreign_key="user.id", index=True, nullable=False, description="所属用户ID")
|
||||
key: str = Field(index=True, nullable=False, unique=True, description="物品外部ID")
|
||||
type: Literal['normal', 'car'] = Field(
|
||||
default='normal',
|
||||
description="物品类型",
|
||||
sa_column=Column(
|
||||
String,
|
||||
default='normal',
|
||||
nullable=False
|
||||
)
|
||||
)
|
||||
name: str = Field(nullable=False, description="物品名称")
|
||||
icon: str | None = Field(default=None, description="物品图标")
|
||||
status: Literal['ok', 'lost'] = Field(
|
||||
default='ok',
|
||||
description="物品状态",
|
||||
sa_column=Column(
|
||||
String,
|
||||
default='ok',
|
||||
nullable=False
|
||||
)
|
||||
)
|
||||
phone: str | None = Field(default=None, description="联系电话")
|
||||
description: str | None = Field(default=None, description="物品描述")
|
||||
find_ip: str | None = Field(default=None, description="最后一次发现的IP地址")
|
||||
lost_at: datetime | None = Field(
|
||||
default=None,
|
||||
description="物品标记为丢失的时间",
|
||||
sa_column=Column(
|
||||
DateTime,
|
||||
nullable=True
|
||||
)
|
||||
)
|
||||
|
||||
user: "User" = Relationship(back_populates="objects")
|
||||
@@ -1,20 +1,15 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
class DefaultResponse(BaseModel):
|
||||
code: int = 0
|
||||
data: dict | list | bool | None = None
|
||||
data: dict | list | bool | None
|
||||
msg: str = ""
|
||||
|
||||
class ObjectData(BaseModel):
|
||||
id: int
|
||||
type: Literal['normal', 'car']
|
||||
key: str
|
||||
name: str
|
||||
icon: str
|
||||
status: Literal['ok', 'lost']
|
||||
phone: str
|
||||
context: str | None = None
|
||||
lost_description: str | None = None
|
||||
create_time: str
|
||||
lost_time: str | None = None
|
||||
# FastAPI 鉴权返回模型
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from sqlmodel import Field
|
||||
from .base import TableBase
|
||||
from .base import TableBase, SQLModelBase
|
||||
|
||||
class Setting(TableBase, table=True):
|
||||
|
||||
type: str = Field(index=True, nullable=False, description="设置类型")
|
||||
name: str = Field(primary_key=True, nullable=False, description="设置名称") # name 为唯一主键
|
||||
value: str | None = Field(description="设置值")
|
||||
class SettingBase(SQLModelBase):
|
||||
type: str = Field(index=True)
|
||||
"""设置类型"""
|
||||
|
||||
name: str = Field(index=True, unique=True) # name 为唯一主键
|
||||
"""设置名称"""
|
||||
|
||||
value: str | None
|
||||
"""设置值"""
|
||||
|
||||
class Setting(SettingBase, TableBase, table=True):
|
||||
pass
|
||||
|
||||
class SettingResponse(SettingBase):
|
||||
pass
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
# FastAPI 鉴权模型
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
@@ -1,16 +1,81 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from sqlmodel import Field, Column, String, Boolean, Relationship
|
||||
from .base import TableBase, IdMixin
|
||||
from enum import StrEnum
|
||||
from typing import ClassVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .object import Object
|
||||
import sqlalchemy as sa
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.orm.session import Session as SessionClass
|
||||
from sqlmodel import Field, Relationship
|
||||
|
||||
class User(IdMixin, TableBase, table=True):
|
||||
from .base import SQLModelBase, UUIDTableBase
|
||||
from .item import Item
|
||||
|
||||
email: str = Field(sa_column=Column(String(100), index=True, unique=True))
|
||||
username: str = Field(sa_column=Column(String(50), index=True, unique=True))
|
||||
password: str = Field(sa_column=Column(String(100)))
|
||||
|
||||
is_admin: bool = Field(default=False, sa_column=Column(Boolean, default=False))
|
||||
|
||||
objects: list["Object"] = Relationship(back_populates="user")
|
||||
class UserTypeEnum(StrEnum):
|
||||
normal_user = 'normal_user'
|
||||
admin = 'admin'
|
||||
super_admin = 'super_admin'
|
||||
|
||||
class UserBase(SQLModelBase):
|
||||
pass
|
||||
|
||||
class User(UserBase, UUIDTableBase, table=True):
|
||||
email: EmailStr = Field(index=True, unique=True)
|
||||
"""邮箱"""
|
||||
|
||||
username: str = Field(index=True, unique=True)
|
||||
"""用户名"""
|
||||
|
||||
password: str
|
||||
"""Argon2算法哈希后的密码"""
|
||||
|
||||
two_factor_secret: str | None = None
|
||||
"""两步验证的密钥"""
|
||||
|
||||
role: UserTypeEnum = Field(default=UserTypeEnum.normal_user, index=True)
|
||||
"""用户的权限等级"""
|
||||
|
||||
items: list[Item] = Relationship(back_populates='user', cascade_delete=True)
|
||||
"""物品关系"""
|
||||
|
||||
_initializing: ClassVar[bool] = False
|
||||
"""标记当前是否处于初始化阶段,初始化阶段允许创建 super_admin"""
|
||||
|
||||
@event.listens_for(SessionClass, "before_flush")
|
||||
def check_super_admin_immutability(session, flush_context, instances):
|
||||
"""
|
||||
在事务刷新到数据库前,集中检查所有关于 super_admin 的不合法操作。
|
||||
此监听器确保超级管理员的角色和存在性是不可变的。
|
||||
"""
|
||||
# 检查1: 禁止创建新的 super_admin
|
||||
for obj in session.new:
|
||||
if isinstance(obj, User) and obj.role == UserTypeEnum.super_admin and not User._initializing:
|
||||
raise ValueError("业务规则:不允许创建新的超级管理员。")
|
||||
|
||||
# 检查2: 禁止删除已存在的 super_admin
|
||||
for obj in session.deleted:
|
||||
if isinstance(obj, User):
|
||||
state = sa.inspect(obj)
|
||||
# 直接从对象被删除前的状态获取角色,避免不必要的 lazy load
|
||||
original_role = state.committed_state.get('role')
|
||||
if original_role == UserTypeEnum.super_admin:
|
||||
username = state.committed_state.get('username', f'(ID: {obj.id})')
|
||||
raise ValueError(f"业务规则:不允许删除超级管理员 '{username}'。")
|
||||
|
||||
# 检查3: 禁止与 super_admin 相关的角色变更
|
||||
for obj in session.dirty:
|
||||
if isinstance(obj, User):
|
||||
state = sa.inspect(obj)
|
||||
# 仅在 'role' 字段确实被修改时才进行检查
|
||||
if "role" in state.committed_state:
|
||||
history = state.attrs.role.history
|
||||
original_role = history.deleted[0]
|
||||
new_role = history.added[0]
|
||||
|
||||
# 场景 a: 禁止将 super_admin 降级
|
||||
if original_role == UserTypeEnum.super_admin:
|
||||
raise ValueError(f"业务规则:不允许将超级管理员 '{obj.username}' 的角色降级。")
|
||||
|
||||
# 场景 b: 禁止将任何用户提升为 super_admin
|
||||
if new_role == UserTypeEnum.super_admin:
|
||||
raise ValueError(f"业务规则:不允许将用户 '{obj.username}' 提升为超级管理员。")
|
||||
|
||||
Reference in New Issue
Block a user