diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 35410ca..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# 默认忽略的文件
-/shelf/
-/workspace.xml
-# 基于编辑器的 HTTP 客户端请求
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/Server.iml b/.idea/Server.iml
deleted file mode 100644
index 050da18..0000000
--- a/.idea/Server.iml
+++ /dev/null
@@ -1,20 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml
deleted file mode 100644
index 6e20fd7..0000000
--- a/.idea/dataSources.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-
-
- sqlite.xerial
- true
- org.sqlite.JDBC
- jdbc:sqlite:$PROJECT_DIR$/disknext.db
- $ProjectFileDir$
-
-
-
\ No newline at end of file
diff --git a/.idea/encodings.xml b/.idea/encodings.xml
deleted file mode 100644
index 97626ba..0000000
--- a/.idea/encodings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index 5027b3a..0000000
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,143 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/material_theme_project_new.xml b/.idea/material_theme_project_new.xml
deleted file mode 100644
index 0831c9f..0000000
--- a/.idea/material_theme_project_new.xml
+++ /dev/null
@@ -1,12 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index df4982d..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,7 +0,0 @@
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 8f3a104..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/runConfigurations/_DiskNext.xml b/.idea/runConfigurations/_DiskNext.xml
deleted file mode 100644
index cee6378..0000000
--- a/.idea/runConfigurations/_DiskNext.xml
+++ /dev/null
@@ -1,26 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/runConfigurations/_DiskNext2.xml b/.idea/runConfigurations/_DiskNext2.xml
deleted file mode 100644
index a72c412..0000000
--- a/.idea/runConfigurations/_DiskNext2.xml
+++ /dev/null
@@ -1,21 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 35eb1dd..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/identifier.sqlite b/identifier.sqlite
deleted file mode 100644
index e69de29..0000000
diff --git a/main.py b/main.py
index e0d1944..2577be9 100644
--- a/main.py
+++ b/main.py
@@ -4,12 +4,17 @@ from pkg.conf import appmeta
from models.database import init_db
from models.migration import migration
from pkg.lifespan import lifespan
-from pkg.JWT import jwt
+from pkg.JWT import JWT
+from pkg.log import log
# 添加初始化数据库启动项
lifespan.add_startup(init_db)
lifespan.add_startup(migration)
-lifespan.add_startup(jwt.load_secret_key)
+lifespan.add_startup(JWT.load_secret_key)
+
+# 设置日志等级
+if appmeta.debug:
+ log.set_log_level(log.LogLevelEnum.DEBUG)
# 创建应用实例并设置元数据
app = FastAPI(
@@ -39,7 +44,7 @@ for router in routers.Router:
# 启动时打印欢迎信息
if __name__ == "__main__":
import uvicorn
-
+
if appmeta.debug:
uvicorn.run(app='main:app', host=appmeta.host, port=appmeta.port, reload=True)
else:
diff --git a/middleware/auth.py b/middleware/auth.py
index 60717df..5e408c2 100644
--- a/middleware/auth.py
+++ b/middleware/auth.py
@@ -1,30 +1,54 @@
-from typing import Annotated, Literal
-from fastapi import Depends
-from pkg.JWT import jwt
+from typing import Annotated, Optional
+from fastapi import Depends, HTTPException
+from models.user import User
+from pkg.JWT import JWT
+import jwt
+from jwt import InvalidTokenError
+
+credentials_exception = HTTPException(
+ status_code=401,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+)
async def AuthRequired(
- token: Annotated[str, Depends(jwt.oauth2_scheme)]
-) -> Literal[True]:
- '''
+ token: Annotated[str, Depends(JWT.oauth2_scheme)]
+) -> Optional["User"]:
+ """
AuthRequired 需要登录
- '''
- from models.user import User
+ """
+ try:
+ payload = jwt.decode(token, JWT.SECRET_KEY, algorithms="HS256")
+ username = payload.get("sub")
+
+ if username is None:
+ raise credentials_exception
+
+ # 从数据库获取用户信息
+ user = await User.get(email=username)
+ if not user:
+ raise credentials_exception
+
+ return user
+
+ except InvalidTokenError:
+ raise credentials_exception
async def SignRequired(
- token: Annotated[str, Depends(jwt.oauth2_scheme)]
-) -> Literal[True]:
- '''
- SignAuthRequired 需要登录并验证请求签名
- '''
- return True
+ token: Annotated[str, Depends(JWT.oauth2_scheme)]
+) -> Optional["User"]:
+ """
+ SignAuthRequired 需要验证请求签名
+ """
+ pass
async def AdminRequired(
- token: Annotated[str, Depends(jwt.oauth2_scheme)]
-) -> Literal[True]:
- '''
+ token: Annotated[str, Depends(JWT.oauth2_scheme)]
+) -> Optional["User"]:
+ """
验证是否为管理员。
使用方法:
>>> APIRouter(dependencies=[Depends(is_admin)])
- '''
+ """
pass
\ No newline at end of file
diff --git a/models/base.py b/models/base.py
index e1fd70a..2d4eb92 100644
--- a/models/base.py
+++ b/models/base.py
@@ -1,5 +1,3 @@
-# my_project/models/base.py
-
from typing import Optional
from sqlmodel import SQLModel, Field
diff --git a/models/database.py b/models/database.py
index 9c5bf5f..b326872 100644
--- a/models/database.py
+++ b/models/database.py
@@ -10,11 +10,11 @@ from typing import AsyncGenerator
ASYNC_DATABASE_URL = appmeta.database_url
engine: AsyncEngine = create_async_engine(
- ASYNC_DATABASE_URL,
+ ASYNC_DATABASE_URL,
echo=appmeta.debug,
- connect_args={"check_same_thread": False}
- if ASYNC_DATABASE_URL.startswith("sqlite")
- else None,
+ connect_args={
+ "check_same_thread": False
+ } if ASYNC_DATABASE_URL.startswith("sqlite") else None,
future=True,
# pool_size=POOL_SIZE,
# max_overflow=64,
diff --git a/models/group.py b/models/group.py
index 9ade27e..16f2974 100644
--- a/models/group.py
+++ b/models/group.py
@@ -92,6 +92,7 @@ class Group(BaseModel, table=True):
try:
session.add(group)
await session.commit()
+ await session.refresh(group)
except Exception as e:
await session.rollback()
raise e
diff --git a/models/migration.py b/models/migration.py
index 0099b74..2a747fc 100644
--- a/models/migration.py
+++ b/models/migration.py
@@ -212,7 +212,7 @@ async def init_default_user() -> None:
admin_user = User(
email="admin@yxqi.cn",
nick="admin",
- status=1, # 正常状态
+ status=0, # 正常状态
group_id=admin_group.id,
password=hashed_admin_password,
)
diff --git a/models/response.py b/models/response.py
index 824ad4e..1a7deb6 100644
--- a/models/response.py
+++ b/models/response.py
@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field
from typing import Literal, Union, Optional
+from datetime import datetime, timedelta, timezone
from uuid import uuid4
class ResponseModel(BaseModel):
@@ -7,6 +8,22 @@ class ResponseModel(BaseModel):
data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据")
msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示")
instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID,用于标识请求的唯一性")
+
+class TokenModel(BaseModel):
+ access_expires: datetime = Field(default=None, description="访问令牌的过期时间")
+ access_token: str = Field(default=None, description="访问令牌")
+ refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间")
+ refresh_token: str = Field(default=None, description="刷新令牌")
+
+class userModel(ResponseModel):
+ id: str = Field(default=None, description="用户ID")
+ username: str = Field(default=None, description="用户名")
+ email: Optional[str] = Field(default=None, description="用户邮箱")
+ avatar: Optional[str] = Field(default=None, description="用户头像URL")
+ is_active: bool = Field(default=True, description="用户是否激活")
+ is_admin: bool = Field(default=False, description="用户是否为管理员")
+ created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户创建时间")
+ updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户更新时间")
class SiteConfigModel(ResponseModel):
title: str = Field(default="DiskNext", description="网站标题")
diff --git a/models/user.py b/models/user.py
index 3a9ca0a..34f9336 100644
--- a/models/user.py
+++ b/models/user.py
@@ -4,6 +4,8 @@ from typing import Optional, TYPE_CHECKING
from datetime import datetime
from sqlmodel import Field, Relationship, Column, func, DateTime
from .base import BaseModel
+from .database import get_session
+from sqlmodel import select
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
if TYPE_CHECKING:
@@ -130,6 +132,7 @@ class User(BaseModel, table=True):
try:
session.add(user)
await session.commit()
+ await session.refresh(user)
except Exception as e:
await session.rollback()
raise e
@@ -150,10 +153,6 @@ class User(BaseModel, table=True):
:return: 用户对象或 None
:rtype: Optional[User]
"""
-
- from .database import get_session
- from sqlmodel import select
-
session = get_session()
if id is None and email is None:
@@ -193,10 +192,6 @@ class User(BaseModel, table=True):
:return: 更新后的用户对象
:rtype: User
"""
-
- from .database import get_session
- from sqlmodel import select
-
async for session in get_session():
try:
statement = select(User).where(User.id == id)
@@ -248,9 +243,8 @@ class User(BaseModel, table=True):
:param id: 用户ID
:type id: int
"""
-
- from .database import get_session
- from sqlmodel import select
+ if id == 1:
+ raise ValueError("Cannot delete the default admin user with id 1.")
async for session in get_session():
try:
diff --git a/pkg/JWT/jwt.py b/pkg/JWT/jwt.py
index 4fcca1c..0b9f526 100644
--- a/pkg/JWT/jwt.py
+++ b/pkg/JWT/jwt.py
@@ -1,5 +1,7 @@
from fastapi.security import OAuth2PasswordBearer
from models.setting import Setting
+from datetime import datetime, timedelta, timezone
+import jwt
oauth2_scheme = OAuth2PasswordBearer(
scheme_name='获取 JWT Bearer 令牌',
@@ -17,4 +19,26 @@ async def load_secret_key() -> None:
:type key: str
"""
global SECRET_KEY
- SECRET_KEY = await Setting.get(type='auth', name='secret_key')
\ No newline at end of file
+ SECRET_KEY = await Setting.get(type='auth', name='secret_key')
+
+# 访问令牌
+def create_access_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]:
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.now(timezone.utc) + expires_delta
+ else:
+ expire = datetime.now(timezone.utc) + timedelta(hours=3)
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256')
+ return encoded_jwt, expire
+
+# 刷新令牌
+def create_refresh_token(data: dict, expires_delta: timedelta | None = None) -> tuple[str, datetime]:
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.now(timezone.utc) + expires_delta
+ else:
+ expire = datetime.now(timezone.utc) + timedelta(days=30)
+ to_encode.update({"exp": expire, "token_type": "refresh"})
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm='HS256')
+ return encoded_jwt, expire
\ No newline at end of file
diff --git a/pkg/log/log.py b/pkg/log/log.py
index 50718af..40fb33b 100644
--- a/pkg/log/log.py
+++ b/pkg/log/log.py
@@ -1,7 +1,6 @@
from rich import print
from rich.console import Console
from rich.markdown import Markdown
-from configparser import ConfigParser
from typing import Literal, Optional, Dict, Union
from enum import Enum
import time
@@ -17,10 +16,6 @@ class LogLevelEnum(str, Enum):
# 默认日志级别
LogLevel = LogLevelEnum.INFO
-# 日志文件路径
-LOG_FILE_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'logs')
-# 是否启用文件日志
-ENABLE_FILE_LOG = False
def set_log_level(level: Union[str, LogLevelEnum]) -> None:
"""设置日志级别"""
@@ -33,17 +28,6 @@ def set_log_level(level: Union[str, LogLevelEnum]) -> None:
else:
LogLevel = level
-def enable_file_log(enable: bool = True) -> None:
- """启用或禁用文件日志"""
- global ENABLE_FILE_LOG
- ENABLE_FILE_LOG = enable
- if enable and not os.path.exists(LOG_FILE_PATH):
- try:
- os.makedirs(LOG_FILE_PATH)
- except Exception as e:
- print(f"[bold red]创建日志目录失败: {e}[/bold red]")
- ENABLE_FILE_LOG = False
-
def truncate_path(full_path: str, marker: str = "HeyAuth") -> str:
"""截断路径,只保留从marker开始的部分"""
try:
@@ -114,17 +98,6 @@ def log(level: str = 'debug', message: str = ''):
if should_log:
print(log_message)
-
- # 文件日志记录
- if ENABLE_FILE_LOG:
- try:
- # 去除rich格式化标记
- clean_message = f"{level_value.upper()}\t{timestamp} From {filename}, line {lineno} {message}"
- log_file = os.path.join(LOG_FILE_PATH, f"{time.strftime('%Y%m%d')}.log")
- with open(log_file, 'a', encoding='utf-8') as f:
- f.write(f"{clean_message}\n")
- except Exception as e:
- print(f"[bold red]写入日志文件失败: {e}[/bold red]")
# 便捷日志函数
debug = lambda message: log('debug', message)
@@ -133,31 +106,6 @@ warning = lambda message: log('warn', message)
error = lambda message: log('error', message)
success = lambda message: log('success', message)
-def load_config(config_path: str) -> bool:
- """从配置文件加载日志配置"""
- try:
- if not os.path.exists(config_path):
- return False
-
- config = ConfigParser()
- config.read(config_path, encoding='utf-8')
-
- if 'log' in config:
- log_config = config['log']
- if 'level' in log_config:
- set_log_level(log_config['level'])
- if 'file_log' in log_config:
- enable_file_log(log_config.getboolean('file_log'))
- if 'log_path' in log_config:
- global LOG_FILE_PATH
- custom_path = log_config['log_path']
- if os.path.exists(custom_path) or os.makedirs(custom_path, exist_ok=True):
- LOG_FILE_PATH = custom_path
- return True
- except Exception as e:
- error(f"加载日志配置失败: {e}")
- return False
-
def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'):
"""
输出标题
@@ -204,8 +152,4 @@ if __name__ == '__main__':
print("\n设置为DEBUG级别测试:")
set_log_level(LogLevelEnum.DEBUG)
- debug('这是一个debug日志') # 现在会显示
-
- print("\n启用文件日志测试:")
- enable_file_log()
- info('此日志将同时记录到文件')
\ No newline at end of file
+ debug('这是一个debug日志') # 现在会显示
\ No newline at end of file
diff --git a/routers/controllers/user.py b/routers/controllers/user.py
index 55fc19d..dd0cc4f 100644
--- a/routers/controllers/user.py
+++ b/routers/controllers/user.py
@@ -1,6 +1,9 @@
-from fastapi import APIRouter, Depends
+from typing import Annotated
+from fastapi import APIRouter, Depends, HTTPException
+from fastapi.security import OAuth2PasswordRequestForm
from middleware.auth import SignRequired
-from models.response import ResponseModel
+from models.response import ResponseModel, TokenModel
+from pkg.log import log
user_router = APIRouter(
prefix="/user",
@@ -18,14 +21,25 @@ user_settings_router = APIRouter(
summary='用户登录',
description='User login endpoint.',
)
-def router_user_session() -> ResponseModel:
- """
- User login endpoint.
+async def router_user_session(
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
+) -> TokenModel:
- Returns:
- dict: A dictionary containing user session information.
- """
- pass
+ import service.user.login
+
+ username = form_data.username
+ password = form_data.password
+
+ user = await service.user.login.login(username=username, password=password)
+
+ if user is None:
+ raise HTTPException(status_code=400, detail="Invalid username or password")
+ elif user == 1:
+ raise HTTPException(status_code=400, detail="User account is not fully registered")
+ elif user == 2:
+ raise HTTPException(status_code=403, detail="User account is banned")
+
+ return user
@user_router.post(
path='/',
diff --git a/service/user/login.py b/service/user/login.py
index 8827549..160c494 100644
--- a/service/user/login.py
+++ b/service/user/login.py
@@ -1,15 +1,63 @@
from models.setting import Setting
+from models.response import TokenModel
+from models.user import User
+from pkg.log import log
async def login(
username: str,
password: str
-):
+) -> TokenModel | int | None:
"""
+ 根据账号密码进行登录。
+
+ 如果登录成功,返回一个 TokenModel 对象,包含访问令牌和刷新令牌以及它们的过期时间。
+ 如果登录异常,返回 `int` 状态码,`1` 为未完成注册,`2` 为账号被封禁。
+ 如果登录失败,返回 `None`。
+
+ :param username: 用户名或邮箱
+ :type username: str
+ :param password: 用户密码
+ :type password: str
+
+ :return: TokenModel 对象或状态码或 None
+ :rtype: TokenModel | int | None
"""
+ from pkg.password.pwd import Password
- isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', type=bool)
- captchaType = await Setting.get(type='auth', name='captcha_type', type=str)
+ isCaptchaRequired = await Setting.get(type='auth', name='login_captcha', format='bool')
+ captchaType = await Setting.get(type='auth', name='captcha_type', format='str')
# [TODO] 验证码校验
-
\ No newline at end of file
+ # 验证用户是否存在
+ user = await User.get(email=username)
+
+ if not user:
+ log.debug(f"Cannot find user with email: {username}")
+ return None
+
+ # 验证密码是否正确
+ if not Password.verify(user.password, password):
+ log.debug(f"Password verification failed for user: {username}")
+ return None
+
+ # 验证用户是否可登录
+ if user.status == 1:
+ # 未完成注册
+ return 1
+ elif user.status == 2:
+ # 账号已被封禁
+ return 2
+
+ # 创建令牌
+ from pkg.JWT.JWT import create_access_token, create_refresh_token
+
+ access_token, access_expire = create_access_token(data={'sub': user.email})
+ refresh_token, refresh_expire = create_refresh_token(data={'sub': user.email})
+
+ return TokenModel(
+ access_token=access_token,
+ access_expires=access_expire,
+ refresh_token=refresh_token,
+ refresh_expires=refresh_expire,
+ )
\ No newline at end of file
diff --git a/tests/test_pkg_password.py b/tests/test_pkg_password.py
new file mode 100644
index 0000000..2877230
--- /dev/null
+++ b/tests/test_pkg_password.py
@@ -0,0 +1,8 @@
+import pytest
+from pkg.password.pwd import Password
+
+def test_password():
+ for i in range(10):
+ password = Password.generate()
+ hashed_password = Password.hash(password)
+ assert Password.verify(hashed_password, password)
\ No newline at end of file