添加Github登录,优化数据库模型

This commit is contained in:
2025-09-01 00:21:06 +08:00
parent f3a5ae9c40
commit 2a173c0566
23 changed files with 321 additions and 264 deletions

8
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

17
.idea/Server.iml generated Normal file
View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.12 (Server)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
<component name="TestRunnerService">
<option name="PROJECT_TEST_RUNNER" value="py.test" />
</component>
</module>

View File

@@ -0,0 +1,30 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyInterpreterInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="PyMissingTypeHintsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true" />
<inspection_tool class="PyPep8Inspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredErrors">
<list>
<option value="E402" />
<option value="E713" />
<option value="E271" />
<option value="E302" />
<option value="E265" />
<option value="W292" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false">
<option name="ignoredErrors">
<list>
<option value="N802" />
<option value="N806" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PySingleQuotedDocstringInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="Stylelint" enabled="true" level="ERROR" enabled_by_default="true" />
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

17
.idea/material_theme_project_new.xml generated Normal file
View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MaterialThemeProjectNewConfig">
<option name="metadata">
<MTProjectMetadataState>
<option name="migrated" value="true" />
<option name="pristineConfig" value="false" />
<option name="userId" value="4dc9a07a:18f958e7499:-7ffe" />
</MTProjectMetadataState>
</option>
<option name="titleBarState">
<MTProjectTitleBarConfigState>
<option name="overrideColor" value="false" />
</MTProjectTitleBarConfigState>
</option>
</component>
</project>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.12 (Server)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (Server)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Server.iml" filepath="$PROJECT_DIR$/.idea/Server.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@@ -12,16 +12,25 @@
目前正处于 `OMEGA` 实验阶段,比 `Alpha` 版还更早期,仅供测试。 目前正处于 `OMEGA` 实验阶段,比 `Alpha` 版还更早期,仅供测试。
## 特性
- 支持将文件存储到本地、远程节点、OneDrive 以及 S3兼容API、阿里云OSS等
- 内置离线下载服务亦可对接Aria2/qBittorrent 下载文件,可用多个节点分担下载任务
- 在线压缩/解压缩文件,支持批量下载
- 部署方便,开箱即用,亦可通过配置获得强大的生态能力
- 可信、现代化的安全能力JWT令牌、OAuth2、WebAuthn、全盘加密
- 兼容 WebDAV、Subsonic 接口
- 支持多用户、多群组,分级管理权限俱全
- 强大的分享链接管理支持分享页README渲染、媒体元数据展示
- 在线预览/编辑多种文件包括但不限于视频、图片、音频、PDF、ePub、Office、Markdown、图表等
- 自定义主题色、深浅色主题、PWA、i18n
## :alembic: 技术栈 ## :alembic: 技术栈
* [Python ](https://www.python.org/) + [FastAPI](https://fastapi.tiangolo.com/) * [Python](https://www.python.org/) + [FastAPI](https://fastapi.tiangolo.com/)
<!-- * [React](https://github.com/facebook/react) + [Redux](https://github.com/reduxjs/redux) + [Material-UI](https://github.com/mui-org/material-ui) --> <!-- * [React](https://github.com/facebook/react) + [Redux](https://github.com/reduxjs/redux) + [Material-UI](https://github.com/mui-org/material-ui) -->
## :scroll: 许可证 ## :scroll: 许可证
GPL V3 GPL V3
---
> GitHub [@Yuerchu](https://github.com/Yuerchu) &nbsp;&middot;&nbsp;
> Twitter [@LaBoyXiaoXin](https://twitter.com/LaBoyXiaoXin)

View File

@@ -5,7 +5,7 @@ from models.database import init_db
from models.migration import migration from models.migration import migration
from pkg.lifespan import lifespan from pkg.lifespan import lifespan
from pkg.JWT import JWT from pkg.JWT import JWT
from pkg.log import log from pkg.log import log, set_log_level
# 添加初始化数据库启动项 # 添加初始化数据库启动项
lifespan.add_startup(init_db) lifespan.add_startup(init_db)
@@ -14,7 +14,9 @@ lifespan.add_startup(JWT.load_secret_key)
# 设置日志等级 # 设置日志等级
if appmeta.debug: if appmeta.debug:
log.set_log_level(log.LogLevelEnum.DEBUG) set_log_level('DEBUG')
else:
set_log_level('INFO')
# 创建应用实例并设置元数据 # 创建应用实例并设置元数据
app = FastAPI( app = FastAPI(

View File

@@ -213,7 +213,7 @@ async def init_default_user() -> None:
admin_user = User( admin_user = User(
email="admin@yxqi.cn", email="admin@yxqi.cn",
nick="admin", nick="admin",
status=0, # 正常状态 status=True, # 正常状态
group_id=admin_group.id, group_id=admin_group.id,
password=hashed_admin_password, password=hashed_admin_password,
) )

View File

@@ -30,3 +30,10 @@ class Policy(BaseModel, table=True):
# 关系 # 关系
files: List["File"] = Relationship(back_populates="policy") files: List["File"] = Relationship(back_populates="policy")
folders: List["Folder"] = Relationship(back_populates="policy") folders: List["Folder"] = Relationship(back_populates="policy")
@staticmethod
async def create(
policy: Optional["Policy"] = None,
**kwargs
):
pass

17
models/request.py Normal file
View File

@@ -0,0 +1,17 @@
"""
请求模型定义
"""
from pydantic import BaseModel, Field
from typing import Literal, Union, Optional
from datetime import datetime, timezone
from uuid import uuid4
class LoginRequest(BaseModel):
"""
登录请求模型
"""
username: str = Field(..., description="用户名或邮箱")
password: str = Field(..., description="用户密码")
captcha: Optional[str] = Field(None, description="验证码")
twoFaCode: Optional[str] = Field(None, description="两步验证代码")

View File

@@ -1,3 +1,7 @@
"""
响应模型定义
"""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Literal, Union, Optional from typing import Literal, Union, Optional
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -102,3 +106,32 @@ class UserSettingModel(BaseModel):
themes: dict = Field(default_factory=dict, description="用户主题配置") themes: dict = Field(default_factory=dict, description="用户主题配置")
two_factor: bool = Field(default=False, description="是否启用两步验证") two_factor: bool = Field(default=False, description="是否启用两步验证")
uid: int = Field(default=0, description="用户UID") uid: int = Field(default=0, description="用户UID")
class FoldObjectModel(BaseModel):
id: str = Field(default=..., description="对象ID")
name: str = Field(default=..., description="对象名称")
path: str = Field(default=..., description="对象路径")
thumb: bool = Field(default=False, description="是否有缩略图")
size: int = Field(default=None, description="对象大小,单位字节")
type: Literal['file', 'folder'] = Field(default=..., description="对象类型file表示文件folder表示文件夹")
date: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="对象创建或修改时间")
create_date: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="对象创建时间")
source_enabled: bool = Field(default=False, description="是否启用离线下载源")
class PolicyModel(BaseModel):
'''
存储策略模型
'''
id: str = Field(default=..., description="策略ID")
name: str = Field(default=..., description="策略名称")
type: Literal['local', 'qiniu', 'tencent', 'aliyun', 'onedrive', 'google_drive', 'dropbox', 'webdav', 'remote'] = Field(default=..., description="存储类型")
max_size: int = Field(default=0, description="单文件最大限制单位字节0表示不限制")
file_type: list = Field(default_factory=list, description="允许的文件类型列表,空列表表示不限制")
class DirectoryModel(BaseModel):
'''
目录模型
'''
parent: str = Field(default=..., description="父目录ID")
objects: list[FoldObjectModel] = Field(default_factory=list, description="目录下的对象列表")
policy: PolicyModel = Field(default_factory=PolicyModel, description="存储策略")

View File

@@ -26,7 +26,7 @@ class User(BaseModel, table=True):
email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一") email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一")
nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称") nick: Optional[str] = Field(default=None, max_length=50, description="用户昵称")
password: str = Field(max_length=255, description="用户密码(加密后)") password: str = Field(max_length=255, description="用户密码(加密后)")
status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="用户状态: 0=正常, 1=未激活, 2=封禁") status: Optional[bool] = Field(default=None, sa_column_kwargs={"server_default": "0"}, description="用户状态: True=正常, None=未激活, False=封禁")
storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)") storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="已用存储空间(字节)")
two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥") two_factor: Optional[str] = Field(default=None, max_length=255, description="两步验证密钥")
avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址") avatar: Optional[str] = Field(default=None, max_length=255, description="头像地址")
@@ -64,17 +64,7 @@ class User(BaseModel, table=True):
@staticmethod @staticmethod
async def create( async def create(
user: Optional["User"] = None, user: Optional["User"] = None,
email: str = None, **kwargs
nick: Optional[str] = None,
password: str = None,
status: int = 0,
two_factor: Optional[str] = None,
avatar: Optional[str] = None,
options: Optional[str] = None,
authn: Optional[str] = None,
open_id: Optional[str] = None,
score: int = 0,
phone: Optional[str] = None
): ):
""" """
向数据库内添加用户。 向数据库内添加用户。
@@ -83,19 +73,7 @@ class User(BaseModel, table=True):
:type user: User :type user: User
""" """
if not user: if not user:
user = User( user = User(**kwargs)
email=email,
nick=nick,
password=password,
status=status,
two_factor=two_factor,
avatar=avatar,
options=options,
authn=authn,
open_id=open_id,
score=score,
phone=phone
)
from .database import get_session from .database import get_session

View File

@@ -13,7 +13,7 @@ BackendVersion = "0.0.1"
IsPro = False IsPro = False
debug: bool = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes") debug: bool = os.getenv("DEBUG", "false").lower() in ("true", "1", "yes") or False
if debug: if debug:
log.info("Debug mode is enabled. This is not recommended for production use.") log.info("Debug mode is enabled. This is not recommended for production use.")

2
pkg/log/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
from .log_handle import logger as log
from .log_handle import set_log_level

View File

@@ -1,155 +0,0 @@
from rich import print
from rich.console import Console
from rich.markdown import Markdown
from typing import Literal, Optional, Dict, Union
from enum import Enum
import time
import os
import inspect
class LogLevelEnum(str, Enum):
DEBUG = 'debug'
INFO = 'info'
WARNING = 'warning'
ERROR = 'error'
SUCCESS = 'success'
# 默认日志级别
LogLevel = LogLevelEnum.INFO
def set_log_level(level: Union[str, LogLevelEnum]) -> None:
"""设置日志级别"""
global LogLevel
if isinstance(level, str):
try:
LogLevel = LogLevelEnum(level.lower())
except ValueError:
print(f"[bold red]无效的日志级别: {level},使用默认级别: {LogLevel}[/bold red]")
else:
LogLevel = level
def truncate_path(full_path: str, marker: str = "HeyAuth") -> str:
"""截断路径只保留从marker开始的部分"""
try:
marker_index = full_path.find(marker)
if marker_index != -1:
return '.' + full_path[marker_index + len(marker):]
return full_path
except Exception:
return full_path
def get_caller_info(depth: int = 2) -> tuple:
"""获取调用者信息"""
try:
frame = inspect.currentframe()
# 向上查找指定深度的调用帧
for _ in range(depth):
if frame.f_back is None:
break
frame = frame.f_back
filename = frame.f_code.co_filename
lineno = frame.f_lineno
return truncate_path(filename), lineno
except Exception:
return "<unknown>", 0
finally:
# 确保引用被释放
del frame
def log(level: str = 'debug', message: str = ''):
"""
输出日志
---
通过传入的`level`和`message`参数,输出不同级别的日志信息。<br>
`level`参数为日志级别,支持`红色error`、`紫色info`、`绿色success`、`黄色warning`、`淡蓝色debug`。<br>
`message`参数为日志信息。<br>
"""
level_colors: Dict[str, str] = {
'debug': '[bold cyan][DEBUG][/bold cyan]',
'info': '[bold blue][INFO][/bold blue]',
'warning': '[bold yellow][WARN][/bold yellow]',
'error': '[bold red][ERROR][/bold red]',
'success': '[bold green][SUCCESS][/bold green]'
}
level_value = level.lower()
lv = level_colors.get(level_value, '[bold magenta][UNKNOWN][/bold magenta]')
# 获取调用者信息
filename, lineno = get_caller_info(3) # 考虑lambda调用和包装函数深度为3
timestamp = time.strftime('%Y/%m/%d %H:%M:%S %p', time.localtime())
log_message = f"{lv}\t{timestamp} [bold]From {filename}, line {lineno}[/bold] {message}"
# 根据日志级别判断是否输出
global LogLevel
should_log = False
if level_value == 'debug' and LogLevel == LogLevelEnum.DEBUG:
should_log = True
elif level_value == 'info' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO]:
should_log = True
elif level_value == 'warning' and LogLevel in [LogLevelEnum.DEBUG, LogLevelEnum.INFO, LogLevelEnum.WARNING]:
should_log = True
elif level_value == 'error':
should_log = True
elif level_value == 'success':
should_log = False
if should_log:
print(log_message)
# 便捷日志函数
debug = lambda message: log('debug', message)
info = lambda message: log('info', message)
warning = lambda message: log('warn', message)
error = lambda message: log('error', message)
success = lambda message: log('success', message)
def title(title: str = '海枫授权系统 HeyAuth', size: Optional[Literal['h1', 'h2', 'h3', 'h4', 'h5']] = 'h1'):
"""
输出标题
---
通过传入的`title`参数,输出一个整行的标题。<br>
`title`参数为标题内容。<br>
"""
try:
console = Console()
markdown_sizes = {
'h1': '# ',
'h2': '## ',
'h3': '### ',
'h4': '#### ',
'h5': '##### '
}
markdown_tag = markdown_sizes.get(size, '# ')
console.print(Markdown(markdown_tag + title))
except Exception as e:
error(f"输出标题失败: {e}")
finally:
if 'console' in locals():
del console
if True:
pass
if __name__ == '__main__':
# 测试代码
title('海枫授权系统 日志组件测试', 'h1')
title('测试h2标题', 'h2')
title('测试h3标题', 'h3')
title('测试h4标题', 'h4')
title('测试h5标题', 'h5')
print("\n默认日志级别(INFO)测试:")
debug('这是一个debug日志') # 不会显示
info('这是一个info日志')
warning('这是一个warning日志')
error('这是一个error日志')
success('这是一个success日志')
print("\n设置为DEBUG级别测试:")
set_log_level(LogLevelEnum.DEBUG)
debug('这是一个debug日志') # 现在会显示

