修复数据库构建问题

This commit is contained in:
2025-10-03 10:00:22 +08:00
parent 3469ca9ab1
commit 643f19c1f1
7 changed files with 101 additions and 180 deletions

31
JWT.py
View File

@@ -1,12 +1,35 @@
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from model import database from model import Setting
import asyncio from model.database import Database
oauth2_scheme = OAuth2PasswordBearer( oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌', scheme_name='获取 JWT Bearer 令牌',
description='用于获取 JWT Bearer 令牌需要以表单的形式提交', description='用于获取 JWT Bearer 令牌,需要以表单的形式提交',
tokenUrl="/api/token" tokenUrl="/api/token"
) )
SECRET_KEY = asyncio.run(database.Database().get_setting('SECRET_KEY'))
ALGORITHM = "HS256" ALGORITHM = "HS256"
# 延迟加载 SECRET_KEY
_SECRET_KEY_CACHE = None
async def get_secret_key() -> str:
"""
获取 JWT 密钥
:return: JWT 密钥字符串
"""
global _SECRET_KEY_CACHE
if _SECRET_KEY_CACHE is None:
async with Database.get_session() as session:
setting = await Setting.get(
session=session,
condition=(Setting.name == 'SECRET_KEY')
)
if setting:
_SECRET_KEY_CACHE = setting.value
else:
raise RuntimeError("SECRET_KEY not found in database")
return _SECRET_KEY_CACHE

View File

@@ -1,15 +1,12 @@
# ~/models/database.py
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import aiosqlite
from datetime import datetime
from typing import Optional
from sqlmodel import SQLModel
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import sessionmaker
from typing import AsyncGenerator from typing import AsyncGenerator
import warnings from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from sqlmodel.ext.asyncio.session import AsyncSession
from .migration import migration from .migration import migration
ASYNC_DATABASE_URL = "sqlite+aiosqlite:///data.db" ASYNC_DATABASE_URL = "sqlite+aiosqlite:///data.db"
@@ -19,161 +16,38 @@ engine: AsyncEngine = create_async_engine(
echo=True, echo=True,
connect_args={ connect_args={
"check_same_thread": False "check_same_thread": False
} if ASYNC_DATABASE_URL.startswith("sqlite") else None, } if ASYNC_DATABASE_URL.startswith("sqlite") else {},
future=True, future=True,
# pool_size=POOL_SIZE, # pool_size=POOL_SIZE,
# max_overflow=64, # max_overflow=64,
) )
_async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) _async_session_factory = sessionmaker(
engine, class_=AsyncSession, expire_on_commit=False
)
# 数据库类 # 数据库类
class Database: class Database:
# Database 初始化方法 # Database 初始化方法
def __init__( def __init__(
self, # self 用于引用类的实例 self, # self 用于引用类的实例
db_path: str = "data.db" # db_path 数据库文件路径,默认为 data.db db_path: str = "data.db", # db_path 数据库文件路径,默认为 data.db
): ):
self.db_path = db_path self.db_path = db_path
@staticmethod @staticmethod
@asynccontextmanager
async def get_session() -> AsyncGenerator[AsyncSession, None]: async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency to get a database session."""
async with _async_session_factory() as session: async with _async_session_factory() as session:
yield session yield session
async def init_db( async def init_db(self, url: str = ASYNC_DATABASE_URL):
self,
url: str = ASYNC_DATABASE_URL
):
"""创建数据库结构""" """创建数据库结构"""
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all) await conn.run_sync(SQLModel.metadata.create_all)
async with self.get_session() as session: # For internal use, create a temporary context manager
get_session_cm = asynccontextmanager(self.get_session)
async with get_session_cm() as session:
await migration(session) # 执行迁移脚本 await migration(session) # 执行迁移脚本
async def add_object(self, key: str, name: str, icon: str = None, phone: str = None):
"""
添加新对象
:param key: 序列号
:param name: 名称
:param icon: 图标
:param phone: 电话
"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("SELECT 1 FROM fr_objects WHERE key = ?", (key,)) as cursor:
if await cursor.fetchone():
raise ValueError(f"序列号 {key} 已存在")
now = datetime.now()
now = now.strftime("%Y-%m-%d %H:%M:%S")
await db.execute(
"INSERT INTO fr_objects (key, name, icon, phone, create_at, status) VALUES (?, ?, ?, ?, ?, 'ok')",
(key, name, icon, phone, now)
)
await db.commit()
async def update_object(
self,
id: int,
key: str = None,
name: str = None,
icon: str = None,
status: str = None,
phone: int = None,
lost_description: Optional[str] = None,
find_ip: Optional[str] = None,
lost_time: Optional[str] = None):
"""
更新对象信息
:param id: 对象ID
:param key: 序列号
:param name: 名称
:param icon: 图标
:param status: 状态
:param phone: 电话
:param lost_description: 丢失描述
:param find_ip: 发现IP
:param lost_time: 丢失时间
"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute("SELECT 1 FROM fr_objects WHERE id = ?", (id,)) as cursor:
if not await cursor.fetchone():
raise ValueError(f"ID {id} 不存在")
async with db.execute("SELECT 1 FROM fr_objects WHERE key = ? AND id != ?", (key, id)) as cursor:
if await cursor.fetchone():
raise ValueError(f"序列号 {key} 已存在")
await db.execute(
f"UPDATE fr_objects SET "
f"key = COALESCE(?, key), "
f"name = COALESCE(?, name), "
f"icon = COALESCE(?, icon), "
f"status = COALESCE(?, status), "
f"phone = COALESCE(?, phone), "
f"context = COALESCE(?, context), "
f"find_ip = COALESCE(?, find_ip), "
f"lost_at = COALESCE(?, lost_at) "
f"WHERE id = ?",
(key, name, icon, status, phone, lost_description, find_ip, lost_time, id)
)
await db.commit()
async def get_object(self, id: int = None, key: str = None):
"""
获取对象
:param id: 对象ID
:param key: 序列号
"""
async with aiosqlite.connect(self.db_path) as db:
if id is not None or key is not None:
async with db.execute(
"SELECT * FROM fr_objects WHERE id = ? OR key = ?", (id, key)
) as cursor:
return await cursor.fetchone()
else:
async with db.execute("SELECT * FROM fr_objects") as cursor:
return await cursor.fetchall()
async def delete_object(self, id: int):
"""
删除对象
:param id: 对象ID
"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute("DELETE FROM fr_objects WHERE id = ?", (id,))
await db.commit()
async def set_setting(self, name: str, value: str):
"""
设置配置项
:param name: 配置项名称
:param value: 配置项值
"""
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"INSERT OR REPLACE INTO fr_settings (name, value) VALUES (?, ?)",
(name, value)
)
await db.commit()
async def get_setting(self, name: str):
"""
获取配置项
:param name: 配置项名称
"""
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT value FROM fr_settings WHERE name = ?", (name,)
) as cursor:
result = await cursor.fetchone()
return result[0] if result else None

