From 643f19c1f192bc5b8d53178b3eadf065306476b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=8E=E5=B0=8F=E4=B8=98?= Date: Fri, 3 Oct 2025 10:00:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=9E=84=E5=BB=BA=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- JWT.py | 35 ++++++++-- model/database.py | 168 ++++++--------------------------------------- model/migration.py | 8 ++- model/response.py | 5 +- requirements.txt | Bin 2100 -> 2318 bytes routes/admin.py | 35 ++++++---- routes/session.py | 30 +++++--- 7 files changed, 101 insertions(+), 180 deletions(-) diff --git a/JWT.py b/JWT.py index 1a343f5..7fcd52a 100644 --- a/JWT.py +++ b/JWT.py @@ -1,12 +1,35 @@ from fastapi.security import OAuth2PasswordBearer -from model import database -import asyncio +from model import Setting +from model.database import Database oauth2_scheme = OAuth2PasswordBearer( scheme_name='获取 JWT Bearer 令牌', - description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', + description='用于获取 JWT Bearer 令牌,需要以表单的形式提交', tokenUrl="/api/token" - ) +) -SECRET_KEY = asyncio.run(database.Database().get_setting('SECRET_KEY')) -ALGORITHM = "HS256" \ No newline at end of file +ALGORITHM = "HS256" + +# 延迟加载 SECRET_KEY +_SECRET_KEY_CACHE = None + +async def get_secret_key() -> str: + """ + 获取 JWT 密钥 + + :return: JWT 密钥字符串 + """ + global _SECRET_KEY_CACHE + + if _SECRET_KEY_CACHE is None: + async with Database.get_session() as session: + setting = await Setting.get( + session=session, + condition=(Setting.name == 'SECRET_KEY') + ) + if setting: + _SECRET_KEY_CACHE = setting.value + else: + raise RuntimeError("SECRET_KEY not found in database") + + return _SECRET_KEY_CACHE \ No newline at end of file diff --git a/model/database.py b/model/database.py index 244a872..cbcb623 100644 --- a/model/database.py +++ b/model/database.py @@ -1,15 +1,12 @@ +# ~/models/database.py from contextlib import asynccontextmanager -import aiosqlite -from datetime import datetime -from typing import Optional - -from sqlmodel import SQLModel -from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlmodel.ext.asyncio.session import AsyncSession -from sqlalchemy.orm import sessionmaker from typing import AsyncGenerator -import warnings +from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import SQLModel +from sqlmodel.ext.asyncio.session import AsyncSession + from .migration import migration ASYNC_DATABASE_URL = "sqlite+aiosqlite:///data.db" @@ -19,161 +16,38 @@ engine: AsyncEngine = create_async_engine( echo=True, connect_args={ "check_same_thread": False - } if ASYNC_DATABASE_URL.startswith("sqlite") else None, + } if ASYNC_DATABASE_URL.startswith("sqlite") else {}, future=True, # pool_size=POOL_SIZE, # max_overflow=64, ) -_async_session_factory = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) +_async_session_factory = sessionmaker( + engine, class_=AsyncSession, expire_on_commit=False +) + # 数据库类 class Database: - # Database 初始化方法 def __init__( - self, # self 用于引用类的实例 - db_path: str = "data.db" # db_path 数据库文件路径,默认为 data.db + self, # self 用于引用类的实例 + db_path: str = "data.db", # db_path 数据库文件路径,默认为 data.db ): self.db_path = db_path - + @staticmethod - @asynccontextmanager async def get_session() -> AsyncGenerator[AsyncSession, None]: + """FastAPI dependency to get a database session.""" async with _async_session_factory() as session: yield session - async def init_db( - self, - url: str = ASYNC_DATABASE_URL - ): + async def init_db(self, url: str = ASYNC_DATABASE_URL): """创建数据库结构""" async with engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - - async with self.get_session() as session: - await migration(session) # 执行迁移脚本 - - async def add_object(self, key: str, name: str, icon: str = None, phone: str = None): - """ - 添加新对象 - - :param key: 序列号 - :param name: 名称 - :param icon: 图标 - :param phone: 电话 - """ - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT 1 FROM fr_objects WHERE key = ?", (key,)) as cursor: - if await cursor.fetchone(): - raise ValueError(f"序列号 {key} 已存在") - - now = datetime.now() - now = now.strftime("%Y-%m-%d %H:%M:%S") - await db.execute( - "INSERT INTO fr_objects (key, name, icon, phone, create_at, status) VALUES (?, ?, ?, ?, ?, 'ok')", - (key, name, icon, phone, now) - ) - await db.commit() - - async def update_object( - self, - id: int, - key: str = None, - name: str = None, - icon: str = None, - status: str = None, - phone: int = None, - lost_description: Optional[str] = None, - find_ip: Optional[str] = None, - lost_time: Optional[str] = None): - """ - 更新对象信息 - - :param id: 对象ID - :param key: 序列号 - :param name: 名称 - :param icon: 图标 - :param status: 状态 - :param phone: 电话 - :param lost_description: 丢失描述 - :param find_ip: 发现IP - :param lost_time: 丢失时间 - """ - async with aiosqlite.connect(self.db_path) as db: - async with db.execute("SELECT 1 FROM fr_objects WHERE id = ?", (id,)) as cursor: - if not await cursor.fetchone(): - raise ValueError(f"ID {id} 不存在") - - async with db.execute("SELECT 1 FROM fr_objects WHERE key = ? AND id != ?", (key, id)) as cursor: - if await cursor.fetchone(): - raise ValueError(f"序列号 {key} 已存在") - - await db.execute( - f"UPDATE fr_objects SET " - f"key = COALESCE(?, key), " - f"name = COALESCE(?, name), " - f"icon = COALESCE(?, icon), " - f"status = COALESCE(?, status), " - f"phone = COALESCE(?, phone), " - f"context = COALESCE(?, context), " - f"find_ip = COALESCE(?, find_ip), " - f"lost_at = COALESCE(?, lost_at) " - f"WHERE id = ?", - (key, name, icon, status, phone, lost_description, find_ip, lost_time, id) - ) - await db.commit() - - async def get_object(self, id: int = None, key: str = None): - """ - 获取对象 - - :param id: 对象ID - :param key: 序列号 - """ - async with aiosqlite.connect(self.db_path) as db: - if id is not None or key is not None: - async with db.execute( - "SELECT * FROM fr_objects WHERE id = ? OR key = ?", (id, key) - ) as cursor: - return await cursor.fetchone() - else: - async with db.execute("SELECT * FROM fr_objects") as cursor: - return await cursor.fetchall() - - async def delete_object(self, id: int): - """ - 删除对象 - - :param id: 对象ID - """ - async with aiosqlite.connect(self.db_path) as db: - await db.execute("DELETE FROM fr_objects WHERE id = ?", (id,)) - await db.commit() - - async def set_setting(self, name: str, value: str): - """ - 设置配置项 - - :param name: 配置项名称 - :param value: 配置项值 - """ - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "INSERT OR REPLACE INTO fr_settings (name, value) VALUES (?, ?)", - (name, value) - ) - await db.commit() - - async def get_setting(self, name: str): - """ - 获取配置项 - - :param name: 配置项名称 - """ - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT value FROM fr_settings WHERE name = ?", (name,) - ) as cursor: - result = await cursor.fetchone() - return result[0] if result else None \ No newline at end of file + + # For internal use, create a temporary context manager + get_session_cm = asynccontextmanager(self.get_session) + async with get_session_cm() as session: + await migration(session) # 执行迁移脚本 \ No newline at end of file diff --git a/model/migration.py b/model/migration.py index c829735..a903a56 100644 --- a/model/migration.py +++ b/model/migration.py @@ -1,4 +1,4 @@ -from typing import Sequence +from loguru import logger from sqlmodel import select from .setting import Setting import tool @@ -13,9 +13,13 @@ async def migration(session): # 先准备基础配置 settings: list[Setting] = default_settings.copy() + if await Setting.get(session, Setting.name == 'version'): + # 已有数据,说明不是第一次运行,直接返回 + return + # 生成初始密码与密钥 admin_password = tool.generate_password() - print(f"密码(请牢记,后续不再显示): {admin_password}") + logger.warning(f"密码(请牢记,后续不再显示): {admin_password}") settings.append(Setting(type='string', name='password', value=tool.hash_password(admin_password))) settings.append(Setting(type='string', name='SECRET_KEY', value=tool.generate_password(64))) diff --git a/model/response.py b/model/response.py index a87dd32..bf6169a 100644 --- a/model/response.py +++ b/model/response.py @@ -14,4 +14,7 @@ class ObjectData(BaseModel): icon: str status: Literal['ok', 'lost'] phone: str - context: Optional[str] = None \ No newline at end of file + context: str | None = None + lost_description: str | None = None + create_time: str + lost_time: str | None = None \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 85b9a57e250855ee123e56a41d2a38a87200e41d..b1bd31fab1ba32531aceeebdc535456dd2550a39 100644 GIT binary patch delta 293 zcmX|+%?`m(5QR^vpGYKrR+U&{Cup&cxODhUH%Zhao6Ji~xAGom9o8vys8egnSbxboQ3!gm zt=1o5hyl8cXk)>ct1qfzsD&0Ir;)+|J2ERqnv(MVffUm~Gq}dHT1DkB`5Y5hKepK} Nhi3onp{+a;oev|ZH--QJ delta 104 zcmeAZ+9I%_ifQs3rfrkUnB6AdV$PGyW5{GkW=Lg7XD9`-Y=O{-L65qE0%qH(+Rgp<#C Literal[True]: +async def is_admin( + token: Annotated[str, Depends(JWT.oauth2_scheme)], + session: Annotated[AsyncSession, Depends(database.Database.get_session)] +) -> Literal[True]: ''' 验证是否为管理员。 @@ -24,9 +29,9 @@ async def is_admin(token: Annotated[str, Depends(JWT.oauth2_scheme)]) -> Literal ) try: - payload = jwt.decode(token, JWT.SECRET_KEY, algorithms=[JWT.ALGORITHM]) + payload = jwt.decode(token, JWT.get_secret_key(), algorithms=[JWT.ALGORITHM]) username = payload.get("sub") - if username is None or not await database.Database().get_setting('account') == username: + if username is None or not await Setting.get(session, Setting.name == 'account') == username: raise credentials_exception else: return True @@ -64,8 +69,8 @@ async def verity_admin() -> DefaultResponse: response_description='物品信息列表' ) async def get_items( - id: Optional[int] = Query(default=None, ge=1, description='物品ID'), - key: Optional[str] = Query(default=None, description='物品序列号')): + id: int | None = Query(default=None, ge=1, description='物品ID'), + key: str | None = Query(default=None, description='物品序列号')): ''' 获得物品信息。 @@ -80,7 +85,6 @@ async def get_items( items = results item = [] for i in items: - print(i) item.append(Item( id=i[0], type=i[1], @@ -144,14 +148,15 @@ async def add_items( ) async def update_items( id: int = Query(ge=1), - key: Optional[str] = None, - name: Optional[str] = None, - icon: Optional[str] = None, - status: Optional[str] = None, - phone: Optional[int] = None, - lost_description: Optional[str] = None, - find_ip: Optional[str] = None, - lost_time: Optional[str] = None) -> DefaultResponse: + key: str | None = None, + name: str | None = None, + icon: str | None = None, + status: str | None = None, + phone: int | None = None, + lost_description: str | None = None, + find_ip: str | None = None, + lost_time: str | None = None + ) -> DefaultResponse: ''' 更新物品信息。 diff --git a/routes/session.py b/routes/session.py index 6e006f9..10c091f 100644 --- a/routes/session.py +++ b/routes/session.py @@ -5,31 +5,38 @@ from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm from fastapi import APIRouter import jwt, JWT +from sqlmodel.ext.asyncio.session import AsyncSession +from tool import verify_password +from loguru import logger from model.token import Token from model import Setting, database -from tool import verify_password Router = APIRouter(tags=["令牌 session"]) # 创建令牌 -def create_access_token(data: dict, expires_delta: timedelta | None = None): +async def create_access_token(data: dict, expires_delta: timedelta | None = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(minutes=15) to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, JWT.SECRET_KEY, algorithm='HS256') + encoded_jwt = jwt.encode(to_encode, key=await JWT.get_secret_key(), algorithm='HS256') return encoded_jwt # 验证账号密码 -async def authenticate_user(username: str, password: str): +async def authenticate_user(session: AsyncSession, username: str, password: str): # 验证账号和密码 - account = await Setting.get('setting', 'account') - stored_password = await Setting.get('setting', 'password') + account = await Setting.get(session, Setting.name == 'account') + stored_password = await Setting.get(session, Setting.name == 'password') - if account != username or not verify_password(stored_password, password): + if not account or not stored_password: + logger.error("Account or password not set in settings.") + return False + + if account != username or not verify_password(stored_password.value, password): + logger.error("Invalid username or password.") return False return {'is_authenticated': True} @@ -44,8 +51,13 @@ async def authenticate_user(username: str, password: str): ) async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + session: Annotated[AsyncSession, Depends(database.Database.get_session)], ) -> Token: - user = await authenticate_user(form_data.username, form_data.password) + user = await authenticate_user( + session=session, + username=form_data.username, + password=form_data.password + ) if not user: raise HTTPException( status_code=401, @@ -53,7 +65,7 @@ async def login_for_access_token( headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(hours=1) - access_token = create_access_token( + access_token = await create_access_token( data={"sub": form_data.username}, expires_delta=access_token_expires ) return Token(access_token=access_token, token_type="bearer") \ No newline at end of file