34
pkg/log/log_handle.py Normal file
View File

@@ -0,0 +1,34 @@
import logging
from rich.logging import RichHandler
FOTMAT = "%(message)s"
logging.basicConfig(
level="NOTSET",
format=FOTMAT,
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)],
)
logger = logging.getLogger("rich")
def set_log_level(level: str):
"""
设置日志等级。
:param level: 日志等级 (DEBUG, INFO, WARNING, ERROR, CRITICAL)
:type level: str
"""
level = level.upper()
if level == "DEBUG":
logger.setLevel(logging.DEBUG)
elif level == "INFO":
logger.setLevel(logging.INFO)
elif level == "WARNING":
logger.setLevel(logging.WARNING)
elif level == "ERROR":
logger.setLevel(logging.ERROR)
elif level == "CRITICAL":
logger.setLevel(logging.CRITICAL)
else:
logger.setLevel(logging.INFO)
logger.warning(f"未知的日志等级 '{level}',已设置为默认等级 'INFO'")

View File

@@ -1,6 +1,6 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from middleware.auth import SignRequired from middleware.auth import SignRequired
from models.response import ResponseModel from models import response
directory_router = APIRouter( directory_router = APIRouter(
prefix="/directory", prefix="/directory",
@@ -13,7 +13,7 @@ directory_router = APIRouter(
description='Create a directory endpoint.', description='Create a directory endpoint.',
dependencies=[Depends(SignRequired)] dependencies=[Depends(SignRequired)]
) )
def router_directory_create() -> ResponseModel: def router_directory_create() -> response.ResponseModel:
""" """
Create a directory endpoint. Create a directory endpoint.
@@ -28,7 +28,7 @@ def router_directory_create() -> ResponseModel:
description='Get directory contents endpoint.', description='Get directory contents endpoint.',
dependencies=[Depends(SignRequired)] dependencies=[Depends(SignRequired)]
) )
def router_directory_get(path: str) -> ResponseModel: def router_directory_get(path: str) -> response.ResponseModel:
""" """
Get directory contents endpoint. Get directory contents endpoint.
@@ -38,4 +38,8 @@ def router_directory_get(path: str) -> ResponseModel:
Returns: Returns:
ResponseModel: A model containing the response data for the directory contents. ResponseModel: A model containing the response data for the directory contents.
""" """
pass return response.ResponseModel(
data=response.DirectoryModel(
)
)