View File

@@ -1,4 +1,4 @@
from typing import Sequence from loguru import logger
from sqlmodel import select from sqlmodel import select
from .setting import Setting from .setting import Setting
import tool import tool
@@ -13,9 +13,13 @@ async def migration(session):
# 先准备基础配置 # 先准备基础配置
settings: list[Setting] = default_settings.copy() settings: list[Setting] = default_settings.copy()
if await Setting.get(session, Setting.name == 'version'):
# 已有数据,说明不是第一次运行,直接返回
return
# 生成初始密码与密钥 # 生成初始密码与密钥
admin_password = tool.generate_password() admin_password = tool.generate_password()
print(f"密码(请牢记,后续不再显示): {admin_password}") logger.warning(f"密码(请牢记,后续不再显示): {admin_password}")
settings.append(Setting(type='string', name='password', value=tool.hash_password(admin_password))) settings.append(Setting(type='string', name='password', value=tool.hash_password(admin_password)))
settings.append(Setting(type='string', name='SECRET_KEY', value=tool.generate_password(64))) settings.append(Setting(type='string', name='SECRET_KEY', value=tool.generate_password(64)))

View File

@@ -14,4 +14,7 @@ class ObjectData(BaseModel):
icon: str icon: str
status: Literal['ok', 'lost'] status: Literal['ok', 'lost']
phone: str phone: str
context: Optional[str] = None context: str | None = None
lost_description: str | None = None
create_time: str
lost_time: str | None = None

Binary file not shown.

View File

