用户登录

This commit is contained in:
2025-07-17 19:33:48 +08:00
parent 412565cda2
commit e98c46f44a
26 changed files with 187 additions and 385 deletions

8
.idea/.gitignore generated vendored
View File

@@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

20
.idea/Server.iml generated
View File

@@ -1,20 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
<excludeFolder url="file://$MODULE_DIR$/.VSCodeCounter" />
<excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/.vscode" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="py.test" />
</component>
</module>

12
.idea/dataSources.xml generated
View File

@@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourceManagerImpl" format="xml" multifile-model="true">
<data-source source="LOCAL" name="DiskNext" uuid="6f814477-212f-41c7-885f-862eeff6022a">
<driver-ref>sqlite.xerial</driver-ref>
<synchronize>true</synchronize>
<jdbc-driver>org.sqlite.JDBC</jdbc-driver>
<jdbc-url>jdbc:sqlite:$PROJECT_DIR$/disknext.db</jdbc-url>
<working-dir>$ProjectFileDir$</working-dir>
</data-source>
</component>
</project>

6
.idea/encodings.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Encoding">
<file url="PROJECT" charset="UTF-8" />
</component>
</project>

View File

@@ -1,143 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyInterpreterInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="PyMissingTypeHintsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="107">
<item index="0" class="java.lang.String" itemvalue="httpx" />
<item index="1" class="java.lang.String" itemvalue="nicegui" />
<item index="2" class="java.lang.String" itemvalue="yarg" />
<item index="3" class="java.lang.String" itemvalue="PyYAML" />
<item index="4" class="java.lang.String" itemvalue="pickleshare" />
<item index="5" class="java.lang.String" itemvalue="defusedxml" />
<item index="6" class="java.lang.String" itemvalue="executing" />
<item index="7" class="java.lang.String" itemvalue="pycparser" />
<item index="8" class="java.lang.String" itemvalue="markdown-it-py" />
<item index="9" class="java.lang.String" itemvalue="Pygments" />
<item index="10" class="java.lang.String" itemvalue="pycryptodomex" />
<item index="11" class="java.lang.String" itemvalue="starlette" />
<item index="12" class="java.lang.String" itemvalue="bleach" />
<item index="13" class="java.lang.String" itemvalue="docutils" />
<item index="14" class="java.lang.String" itemvalue="soupsieve" />
<item index="15" class="java.lang.String" itemvalue="uvicorn" />
<item index="16" class="java.lang.String" itemvalue="jsonschema" />
<item index="17" class="java.lang.String" itemvalue="pywin32" />
<item index="18" class="java.lang.String" itemvalue="pydantic" />
<item index="19" class="java.lang.String" itemvalue="python-engineio" />
<item index="20" class="java.lang.String" itemvalue="click" />
<item index="21" class="java.lang.String" itemvalue="nbconvert" />
<item index="22" class="java.lang.String" itemvalue="wsproto" />
<item index="23" class="java.lang.String" itemvalue="attrs" />
<item index="24" class="java.lang.String" itemvalue="jedi" />
<item index="25" class="java.lang.String" itemvalue="jupyterlab_pygments" />
<item index="26" class="java.lang.String" itemvalue="pydantic_core" />
<item index="27" class="java.lang.String" itemvalue="asttokens" />
<item index="28" class="java.lang.String" itemvalue="platformdirs" />
<item index="29" class="java.lang.String" itemvalue="tencentcloud-sdk-python-common" />
<item index="30" class="java.lang.String" itemvalue="httpcore" />
<item index="31" class="java.lang.String" itemvalue="idna" />
<item index="32" class="java.lang.String" itemvalue="referencing" />
<item index="33" class="java.lang.String" itemvalue="decorator" />
<item index="34" class="java.lang.String" itemvalue="cffi" />
<item index="35" class="java.lang.String" itemvalue="pandocfilters" />
<item index="36" class="java.lang.String" itemvalue="requests" />
<item index="37" class="java.lang.String" itemvalue="bidict" />
<item index="38" class="java.lang.String" itemvalue="sniffio" />
<item index="39" class="java.lang.String" itemvalue="vbuild" />
<item index="40" class="java.lang.String" itemvalue="pyOpenSSL" />
<item index="41" class="java.lang.String" itemvalue="stack-data" />
<item index="42" class="java.lang.String" itemvalue="tencentcloud-sdk-python-sms" />
<item index="43" class="java.lang.String" itemvalue="mdurl" />
<item index="44" class="java.lang.String" itemvalue="itsdangerous" />
<item index="45" class="java.lang.String" itemvalue="websockets" />
<item index="46" class="java.lang.String" itemvalue="annotated-types" />
<item index="47" class="java.lang.String" itemvalue="watchfiles" />
<item index="48" class="java.lang.String" itemvalue="tornado" />
<item index="49" class="java.lang.String" itemvalue="markdown2" />
<item index="50" class="java.lang.String" itemvalue="aiofiles" />
<item index="51" class="java.lang.String" itemvalue="python-multipart" />
<item index="52" class="java.lang.String" itemvalue="mistune" />
<item index="53" class="java.lang.String" itemvalue="email_validator" />
<item index="54" class="java.lang.String" itemvalue="typing_extensions" />
<item index="55" class="java.lang.String" itemvalue="ifaddr" />
<item index="56" class="java.lang.String" itemvalue="multidict" />
<item index="57" class="java.lang.String" itemvalue="yarl" />
<item index="58" class="java.lang.String" itemvalue="webencodings" />
<item index="59" class="java.lang.String" itemvalue="traitlets" />
<item index="60" class="java.lang.String" itemvalue="python-dateutil" />
<item index="61" class="java.lang.String" itemvalue="python-dotenv" />
<item index="62" class="java.lang.String" itemvalue="h11" />
<item index="63" class="java.lang.String" itemvalue="nbclient" />
<item index="64" class="java.lang.String" itemvalue="MarkupSafe" />
<item index="65" class="java.lang.String" itemvalue="tinycss2" />
<item index="66" class="java.lang.String" itemvalue="httptools" />
<item index="67" class="java.lang.String" itemvalue="frozenlist" />
<item index="68" class="java.lang.String" itemvalue="docopt" />
<item index="69" class="java.lang.String" itemvalue="pyzmq" />
<item index="70" class="java.lang.String" itemvalue="certifi" />
<item index="71" class="java.lang.String" itemvalue="anyio" />
<item index="72" class="java.lang.String" itemvalue="beautifulsoup4" />
<item index="73" class="java.lang.String" itemvalue="dnspython" />
<item index="74" class="java.lang.String" itemvalue="pscript" />
<item index="75" class="java.lang.String" itemvalue="jupyter_client" />
<item index="76" class="java.lang.String" itemvalue="pure_eval" />
<item index="77" class="java.lang.String" itemvalue="cryptography" />
<item index="78" class="java.lang.String" itemvalue="orjson" />
<item index="79" class="java.lang.String" itemvalue="python-socketio" />
<item index="80" class="java.lang.String" itemvalue="backcall" />
<item index="81" class="java.lang.String" itemvalue="charset-normalizer" />
<item index="82" class="java.lang.String" itemvalue="shellingham" />
<item index="83" class="java.lang.String" itemvalue="simple-websocket" />
<item index="84" class="java.lang.String" itemvalue="matplotlib-inline" />
<item index="85" class="java.lang.String" itemvalue="wcwidth" />
<item index="86" class="java.lang.String" itemvalue="jupyter_core" />
<item index="87" class="java.lang.String" itemvalue="Jinja2" />
<item index="88" class="java.lang.String" itemvalue="jsonschema-specifications" />
<item index="89" class="java.lang.String" itemvalue="rpds-py" />
<item index="90" class="java.lang.String" itemvalue="urllib3" />
<item index="91" class="java.lang.String" itemvalue="fastapi-cli" />
<item index="92" class="java.lang.String" itemvalue="six" />
<item index="93" class="java.lang.String" itemvalue="typer" />
<item index="94" class="java.lang.String" itemvalue="prompt_toolkit" />
<item index="95" class="java.lang.String" itemvalue="parso" />
<item index="96" class="java.lang.String" itemvalue="python-alipay-sdk" />
<item index="97" class="java.lang.String" itemvalue="nbformat" />
<item index="98" class="java.lang.String" itemvalue="ipython" />
<item index="99" class="java.lang.String" itemvalue="rich" />
<item index="100" class="java.lang.String" itemvalue="packaging" />
<item index="101" class="java.lang.String" itemvalue="pipreqs" />
<item index="102" class="java.lang.String" itemvalue="fastjsonschema" />
<item index="103" class="java.lang.String" itemvalue="fastapi" />
<item index="104" class="java.lang.String" itemvalue="colorama" />
<item index="105" class="java.lang.String" itemvalue="aiohttp" />
<item index="106" class="java.lang.String" itemvalue="aiosignal" />
</list>
</value>
</option>
</inspection_tool>
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="E402" />
<option value="E713" />
<option value="E271" />
<option value="E302" />
<option value="E265" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false">
<option name="ignoredErrors">
<list>
<option value="N802" />
<option value="N806" />
</list>
</option>
</inspection_tool>
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
</profile>
</component>

