用户登录
This commit is contained in:
@@ -1,5 +1,3 @@
|
||||
# my_project/models/base.py
|
||||
|
||||
from typing import Optional
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
|
||||
@@ -10,11 +10,11 @@ from typing import AsyncGenerator
|
||||
ASYNC_DATABASE_URL = appmeta.database_url
|
||||
|
||||
engine: AsyncEngine = create_async_engine(
|
||||
ASYNC_DATABASE_URL,
|
||||
ASYNC_DATABASE_URL,
|
||||
echo=appmeta.debug,
|
||||
connect_args={"check_same_thread": False}
|
||||
if ASYNC_DATABASE_URL.startswith("sqlite")
|
||||
else None,
|
||||
connect_args={
|
||||
"check_same_thread": False
|
||||
} if ASYNC_DATABASE_URL.startswith("sqlite") else None,
|
||||
future=True,
|
||||
# pool_size=POOL_SIZE,
|
||||
# max_overflow=64,
|
||||
|
||||
@@ -92,6 +92,7 @@ class Group(BaseModel, table=True):
|
||||
try:
|
||||
session.add(group)
|
||||
await session.commit()
|
||||
await session.refresh(group)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
|
||||
@@ -212,7 +212,7 @@ async def init_default_user() -> None:
|
||||
admin_user = User(
|
||||
email="admin@yxqi.cn",
|
||||
nick="admin",
|
||||
status=1, # 正常状态
|
||||
status=0, # 正常状态
|
||||
group_id=admin_group.id,
|
||||
password=hashed_admin_password,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Literal, Union, Optional
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
@@ -7,6 +8,22 @@ class ResponseModel(BaseModel):
|
||||
data: Union[dict, list, str, int, float, None] = Field(None, description="响应数据")
|
||||
msg: Optional[str] = Field(default=None, description="响应消息,可以是错误消息或信息提示")
|
||||
instance_id: str = Field(default_factory=lambda: str(uuid4()), description="实例ID,用于标识请求的唯一性")
|
||||
|
||||
class TokenModel(BaseModel):
|
||||
access_expires: datetime = Field(default=None, description="访问令牌的过期时间")
|
||||
access_token: str = Field(default=None, description="访问令牌")
|
||||
refresh_expires: datetime = Field(default=None, description="刷新令牌的过期时间")
|
||||
refresh_token: str = Field(default=None, description="刷新令牌")
|
||||
|
||||
class userModel(ResponseModel):
|
||||
id: str = Field(default=None, description="用户ID")
|
||||
username: str = Field(default=None, description="用户名")
|
||||
email: Optional[str] = Field(default=None, description="用户邮箱")
|
||||
avatar: Optional[str] = Field(default=None, description="用户头像URL")
|
||||
is_active: bool = Field(default=True, description="用户是否激活")
|
||||
is_admin: bool = Field(default=False, description="用户是否为管理员")
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户创建时间")
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), description="账户更新时间")
|
||||
|
||||
class SiteConfigModel(ResponseModel):
|
||||
title: str = Field(default="DiskNext", description="网站标题")
|
||||
|
||||
@@ -4,6 +4,8 @@ from typing import Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
from sqlmodel import Field, Relationship, Column, func, DateTime
|
||||
from .base import BaseModel
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
# TYPE_CHECKING 用于解决循环导入问题,只在类型检查时导入
|
||||
if TYPE_CHECKING:
|
||||
@@ -130,6 +132,7 @@ class User(BaseModel, table=True):
|
||||
try:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
@@ -150,10 +153,6 @@ class User(BaseModel, table=True):
|
||||
:return: 用户对象或 None
|
||||
:rtype: Optional[User]
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
session = get_session()
|
||||
|
||||
if id is None and email is None:
|
||||
@@ -193,10 +192,6 @@ class User(BaseModel, table=True):
|
||||
:return: 更新后的用户对象
|
||||
:rtype: User
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
statement = select(User).where(User.id == id)
|
||||
@@ -248,9 +243,8 @@ class User(BaseModel, table=True):
|
||||
:param id: 用户ID
|
||||
:type id: int
|
||||
"""
|
||||
|
||||
from .database import get_session
|
||||
from sqlmodel import select
|
||||
if id == 1:
|
||||
raise ValueError("Cannot delete the default admin user with id 1.")
|
||||
|
||||
async for session in get_session():
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user