diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 35410ca..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,8 +0,0 @@ -# 默认忽略的文件 -/shelf/ -/workspace.xml -# 基于编辑器的 HTTP 客户端请求 -/httpRequests/ -# Datasource local storage ignored files -/dataSources/ -/dataSources.local.xml diff --git a/.idea/Server.iml b/.idea/Server.iml deleted file mode 100644 index 050da18..0000000 --- a/.idea/Server.iml +++ /dev/null @@ -1,20 +0,0 @@ - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml deleted file mode 100644 index 6e20fd7..0000000 --- a/.idea/dataSources.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - sqlite.xerial - true - org.sqlite.JDBC - jdbc:sqlite:$PROJECT_DIR$/disknext.db - $ProjectFileDir$ - - - \ No newline at end of file diff --git a/.idea/encodings.xml b/.idea/encodings.xml deleted file mode 100644 index 97626ba..0000000 --- a/.idea/encodings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 5027b3a..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,143 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml deleted file mode 100644 index 0831c9f..0000000 --- a/.idea/material_theme_project_new.xml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index df4982d..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index 8f3a104..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/_DiskNext.xml b/.idea/runConfigurations/_DiskNext.xml deleted file mode 100644 index cee6378..0000000 --- a/.idea/runConfigurations/_DiskNext.xml +++ /dev/null @@ -1,26 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/runConfigurations/_DiskNext2.xml b/.idea/runConfigurations/_DiskNext2.xml deleted file mode 100644 index a72c412..0000000 --- a/.idea/runConfigurations/_DiskNext2.xml +++ /dev/null @@ -1,21 +0,0 @@ - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/identifier.sqlite b/identifier.sqlite deleted file mode 100644 index e69de29..0000000 diff --git a/main.py b/main.py index e0d1944..2577be9 100644 --- a/main.py +++ b/main.py @@ -4,12 +4,17 @@ from pkg.conf import appmeta from models.database import init_db from models.migration import migration from pkg.lifespan import lifespan -from pkg.JWT import jwt +from pkg.JWT import JWT +from pkg.log import log # 添加初始化数据库启动项 lifespan.add_startup(init_db) lifespan.add_startup(migration) -lifespan.add_startup(jwt.load_secret_key) +lifespan.add_startup(JWT.load_secret_key) + +# 设置日志等级 +if appmeta.debug: + log.set_log_level(log.LogLevelEnum.DEBUG) # 创建应用实例并设置元数据 app = FastAPI( @@ -39,7 +44,7 @@ for router in routers.Router: # 启动时打印欢迎信息 if __name__ == "__main__": import uvicorn - + if appmeta.debug: uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True) else: diff --git a/middleware/auth.py b/middleware/auth.py index 60717df..5e408c2 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -1,30 +1,54 @@ -from typing import Annotated, Literal -from fastapi import Depends -from pkg.JWT import jwt +from typing import Annotated, Optional +from fastapi import Depends, HTTPException +from models.user import User +from pkg.JWT import JWT +import jwt +from jwt import InvalidTokenError + +credentials_exception = HTTPException( + status_code=401, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, +) async def AuthRequired( - token: Annotated[str, Depends(jwt.oauth2_scheme)] -) -> Literal[True]: - ''' + token: Annotated[str, Depends(JWT.oauth2_scheme)] +) -> Optional["User"]: + """ AuthRequired 需要登录 - ''' - from models.user import User + """ + try: + payload = jwt.decode(token, JWT.SECRET_KEY, algorithms="HS256") + username = payload.get("sub") + + if username is None: + raise credentials_exception + + # 从数据库获取用户信息 + user = await User.get(email=username) + if not user: + raise credentials_exception + + return user + + except InvalidTokenError: + raise credentials_exception async def SignRequired( - token: Annotated[str, Depends(jwt.oauth2_scheme)] -) -> Literal[True]: - ''' - SignAuthRequired 需要登录并验证请求签名 - ''' - return True + token: Annotated[str, Depends(JWT.oauth2_scheme)] +) -> Optional["User"]: + """ + SignAuthRequired 需要验证请求签名 + """ + pass async def AdminRequired( - token: Annotated[str, Depends(jwt.oauth2_scheme)] -) -> Literal[True]: - ''' + token: Annotated[str, Depends(JWT.oauth2_scheme)] +) -> Optional["User"]: + """ 验证是否为管理员。 使用方法: >>> APIRouter(dependencies=[Depends(is_admin)]) - ''' + """ pass \ No newline at end of file diff --git a/models/base.py b/models/base.py index e1fd70a..2d4eb92 100644 --- a/models/base.py +++ b/models/base.py @@ -1,5 +1,3 @@ -# my_project/models/base.py - from typing import Optional from sqlmodel import SQLModel, Field diff --git a/models/database.py b/models/database.py index 9c5bf5f..b326872 100644 --- a/models/database.py +++ b/models/database.py @@ -10,11 +10,11 @@ from typing import AsyncGenerator ASYNC_DATABASE_URL = appmeta.database_url engine: AsyncEngine = create_async_engine( - ASYNC_DATABASE_URL, + ASYNC_DATABASE_URL, echo=appmeta.debug, - connect_args={"check_same_thread": False} - if ASYNC_DATABASE_URL.startswith("sqlite") - else None, + connect_args={ + "check_same_thread": False + } if ASYNC_DATABASE_URL.startswith("sqlite") else None, future=True, # pool_size=POOL_SIZE, # max_overflow=64, diff --git a/models/group.py b/models/group.py index 9ade27e..16f2974 100644 --- a/models/group.py +++ b/models/group.py @@ -92,6 +92,7 @@ class Group(BaseModel, table=True): try: session.add(group) await session.commit() + await session.refresh(group) except Exception as e: await session.rollback() raise e diff --git a/models/migration.py b/models/migration.py index 0099b74..2a747fc 100644 --- a/models/migration.py +++ b/models/migration.py @@ -212,7 +212,7 @@ async def init_default_user() -> None: admin_user = User( email="admin@yxqi.cn", nick="admin", - status=1, # 正常状态 + status=0, # 正常状态 group_id=admin_group.id, password=hashed_admin_password, ) diff --git a/models/response.py b/models/response.py index 824ad4e..1a7deb6 100644 --- a/models/response.py +++ b/models/response.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, Field from typing import Literal, Union, Optional +from datetime import datetime, timedelta, timezone from uuid import uuid4 class ResponseModel(BaseModel): @@ -7,6 +8,22 @@ class ResponseModel(BaseModel): data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据") msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示") instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID,用于标识请求的唯一性") + +class TokenModel(BaseModel): + access_expires: datetime = Field(default=None, description="访问令牌的过期时间") + access_token: str = Field(default=None, description="访问令牌") + refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间") + refresh_token: str = Field(default=None, description="刷新令牌") + +class userModel(ResponseModel): + id: str = Field(default=None, description="用户ID") + username: str = Field(default=None, description="用户名") + email: Optional[str] = Field(default=None, description="用户邮箱") + avatar: Optional[str] = Field(default=None, description="用户头像URL") + is_active: bool = Field(default=True, description="用户是否激活") + is_admin: bool = Field(default=False, description="用户是否为管理员") + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户创建时间") + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户更新时间") class SiteConfigModel(ResponseModel): title: str = Field(default="DiskNext", description="网站标题") diff --git a/models/user.py b/models/user.py index 3a9ca0a..34f9336 100644 --- a/models/user.py +++ b/models/user.py @@ -4,6 +4,8 @@ from typing import Optional, TYPE_CHECKING from datetime import datetime from sqlmodel import Field, Relationship, Column, func, DateTime from .base import BaseModel +from .database import get_session +from sqlmodel import select # TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入 if TYPE_CHECKING: @@ -130,6 +132,7 @@ class User(BaseModel, table=True): try: session.add(user) await session.commit() + await session.refresh(user) except Exception as e: await session.rollback() raise e @@ -150,10 +153,6 @@ class User(BaseModel, table=True): :return: 用户对象或 None :rtype: Optional[User] """ - - from .database import get_session - from sqlmodel import select - session = get_session() if id is None and email is None: @@ -193,10 +192,6 @@ class User(BaseModel, table=True): :return: 更新后的用户对象 :rtype: User """ - - from .database import get_session - from sqlmodel import select - async for session in get_session(): try: statement = select(User).where(User.id == id) @@ -248,9 +243,8 @@ class User(BaseModel, table=True): :param id: 用户ID :type id: int """ - - from .database import get_session - from sqlmodel import select + if id == 1: + raise ValueError("Cannot delete the default admin user with id 1.") async for session in get_session(): try: diff --git a/pkg/JWT/jwt.py b/pkg/JWT/jwt.py index 4fcca1c..0b9f526 100644 --- a/pkg/JWT/jwt.py +++ b/pkg/JWT/jwt.py @@ -1,5 +1,7 @@ from fastapi.security import OAuth2PasswordBearer from models.setting import Setting +from datetime import datetime, timedelta, timezone +import jwt oauth2_scheme = OAuth2PasswordBearer( scheme_name='获取 JWT Bearer 令牌', @@ -17,4 +19,26 @@ async def load_secret_key() -> None: :type key: str """ global SECRET_KEY - SECRET_KEY = await Setting.get(type='auth', name='secret_key') \ No newline at end of file + SECRET_KEY = await Setting.get(type='auth', name='secret_key') + +# 访问令牌 +def create_access_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(hours=3) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256') + return encoded_jwt, expire + +# 刷新令牌 +def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]: + to_encode = data.copy() + if expires_delta: + expire = datetime.now(timezone.utc) + expires_delta + else: + expire = datetime.now(timezone.utc) + timedelta(days=30) + to_encode.update({"exp": expire, "token_type": "refresh"}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256') + return encoded_jwt, expire \ No newline at end of file diff --git a/pkg/log/log.py b/pkg/log/log.py index 50718af..40fb33b 100644 --- a/pkg/log/log.py +++ b/pkg/log/log.py @@ -1,7 +1,6 @@ from rich import print from rich.console import Console from rich.markdown import Markdown -from configparser import ConfigParser from typing import Literal, Optional, Dict, Union from enum import Enum import time @@ -17,10 +16,6 @@ class LogLevelEnum(str, Enum): # 默认日志级别 LogLevel = LogLevelEnum.INFO -# 日志文件路径 -LOG_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs') -# 是否启用文件日志 -ENABLE_FILE_LOG = False def set_log_level(level: Union[str, LogLevelEnum]) -> None: """设置日志级别""" @@ -33,17 +28,6 @@ def set_log_level(level: Union[str, LogLevelEnum]) -> None: else: LogLevel = level -def enable_file_log(enable: bool = True) -> None: - """启用或禁用文件日志""" - global ENABLE_FILE_LOG - ENABLE_FILE_LOG = enable - if enable and not os.path.exists(LOG_FILE_PATH): - try: - os.makedirs(LOG_FILE_PATH) - except Exception as e: - print(f"[bold red]创建日志目录失败: {e}[/bold red]") - ENABLE_FILE_LOG = False - def truncate_path(full_path: str, marker: str = "HeyAuth") -> str: """截断路径,只保留从marker开始的部分""" try: @@ -114,17 +98,6 @@ def log(level: str = 'debug', message: str = ''): if should_log: print(log_message) - - # 文件日志记录 - if ENABLE_FILE_LOG: - try: - # 去除rich格式化标记 - clean_message = f"{level_value.upper()}\t{timestamp} From {filename}, line {lineno} {message}" - log_file = os.path.join(LOG_FILE_PATH, f"{time.strftime('%Y%m%d')}.log") - with open(log_file, 'a', encoding='utf-8') as f: - f.write(f"{clean_message}\n") - except Exception as e: - print(f"[bold red]写入日志文件失败: {e}[/bold red]") # 便捷日志函数 debug = lambda message: log('debug', message) @@ -133,31 +106,6 @@ warning = lambda message: log('warn', message) error = lambda message: log('error', message) success = lambda message: log('success', message) -def load_config(config_path: str) -> bool: - """从配置文件加载日志配置""" - try: - if not os.path.exists(config_path): - return False - - config = ConfigParser() - config.read(config_path, encoding='utf-8') - - if 'log' in config: - log_config = config['log'] - if 'level' in log_config: - set_log_level(log_config['level']) - if 'file_log' in log_config: - enable_file_log(log_config.getboolean('file_log')) - if 'log_path' in log_config: - global LOG_FILE_PATH - custom_path = log_config['log_path'] - if os.path.exists(custom_path) or os.makedirs(custom_path, exist_ok=True): - LOG_FILE_PATH = custom_path - return True - except Exception as e: - error(f"加载日志配置失败: {e}") - return False - def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'): """ 输出标题 @@ -204,8 +152,4 @@ if __name__ == '__main__': print("\n设置为DEBUG级别测试:") set_log_level(LogLevelEnum.DEBUG) - debug('这是一个debug日志') # 现在会显示 - - print("\n启用文件日志测试:") - enable_file_log() - info('此日志将同时记录到文件') \ No newline at end of file + debug('这是一个debug日志') # 现在会显示 \ No newline at end of file diff --git a/routers/controllers/user.py b/routers/controllers/user.py index 55fc19d..dd0cc4f 100644 --- a/routers/controllers/user.py +++ b/routers/controllers/user.py @@ -1,6 +1,9 @@ -from fastapi import APIRouter, Depends +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException +from fastapi.security import OAuth2PasswordRequestForm from middleware.auth import SignRequired -from models.response import ResponseModel +from models.response import ResponseModel, TokenModel +from pkg.log import log user_router = APIRouter( prefix="/user", @@ -18,14 +21,25 @@ user_settings_router = APIRouter( summary='用户登录', description='User login endpoint.', ) -def router_user_session() -> ResponseModel: - """ - User login endpoint. +async def router_user_session( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()] +) -> TokenModel: - Returns: - dict: A dictionary containing user session information. - """ - pass + import service.user.login + + username = form_data.username + password = form_data.password + + user = await service.user.login.login(username=username, password=password) + + if user is None: + raise HTTPException(status_code=400, detail="Invalid username or password") + elif user == 1: + raise HTTPException(status_code=400, detail="User account is not fully registered") + elif user == 2: + raise HTTPException(status_code=403, detail="User account is banned") + + return user @user_router.post( path='/', diff --git a/service/user/login.py b/service/user/login.py index 8827549..160c494 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,15 +1,63 @@ from models.setting import Setting +from models.response import TokenModel +from models.user import User +from pkg.log import log async def login( username: str, password: str -): +) -> TokenModel | int | None: """ + 根据账号密码进行登录。 + + 如果登录成功,返回一个 TokenModel 对象,包含访问令牌和刷新令牌以及它们的过期时间。 + 如果登录异常,返回 `int` 状态码,`1` 为未完成注册,`2` 为账号被封禁。 + 如果登录失败,返回 `None`。 + + :param username: 用户名或邮箱 + :type username: str + :param password: 用户密码 + :type password: str + + :return: TokenModel 对象或状态码或 None + :rtype: TokenModel | int | None """ + from pkg.password.pwd import Password - isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', type=bool) - captchaType = await Setting.get(type='auth', name='captcha_type', type=str) + isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', format='bool') + captchaType = await Setting.get(type='auth', name='captcha_type', format='str') # [TODO] 验证码校验 - \ No newline at end of file + # 验证用户是否存在 + user = await User.get(email=username) + + if not user: + log.debug(f"Cannot find user with email: {username}") + return None + + # 验证密码是否正确 + if not Password.verify(user.password, password): + log.debug(f"Password verification failed for user: {username}") + return None + + # 验证用户是否可登录 + if user.status == 1: + # 未完成注册 + return 1 + elif user.status == 2: + # 账号已被封禁 + return 2 + + # 创建令牌 + from pkg.JWT.JWT import create_access_token, create_refresh_token + + access_token, access_expire = create_access_token(data={'sub': user.email}) + refresh_token, refresh_expire = create_refresh_token(data={'sub': user.email}) + + return TokenModel( + access_token=access_token, + access_expires=access_expire, + refresh_token=refresh_token, + refresh_expires=refresh_expire, + ) \ No newline at end of file diff --git a/tests/test_pkg_password.py b/tests/test_pkg_password.py new file mode 100644 index 0000000..2877230 --- /dev/null +++ b/tests/test_pkg_password.py @@ -0,0 +1,8 @@ +import pytest +from pkg.password.pwd import Password + +def test_password(): + for i in range(10): + password = Password.generate() + hashed_password = Password.hash(password) + assert Password.verify(hashed_password, password) \ No newline at end of file