View File

@@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MaterialThemeProjectNewConfig">
<option name="metadata">
<MTProjectMetadataState>
<option name="migrated" value="true" />
<option name="pristineConfig" value="false" />
<option name="userId" value="4dc9a07a:18f958e7499:-7ffe" />
</MTProjectMetadataState>
</option>
</component>
</project>

7
.idea/misc.xml generated
View File

@@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.12 (Server)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (Server)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated
View File

@@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Server.iml" filepath="$PROJECT_DIR$/.idea/Server.iml" />
</modules>
</component>
</project>

View File

@@ -1,26 +0,0 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="启动 DiskNext" type="PythonConfigurationType" factoryName="Python">
<module name="Server" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="Python 3.12 (Server)" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
</component>

View File

@@ -1,21 +0,0 @@
<component name="ProjectRunConfigurationManager">
<configuration default="false" name="测试 DiskNext" type="tests" factoryName="py.test">
<module name="Server" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="Python 3.12 (Server)" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="_new_keywords" value="&quot;&quot;" />
<option name="_new_parameters" value="&quot;&quot;" />
<option name="_new_additionalArguments" value="&quot;&quot;" />
<option name="_new_target" value="&quot;pytest&quot;" />
<option name="_new_targetType" value="&quot;PATH&quot;" />
<method v="2" />
</configuration>
</component>