View File

@@ -1,13 +1,26 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from middleware.auth import AuthRequired, SignRequired from middleware.auth import AuthRequired, AuthRequired
import models import models
from models.response import ResponseModel, TokenModel, userModel, groupModel, UserSettingModel
from deprecated import deprecated from deprecated import deprecated
from pkg.log import log from pkg.log import log
import service import service
from webauthn import (
generate_registration_options,
verify_authentication_response,
options_to_json,
base64url_to_bytes,
)
from webauthn.helpers import options_to_json_dict
from webauthn.helpers.structs import (
PublicKeyCredentialDescriptor,
UserVerificationRequirement,
)
user_router = APIRouter( user_router = APIRouter(
prefix="/user", prefix="/user",
tags=["user"], tags=["user"],
@@ -16,7 +29,7 @@ user_router = APIRouter(
user_settings_router = APIRouter( user_settings_router = APIRouter(
prefix='/user/settings', prefix='/user/settings',
tags=["user", "user_settings"], tags=["user", "user_settings"],
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
@user_router.post( @user_router.post(
@@ -26,11 +39,15 @@ user_settings_router = APIRouter(
) )
async def router_user_session( async def router_user_session(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()] form_data: Annotated[OAuth2PasswordRequestForm, Depends()]
) -> TokenModel: ) -> models.response.TokenModel:
username = form_data.username username = form_data.username
password = form_data.password password = form_data.password
is_login, detail = await service.user.Login(username=username, password=password) is_login, detail = await service.user.Login(
models.request.LoginRequest(
username=username, password=password
)
)
if not is_login: if not is_login:
if detail in ["User not found", "Incorrect password"]: if detail in ["User not found", "Incorrect password"]:
@@ -41,7 +58,7 @@ async def router_user_session(
raise HTTPException(status_code=403, detail="User account is banned") raise HTTPException(status_code=403, detail="User account is banned")
else: else:
raise HTTPException(status_code=500, detail="Internal server error during login") raise HTTPException(status_code=500, detail="Internal server error during login")
if isinstance(detail, TokenModel): if isinstance(detail, models.response.TokenModel):
return detail return detail
else: else:
log.error(f"Unexpected return type from login service: {type(detail)}") log.error(f"Unexpected return type from login service: {type(detail)}")
@@ -52,7 +69,7 @@ async def router_user_session(
summary='用户注册', summary='用户注册',
description='User registration endpoint.', description='User registration endpoint.',
) )
def router_user_register() -> ResponseModel: def router_user_register() -> models.response.ResponseModel:
""" """
User registration endpoint. User registration endpoint.
@@ -66,7 +83,7 @@ def router_user_register() -> ResponseModel:
summary='用两步验证登录', summary='用两步验证登录',
description='Two-factor authentication login endpoint.', description='Two-factor authentication login endpoint.',
) )
def router_user_2fa() -> ResponseModel: def router_user_2fa() -> models.response.ResponseModel:
""" """
Two-factor authentication login endpoint. Two-factor authentication login endpoint.
@@ -80,9 +97,9 @@ def router_user_2fa() -> ResponseModel:
summary='发送验证码邮件', summary='发送验证码邮件',
description='Send a verification code email.', description='Send a verification code email.',
) )
def router_user_email_code() -> ResponseModel: def router_user_email_code() -> models.response.ResponseModel:
""" """
Send a pas Send a verification code email.
Returns: Returns:
dict: A dictionary containing information about the password reset email. dict: A dictionary containing information about the password reset email.
@@ -98,7 +115,7 @@ def router_user_email_code() -> ResponseModel:
summary='通过邮件里的链接重设密码', summary='通过邮件里的链接重设密码',
description='Reset password via email link.', description='Reset password via email link.',
) )
def router_user_reset_patch() -> ResponseModel: def router_user_reset_patch() -> models.response.ResponseModel:
""" """
Reset password via email link. Reset password via email link.
@@ -112,7 +129,7 @@ def router_user_reset_patch() -> ResponseModel:
summary='初始化QQ登录', summary='初始化QQ登录',
description='Initialize QQ login for a user.', description='Initialize QQ login for a user.',
) )
def router_user_qq() -> ResponseModel: def router_user_qq() -> models.response.ResponseModel:
""" """
Initialize QQ login for a user. Initialize QQ login for a user.
@@ -126,16 +143,8 @@ def router_user_qq() -> ResponseModel:
summary='WebAuthn登录初始化', summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.', description='Initialize WebAuthn login for a user.',
) )
def router_user_authn(username: str) -> ResponseModel: async def router_user_authn(username: str) -> models.response.ResponseModel:
"""
Initialize WebAuthn login for a user.
Args:
username (str): The username of the user.
Returns:
dict: A dictionary containing WebAuthn initialization information.
"""
pass pass
@user_router.post( @user_router.post(
@@ -143,7 +152,7 @@ def router_user_authn(username: str) -> ResponseModel:
summary='WebAuthn登录', summary='WebAuthn登录',
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
) )
def router_user_authn_finish(username: str) -> ResponseModel: def router_user_authn_finish(username: str) -> models.response.ResponseModel:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -160,7 +169,7 @@ def router_user_authn_finish(username: str) -> ResponseModel:
summary='获取用户主页展示用分享', summary='获取用户主页展示用分享',
description='Get user profile for display.', description='Get user profile for display.',
) )
def router_user_profile(id: str) -> ResponseModel: def router_user_profile(id: str) -> models.response.ResponseModel:
""" """
Get user profile for display. Get user profile for display.
@@ -177,7 +186,7 @@ def router_user_profile(id: str) -> ResponseModel:
summary='获取用户头像', summary='获取用户头像',
description='Get user avatar by ID and size.', description='Get user avatar by ID and size.',
) )
def router_user_avatar(id: str, size: int = 128) -> ResponseModel: def router_user_avatar(id: str, size: int = 128) -> models.response.ResponseModel:
""" """
Get user avatar by ID and size. Get user avatar by ID and size.
@@ -199,28 +208,28 @@ def router_user_avatar(id: str, size: int = 128) -> ResponseModel:
summary='获取用户信息', summary='获取用户信息',
description='Get user information.', description='Get user information.',
dependencies=[Depends(dependency=AuthRequired)], dependencies=[Depends(dependency=AuthRequired)],
response_model=ResponseModel, response_model=models.response.ResponseModel,
) )
async def router_user_me( async def router_user_me(
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
) -> ResponseModel: ) -> models.response.ResponseModel:
""" """
获取用户信息. 获取用户信息.
:return: ResponseModel containing user information. :return: response.ResponseModel containing user information.
:rtype: ResponseModel :rtype: response.ResponseModel
""" """
group = await models.Group.get(id=user.group_id) group = await models.Group.get(id=user.group_id)
user_group = groupModel( user_group = models.response.groupModel(
id=group.id, id=group.id,
name=group.name, name=group.name,
allowShare=group.share_enabled, allowShare=group.share_enabled,
) )
users = userModel( users = models.response.userModel(
id=user.id, id=user.id,
username=user.email, username=user.email,
nickname=user.nick, nickname=user.nick,
@@ -231,7 +240,7 @@ async def router_user_me(
).model_dump() ).model_dump()
return ResponseModel( return models.response.ResponseModel(
data=users data=users
) )
@@ -239,18 +248,18 @@ async def router_user_me(
path='/storage', path='/storage',
summary='存储信息', summary='存储信息',
description='Get user storage information.', description='Get user storage information.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_storage( def router_user_storage(
user: Annotated[models.user.User, Depends(AuthRequired)], user: Annotated[models.user.User, Depends(AuthRequired)],
) -> ResponseModel: ) -> models.response.ResponseModel:
""" """
Get user storage information. Get user storage information.
Returns: Returns:
dict: A dictionary containing user storage information. dict: A dictionary containing user storage information.
""" """
return ResponseModel( return models.response.ResponseModel(
data={ data={
"used": 0, "used": 0,
"free": 0, "free": 0,
@@ -262,24 +271,40 @@ def router_user_storage(
path='/authn/start', path='/authn/start',
summary='WebAuthn登录初始化', summary='WebAuthn登录初始化',
description='Initialize WebAuthn login for a user.', description='Initialize WebAuthn login for a user.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_authn_start() -> ResponseModel: async def router_user_authn_start(
user: Annotated[models.user.User, Depends(AuthRequired)]
) -> models.response.ResponseModel:
""" """
Initialize WebAuthn login for a user. Initialize WebAuthn login for a user.
Returns: Returns:
dict: A dictionary containing WebAuthn initialization information. dict: A dictionary containing WebAuthn initialization information.
""" """
pass # [TODO] 检查 WebAuthn 是否开启,用户是否有注册过 WebAuthn 设备等
if not await models.Setting.get(type="authn", name="authn_enabled", format="bool"):
raise HTTPException(status_code=400, detail="WebAuthn is not enabled")
options = generate_registration_options(
rp_id=await models.Setting.get(type="basic", name="siteURL"),
rp_name=await models.Setting.get(type="basic", name="siteTitle"),
user_name=user.email,
user_display_name=user.nick or user.email,
)
return models.response.ResponseModel(
data=options_to_json_dict(options)
)
@user_router.put( @user_router.put(
path='/authn/finish', path='/authn/finish',
summary='WebAuthn登录', summary='WebAuthn登录',
description='Finish WebAuthn login for a user.', description='Finish WebAuthn login for a user.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_authn_finish() -> ResponseModel: def router_user_authn_finish() -> models.response.ResponseModel:
""" """
Finish WebAuthn login for a user. Finish WebAuthn login for a user.
@@ -293,7 +318,7 @@ def router_user_authn_finish() -> ResponseModel:
summary='获取用户可选存储策略', summary='获取用户可选存储策略',
description='Get user selectable storage policies.', description='Get user selectable storage policies.',
) )
def router_user_settings_policies() -> ResponseModel: def router_user_settings_policies() -> models.response.ResponseModel:
""" """
Get user selectable storage policies. Get user selectable storage policies.
@@ -306,9 +331,9 @@ def router_user_settings_policies() -> ResponseModel:
path='/nodes', path='/nodes',
summary='获取用户可选节点', summary='获取用户可选节点',
description='Get user selectable nodes.', description='Get user selectable nodes.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_nodes() -> ResponseModel: def router_user_settings_nodes() -> models.response.ResponseModel:
""" """
Get user selectable nodes. Get user selectable nodes.
@@ -321,9 +346,9 @@ def router_user_settings_nodes() -> ResponseModel:
path='/tasks', path='/tasks',
summary='任务队列', summary='任务队列',
description='Get user task queue.', description='Get user task queue.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_tasks() -> ResponseModel: def router_user_settings_tasks() -> models.response.ResponseModel:
""" """
Get user task queue. Get user task queue.
@@ -336,24 +361,24 @@ def router_user_settings_tasks() -> ResponseModel:
path='/', path='/',
summary='获取当前用户设定', summary='获取当前用户设定',
description='Get current user settings.', description='Get current user settings.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings() -> ResponseModel: def router_user_settings() -> models.response.ResponseModel:
""" """
Get current user settings. Get current user settings.
Returns: Returns:
dict: A dictionary containing the current user settings. dict: A dictionary containing the current user settings.
""" """
return ResponseModel(data=UserSettingModel().model_dump()) return models.response.ResponseModel(data=models.response.UserSettingModel().model_dump())
@user_settings_router.post( @user_settings_router.post(
path='/avatar', path='/avatar',
summary='从文件上传头像', summary='从文件上传头像',
description='Upload user avatar from file.', description='Upload user avatar from file.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_avatar() -> ResponseModel: def router_user_settings_avatar() -> models.response.ResponseModel:
""" """
Upload user avatar from file. Upload user avatar from file.
@@ -366,9 +391,9 @@ def router_user_settings_avatar() -> ResponseModel:
path='/avatar', path='/avatar',
summary='设定为Gravatar头像', summary='设定为Gravatar头像',
description='Set user avatar to Gravatar.', description='Set user avatar to Gravatar.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_avatar_gravatar() -> ResponseModel: def router_user_settings_avatar_gravatar() -> models.response.ResponseModel:
""" """
Set user avatar to Gravatar. Set user avatar to Gravatar.
@@ -381,9 +406,9 @@ def router_user_settings_avatar_gravatar() -> ResponseModel:
path='/{option}', path='/{option}',
summary='更新用户设定', summary='更新用户设定',
description='Update user settings.', description='Update user settings.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_patch(option: str) -> ResponseModel: def router_user_settings_patch(option: str) -> models.response.ResponseModel:
""" """
Update user settings. Update user settings.
@@ -399,9 +424,9 @@ def router_user_settings_patch(option: str) -> ResponseModel:
path='/2fa', path='/2fa',
summary='获取两步验证初始化信息', summary='获取两步验证初始化信息',
description='Get two-factor authentication initialization information.', description='Get two-factor authentication initialization information.',
dependencies=[Depends(SignRequired)], dependencies=[Depends(AuthRequired)],
) )
def router_user_settings_2fa() -> ResponseModel: def router_user_settings_2fa() -> models.response.ResponseModel:
""" """
Get two-factor authentication initialization information. Get two-factor authentication initialization information.

6
service/oauth/qq.py Normal file
View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
import aiohttp
async def get_access_token(
code: str
)

View File

@@ -1,15 +1,11 @@
from typing import Optional from typing import Optional
from models.setting import Setting from models.setting import Setting
from models.request import LoginRequest
from models.response import TokenModel from models.response import TokenModel
from models.user import User from models.user import User
from pkg.log import log from pkg.log import log
async def Login( async def Login(LoginRequest: LoginRequest) -> tuple[bool, TokenModel | str]:
username: str,
password: str,
captcha: Optional[str] = None,
twoFaCode: Optional[str] = None
) -> tuple[bool, TokenModel | str]:
""" """
根据账号密码进行登录。 根据账号密码进行登录。
@@ -23,7 +19,7 @@ async def Login(
:type password: str :type password: str
:param captcha: 验证码 :param captcha: 验证码
:type captcha: Optional[str] :type captcha: Optional[str]
:param twoFaCode: 二次验证代码 :param twoFaCode: 两步验证代码
:type twoFaCode: Optional[str] :type twoFaCode: Optional[str]
:return: TokenModel 对象或状态码或 None :return: TokenModel 对象或状态码或 None
@@ -37,22 +33,22 @@ async def Login(
# [TODO] 验证码校验 # [TODO] 验证码校验
# 验证用户是否存在 # 验证用户是否存在
user = await User.get(email=username) user = await User.get(email=LoginRequest.username)
if not user: if not user:
log.debug(f"Cannot find user with email: {username}") log.debug(f"Cannot find user with email: {LoginRequest.username}")
return False, "User not found" return False, "User not found"
# 验证密码是否正确 # 验证密码是否正确
if not Password.verify(user.password, password): if not Password.verify(user.password, LoginRequest.password):
log.debug(f"Password verification failed for user: {username}") log.debug(f"Password verification failed for user: {LoginRequest.username}")
return False, "Incorrect password" return False, "Incorrect password"
# 验证用户是否可登录 # 验证用户是否可登录
if user.status == 1: if user.status == None:
# 未完成注册 # 未完成注册
return False, "Need to complete registration" return False, "Need to complete registration"
elif user.status == 2: elif user.status == False:
# 账号已被封禁 # 账号已被封禁
return False, "Account is banned" return False, "Account is banned"