@@ -1,5 +1,5 @@
from fastapi import APIRouter from fastapi import APIRouter
from typing import Annotated, Literal, Optional from typing import Annotated, Literal
from fastapi import Depends, Query from fastapi import Depends, Query
from fastapi import HTTPException from fastapi import HTTPException
import JWT import JWT
@@ -8,9 +8,14 @@ from jwt import InvalidTokenError
from model import database from model import database
from model.response import DefaultResponse from model.response import DefaultResponse
from model.items import Item from model.items import Item
from sqlmodel.ext.asyncio.session import AsyncSession
from model import Setting
# 验证是否为管理员 # 验证是否为管理员
async def is_admin(token: Annotated[str, Depends(JWT.oauth2_scheme)]) -> Literal[True]: async def is_admin(
token: Annotated[str, Depends(JWT.oauth2_scheme)],
session: Annotated[AsyncSession, Depends(database.Database.get_session)]
) -> Literal[True]:
''' '''
验证是否为管理员。 验证是否为管理员。
@@ -24,9 +29,9 @@ async def is_admin(token: Annotated[str, Depends(JWT.oauth2_scheme)]) -> Literal
) )
try: try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=[JWT.ALGORITHM]) payload = jwt.decode(token, JWT.get_secret_key(), algorithms=[JWT.ALGORITHM])
username = payload.get("sub") username = payload.get("sub")
if username is None or not await database.Database().get_setting('account') == username: if username is None or not await Setting.get(session, Setting.name == 'account') == username:
raise credentials_exception raise credentials_exception
else: else:
return True return True
@@ -64,8 +69,8 @@ async def verity_admin() -> DefaultResponse:
response_description='物品信息列表' response_description='物品信息列表'
) )
async def get_items( async def get_items(
id: Optional[int] = Query(default=None, ge=1, description='物品ID'), id: int | None = Query(default=None, ge=1, description='物品ID'),
key: Optional[str] = Query(default=None, description='物品序列号')): key: str | None = Query(default=None, description='物品序列号')):
''' '''
获得物品信息。 获得物品信息。
@@ -80,7 +85,6 @@ async def get_items(
items = results items = results
item = [] item = []
for i in items: for i in items:
print(i)
item.append(Item( item.append(Item(
id=i[0], id=i[0],
type=i[1], type=i[1],
@@ -144,14 +148,15 @@ async def add_items(
) )
async def update_items( async def update_items(
id: int = Query(ge=1), id: int = Query(ge=1),
key: Optional[str] = None, key: str | None = None,
name: Optional[str] = None, name: str | None = None,
icon: Optional[str] = None, icon: str | None = None,
status: Optional[str] = None, status: str | None = None,
phone: Optional[int] = None, phone: int | None = None,
lost_description: Optional[str] = None, lost_description: str | None = None,
find_ip: Optional[str] = None, find_ip: str | None = None,
lost_time: Optional[str] = None) -> DefaultResponse: lost_time: str | None = None
) -> DefaultResponse:
''' '''
更新物品信息。 更新物品信息。

View File

@@ -5,31 +5,38 @@ from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter from fastapi import APIRouter
import jwt, JWT import jwt, JWT
from sqlmodel.ext.asyncio.session import AsyncSession
from tool import verify_password
from loguru import logger
from model.token import Token from model.token import Token
from model import Setting, database from model import Setting, database
from tool import verify_password
Router = APIRouter(tags=["令牌 session"]) Router = APIRouter(tags=["令牌 session"])
# 创建令牌 # 创建令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None): async def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy() to_encode = data.copy()
if expires_delta: if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta expire = datetime.now(timezone.utc) + expires_delta
else: else:
expire = datetime.now(timezone.utc) + timedelta(minutes=15) expire = datetime.now(timezone.utc) + timedelta(minutes=15)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT.SECRET_KEY, algorithm='HS256') encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm='HS256')
return encoded_jwt return encoded_jwt
# 验证账号密码 # 验证账号密码
async def authenticate_user(username: str, password: str): async def authenticate_user(session: AsyncSession, username: str, password: str):
# 验证账号和密码 # 验证账号和密码
account = await Setting.get('setting', 'account') account = await Setting.get(session, Setting.name == 'account')
stored_password = await Setting.get('setting', 'password') stored_password = await Setting.get(session, Setting.name == 'password')
if account != username or not verify_password(stored_password, password): if not account or not stored_password:
logger.error("Account or password not set in settings.")
return False
if account != username or not verify_password(stored_password.value, password):
logger.error("Invalid username or password.")
return False return False
return {'is_authenticated': True} return {'is_authenticated': True}
@@ -44,8 +51,13 @@ async def authenticate_user(username: str, password: str):
) )
async def login_for_access_token( async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
session: Annotated[AsyncSession, Depends(database.Database.get_session)],
) -> Token: ) -> Token:
user = await authenticate_user(form_data.username, form_data.password) user = await authenticate_user(
session=session,
username=form_data.username,
password=form_data.password
)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=401, status_code=401,
@@ -53,7 +65,7 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token_expires = timedelta(hours=1) access_token_expires = timedelta(hours=1)
access_token = create_access_token( access_token = await create_access_token(
data={"sub": form_data.username}, expires_delta=access_token_expires data={"sub": form_data.username}, expires_delta=access_token_expires
) )
return Token(access_token=access_token, token_type="bearer") return Token(access_token=access_token, token_type="bearer")