6
.idea/vcs.xml generated
View File

@@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

11
main.py
View File

@@ -4,12 +4,17 @@ from pkg.conf import appmeta
from models.database import init_db
from models.migration import migration
from pkg.lifespan import lifespan
from pkg.JWT import jwt
from pkg.JWT import JWT
from pkg.log import log
# 添加初始化数据库启动项
lifespan.add_startup(init_db)
lifespan.add_startup(migration)
lifespan.add_startup(jwt.load_secret_key)
lifespan.add_startup(JWT.load_secret_key)
# 设置日志等级
if appmeta.debug:
log.set_log_level(log.LogLevelEnum.DEBUG)
# 创建应用实例并设置元数据
app = FastAPI(
@@ -39,7 +44,7 @@ for router in routers.Router:
# 启动时打印欢迎信息
if __name__ == "__main__":
import uvicorn
if appmeta.debug:
uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True)
else:

View File

@@ -1,30 +1,54 @@
from typing import Annotated, Literal
from fastapi import Depends
from pkg.JWT import jwt
from typing import Annotated, Optional
from fastapi import Depends, HTTPException
from models.user import User
from pkg.JWT import JWT
import jwt
from jwt import InvalidTokenError
credentials_exception = HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
async def AuthRequired(
token: Annotated[str, Depends(jwt.oauth2_scheme)]
) -> Literal[True]:
'''
token: Annotated[str, Depends(JWT.oauth2_scheme)]
) -> Optional["User"]:
"""
AuthRequired 需要登录
'''
from models.user import User
"""
try:
payload = jwt.decode(token, JWT.SECRET_KEY, algorithms="HS256")
username = payload.get("sub")
if username is None:
raise credentials_exception
# 从数据库获取用户信息
user = await User.get(email=username)
if not user:
raise credentials_exception
return user
except InvalidTokenError:
raise credentials_exception
async def SignRequired(
token: Annotated[str, Depends(jwt.oauth2_scheme)]
) -> Literal[True]:
'''
SignAuthRequired 需要登录并验证请求签名
'''
return True
token: Annotated[str, Depends(JWT.oauth2_scheme)]
) -> Optional["User"]:
"""
SignAuthRequired 需要验证请求签名
"""
pass
async def AdminRequired(
token: Annotated[str, Depends(jwt.oauth2_scheme)]
) -> Literal[True]:
'''
token: Annotated[str, Depends(JWT.oauth2_scheme)]
) -> Optional["User"]:
"""
验证是否为管理员。
使用方法:
>>> APIRouter(dependencies=[Depends(is_admin)])
'''
"""
pass

View File

@@ -1,5 +1,3 @@
# my_project/models/base.py
from typing import Optional
from sqlmodel import SQLModel, Field

View File

@@ -10,11 +10,11 @@ from typing import AsyncGenerator
ASYNC_DATABASE_URL = appmeta.database_url
engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL,
ASYNC_DATABASE_URL,
echo=appmeta.debug,
connect_args={"check_same_thread": False}
if ASYNC_DATABASE_URL.startswith("sqlite")
else None,
connect_args={
"check_same_thread": False
} if ASYNC_DATABASE_URL.startswith("sqlite") else None,
future=True,
# pool_size=POOL_SIZE,
# max_overflow=64,

View File

@@ -92,6 +92,7 @@ class Group(BaseModel, table=True):
try:
session.add(group)
await session.commit()
await session.refresh(group)
except Exception as e:
await session.rollback()
raise e

View File

@@ -212,7 +212,7 @@ async def init_default_user() -> None:
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
status=1, # 正常状态
status=0, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
)

View File

@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import Literal, Union, Optional
from datetime import datetime, timedelta, timezone
from uuid import uuid4
class ResponseModel(BaseModel):
@@ -7,6 +8,22 @@ class ResponseModel(BaseModel):
data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据")
msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示")
instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID用于标识请求的唯一性")
class TokenModel(BaseModel):
access_expires: datetime = Field(default=None, description="访问令牌的过期时间")
access_token: str = Field(default=None, description="访问令牌")
refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间")
refresh_token: str = Field(default=None, description="刷新令牌")
class userModel(ResponseModel):
id: str = Field(default=None, description="用户ID")
username: str = Field(default=None, description="用户名")
email: Optional[str] = Field(default=None, description="用户邮箱")
avatar: Optional[str] = Field(default=None, description="用户头像URL")
is_active: bool = Field(default=True, description="用户是否激活")
is_admin: bool = Field(default=False, description="用户是否为管理员")
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户创建时间")
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户更新时间")
class SiteConfigModel(ResponseModel):
title: str = Field(default="DiskNext", description="网站标题")

View File

@@ -4,6 +4,8 @@ from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
from .database import get_session
from sqlmodel import select
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
if TYPE_CHECKING:
@@ -130,6 +132,7 @@ class User(BaseModel, table=True):
try:
session.add(user)
await session.commit()
await session.refresh(user)
except Exception as e:
await session.rollback()
raise e
@@ -150,10 +153,6 @@ class User(BaseModel, table=True):
:return: 用户对象或 None
:rtype: Optional[User]
"""
from .database import get_session
from sqlmodel import select
session = get_session()
if id is None and email is None:
@@ -193,10 +192,6 @@ class User(BaseModel, table=True):
:return: 更新后的用户对象
:rtype: User
"""
from .database import get_session
from sqlmodel import select
async for session in get_session():
try:
statement = select(User).where(User.id == id)
@@ -248,9 +243,8 @@ class User(BaseModel, table=True):
:param id: 用户ID
:type id: int
"""
from .database import get_session
from sqlmodel import select
if id == 1:
raise ValueError("Cannot delete the default admin user with id 1.")
async for session in get_session():
try:

View File

@@ -1,5 +1,7 @@
from fastapi.security import OAuth2PasswordBearer
from models.setting import Setting
from datetime import datetime, timedelta, timezone
import jwt
oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
@@ -17,4 +19,26 @@ async def load_secret_key() -> None:
:type key: str
"""
global SECRET_KEY
SECRET_KEY = await Setting.get(type='auth', name='secret_key')
SECRET_KEY = await Setting.get(type='auth', name='secret_key')
# 访问令牌
def create_access_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(hours=3)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256')
return encoded_jwt, expire
# 刷新令牌
def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(days=30)
to_encode.update({"exp": expire, "token_type": "refresh"})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256')
return encoded_jwt, expire

View File

@@ -1,7 +1,6 @@
from rich import print
from rich.console import Console
from rich.markdown import Markdown
from configparser import ConfigParser
from typing import Literal, Optional, Dict, Union
from enum import Enum
import time
@@ -17,10 +16,6 @@ class LogLevelEnum(str, Enum):
# 默认日志级别
LogLevel = LogLevelEnum.INFO
# 日志文件路径
LOG_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
# 是否启用文件日志
ENABLE_FILE_LOG = False
def set_log_level(level: Union[str, LogLevelEnum]) -> None:
"""设置日志级别"""
@@ -33,17 +28,6 @@ def set_log_level(level: Union[str, LogLevelEnum]) -> None:
else:
LogLevel = level
def enable_file_log(enable: bool = True) -> None:
"""启用或禁用文件日志"""
global ENABLE_FILE_LOG
ENABLE_FILE_LOG = enable
if enable and not os.path.exists(LOG_FILE_PATH):
try:
os.makedirs(LOG_FILE_PATH)
except Exception as e:
print(f"[bold red]创建日志目录失败: {e}[/bold red]")
ENABLE_FILE_LOG = False
def truncate_path(full_path: str, marker: str = "HeyAuth") -> str:
"""截断路径只保留从marker开始的部分"""
try:
@@ -114,17 +98,6 @@ def log(level: str = 'debug', message: str = ''):
if should_log:
print(log_message)
# 文件日志记录
if ENABLE_FILE_LOG:
try:
# 去除rich格式化标记
clean_message = f"{level_value.upper()}\t{timestamp} From {filename}, line {lineno} {message}"
log_file = os.path.join(LOG_FILE_PATH, f"{time.strftime('%Y%m%d')}.log")
with open(log_file, 'a', encoding='utf-8') as f:
f.write(f"{clean_message}\n")
except Exception as e:
print(f"[bold red]写入日志文件失败: {e}[/bold red]")
# 便捷日志函数
debug = lambda message: log('debug', message)
@@ -133,31 +106,6 @@ warning = lambda message: log('warn', message)
error = lambda message: log('error', message)
success = lambda message: log('success', message)
def load_config(config_path: str) -> bool:
"""从配置文件加载日志配置"""
try:
if not os.path.exists(config_path):
return False
config = ConfigParser()
config.read(config_path, encoding='utf-8')
if 'log' in config:
log_config = config['log']
if 'level' in log_config:
set_log_level(log_config['level'])
if 'file_log' in log_config:
enable_file_log(log_config.getboolean('file_log'))
if 'log_path' in log_config:
global LOG_FILE_PATH
custom_path = log_config['log_path']
if os.path.exists(custom_path) or os.makedirs(custom_path, exist_ok=True):
LOG_FILE_PATH = custom_path
return True
except Exception as e:
error(f"加载日志配置失败: {e}")
return False
def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'):
"""
输出标题
@@ -204,8 +152,4 @@ if __name__ == '__main__':
print("\n设置为DEBUG级别测试:")
set_log_level(LogLevelEnum.DEBUG)
debug('这是一个debug日志') # 现在会显示
print("\n启用文件日志测试:")
enable_file_log()
info('此日志将同时记录到文件')
debug('这是一个debug日志') # 现在会显示

View File

@@ -1,6 +1,9 @@
from fastapi import APIRouter, Depends
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from middleware.auth import SignRequired
from models.response import ResponseModel
from models.response import ResponseModel, TokenModel
from pkg.log import log
user_router = APIRouter(
prefix="/user",
@@ -18,14 +21,25 @@ user_settings_router = APIRouter(
summary='用户登录',
description='User login endpoint.',
)
def router_user_session() -> ResponseModel:
"""
User login endpoint.
async def router_user_session(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
) -> TokenModel:
Returns:
dict: A dictionary containing user session information.
"""
pass
import service.user.login
username = form_data.username
password = form_data.password
user = await service.user.login.login(username=username, password=password)
if user is None:
raise HTTPException(status_code=400, detail="Invalid username or password")
elif user == 1:
raise HTTPException(status_code=400, detail="User account is not fully registered")
elif user == 2:
raise HTTPException(status_code=403, detail="User account is banned")
return user
@user_router.post(
path='/',

View File

@@ -1,15 +1,63 @@
from models.setting import Setting
from models.response import TokenModel
from models.user import User
from pkg.log import log
async def login(
username: str,
password: str
):
) -> TokenModel | int | None:
"""
根据账号密码进行登录。
如果登录成功,返回一个 TokenModel 对象,包含访问令牌和刷新令牌以及它们的过期时间。
如果登录异常,返回 `int` 状态码,`1` 为未完成注册,`2` 为账号被封禁。
如果登录失败,返回 `None`。
:param username: 用户名或邮箱
:type username: str
:param password: 用户密码
:type password: str
:return: TokenModel 对象或状态码或 None
:rtype: TokenModel | int | None
"""
from pkg.password.pwd import Password
isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', type=bool)
captchaType = await Setting.get(type='auth', name='captcha_type', type=str)
isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', format='bool')
captchaType = await Setting.get(type='auth', name='captcha_type', format='str')
# [TODO] 验证码校验
# 验证用户是否存在
user = await User.get(email=username)
if not user:
log.debug(f"Cannot find user with email: {username}")
return None
# 验证密码是否正确
if not Password.verify(user.password, password):
log.debug(f"Password verification failed for user: {username}")
return None
# 验证用户是否可登录
if user.status == 1:
# 未完成注册
return 1
elif user.status == 2:
# 账号已被封禁
return 2
# 创建令牌
from pkg.JWT.JWT import create_access_token, create_refresh_token
access_token, access_expire = create_access_token(data={'sub': user.email})
refresh_token, refresh_expire = create_refresh_token(data={'sub': user.email})
return TokenModel(
access_token=access_token,
access_expires=access_expire,
refresh_token=refresh_token,
refresh_expires=refresh_expire,
)

View File

@@ -0,0 +1,8 @@
import pytest
from pkg.password.pwd import Password
def test_password():
for i in range(10):
password = Password.generate()
hashed_password = Password.hash(password)
assert Password.verify(hashed_password, password)