V1.1.3 安全性与稳定性的一些小改进

This commit is contained in:
2025-03-28 00:40:12 +08:00
parent b2078ad340
commit 0873fa1518
5 changed files with 46 additions and 47 deletions

7
.gitignore vendored
View File

@@ -1,4 +1,5 @@
data.db
# DataBase
*.db
# environment
.venv/
@@ -9,5 +10,5 @@ data.db
# Byte-compiled / optimized / DLL files
__pycache__/
# C extensions
*.so
# VsCodeCounter Data
.VSCodeCounter/

View File

@@ -9,21 +9,14 @@ Description: Findreve 后台管理 admin
Copyright (c) 2018-2024 by 于小丘Yuerchu, All Rights Reserved.
'''
from nicegui import ui, app
from typing import Optional
from nicegui import ui
from typing import Dict
import traceback
import model
import asyncio
import qrcode
import base64
from io import BytesIO
from PIL import Image
from fastapi import Request
import json
import requests
from tool import *
from fastapi.responses import RedirectResponse
from datetime import datetime

View File

@@ -11,24 +11,23 @@ Copyright (c) 2018-2024 by 于小丘Yuerchu, All Rights Reserved.
from nicegui import ui, app
from typing import Optional
import traceback
import asyncio
import model
import tool
from fastapi.responses import RedirectResponse
from fastapi import Request
def create() -> Optional[RedirectResponse]:
@ui.page('/login')
async def session(redirect_to: str = '/'):
async def session(request: Request, redirect_to: str = "/"):
# 检测是否已登录
if app.storage.user.get('authenticated', False):
ui.navigate.to(redirect_to)
return ui.navigate.to(redirect_to)
ui.page_title('登录 Findreve')
async def try_login() -> None:
app.storage.user.update({'authenticated': True})
# 跳转到用户上一页
ui.navigate.to(redirect_to)
ui.navigate.to(app.storage.user.get('referrer_path', '/'))
async def login():
if username.value == "" or password.value == "":

59
main.py
View File

@@ -9,12 +9,11 @@ Description: Findreve
Copyright (c) 2018-2024 by 于小丘Yuerchu, All Rights Reserved.
'''
from nicegui import app, ui, Client
from nicegui import app, ui
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import RedirectResponse, JSONResponse
import hashlib
import inspect
import traceback
import notfound
import main_page
@@ -39,39 +38,51 @@ AUTH_CONFIG = {
"session_expire": 3600 # 会话过期时间
}
# 登录验证中间件 Login verification middleware
def is_restricted_route(path: str) -> bool:
"""判断路径是否为需要认证的受限路由"""
# NiceGUI 路由不受限制
if path.startswith('/_nicegui'):
return False
# 静态资源路径不受限制
if path.startswith('/static'):
return False
# 主题路径不受限制
if path.startswith('/theme'):
return False
# 后台路径始终受限
if path.startswith('/admin'):
return True
# 检查是否为受限的客户端页面路由
if path.startswith('/dash') or path.startswith('/user'):
return True
class AuthMiddleware(BaseHTTPMiddleware):
# 异步处理每个请求
async def dispatch(self, request: Request, call_next):
try:
logging.info(f"访问路径: {request.url.path},"
f"认证状态: {app.storage.user.get('authenticated')}")
if not app.storage.user.get('authenticated', False):
# 如果请求的路径不是nicegui的静态文件并且不在unrestricted_page_routes中
if not request.url.path.startwith('/_nicegui') \
and request.url.path in AUTH_CONFIG["restricted_routes"]:
logging.warning(f"未认证用户尝试访问: {request.url.path}")
# 记录用户想访问的路径 Record the user's intended path
app.storage.user['referrer_path'] = request.url.path
# 重定向到登录页面 Redirect to the login page
return RedirectResponse(f'/login?redirect_to={request.url.path}')
# 否则,继续处理请求 Otherwise, continue processing the request
return await call_next(request)
path = request.url.path
if is_restricted_route(path):
logging.warning(f"未认证用户尝试访问: {path}")
return RedirectResponse(f'/login?redirect_to={path}')
return await call_next(request)
except Exception as e:
# 记录错误日志
logging.error(f"认证中间件错误: {str(e)}")
# 返回适当的错误响应
return JSONResponse(
status_code=500,
content={"detail": "服务器内部错误"}
)
logging.error(f"服务器错误 Server error: {str(traceback.format_exc())}")
return JSONResponse(status_code=500, content={"detail": e})
# 添加中间件 Add middleware
app.add_middleware(AuthMiddleware)
# 添加静态文件目录
try:
app.add_static_files(url_path='/static', local_directory='static')
except RuntimeError:
logging.error('无法挂载静态目录')
# 启动函数 Startup function
def startup():

View File

@@ -14,12 +14,7 @@ import random
import hashlib
import binascii
import logging
import qrcode
from typing import Optional
from io import BytesIO
from pathlib import Path
import base64
from datetime import datetime, timezone, timedelta
from datetime import datetime, timezone
import os
def format_phone(phone: str, groups: list = None, separator: str = " ", private: bool = False) -> str: