diff --git a/main.py b/main.py index 584421c..61f0403 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ from pkg.conf import appmeta from models.database import init_db from models.migration import migration from pkg.lifespan import lifespan -from pkg.JWT import JWT +from pkg.JWT import jwt as JWT from pkg.log import log, set_log_level # 添加初始化数据库启动项 diff --git a/middleware/auth.py b/middleware/auth.py index 5e408c2..a0e6bf0 100644 --- a/middleware/auth.py +++ b/middleware/auth.py @@ -1,7 +1,7 @@ from typing import Annotated, Optional from fastapi import Depends, HTTPException from models.user import User -from pkg.JWT import JWT +from pkg.JWT import jwt as JWT import jwt from jwt import InvalidTokenError diff --git a/models/__init__.py b/models/__init__.py index fc977d1..b0b0142 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -3,7 +3,7 @@ from . import response # 将所有模型导入到这个包的命名空间中 -from .base import BaseModel +from .base import TableBase from .download import Download from .file import File from .folder import Folder @@ -24,7 +24,7 @@ from .webdav import WebDAV # 可以定义一个 __all__ 列表来明确指定可以被 from .models import * 导入的内容 __all__ = [ - "BaseModel", "Download", "File", "Folder", "Group", "Node", "Order", + "TableBase", "Download", "File", "Folder", "Group", "Node", "Order", "Policy", "Redeem", "Report", "Setting", "Share", "SourceLink", "StoragePack", "Tag", "Task", "User", "WebDAV" ] \ No newline at end of file diff --git a/models/base.py b/models/base.py index 9e318a8..91d7c63 100644 --- a/models/base.py +++ b/models/base.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncAttrs utcnow = lambda: datetime.now(tz=timezone.utc) -class BaseModel(SQLModel, AsyncAttrs): +class TableBase(SQLModel, AsyncAttrs): __abstract__ = True id: Optional[int] = Field(default=None, primary_key=True, description="主键ID") diff --git a/models/database.py b/models/database.py index b326872..53b4b51 100644 --- a/models/database.py +++ b/models/database.py @@ -1,5 +1,3 @@ -# my_project/database.py - from sqlmodel import SQLModel from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession diff --git a/models/download.py b/models/download.py index 8e1077b..d0b0f3e 100644 --- a/models/download.py +++ b/models/download.py @@ -2,7 +2,7 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: @@ -10,7 +10,7 @@ if TYPE_CHECKING: from .task import Task from .node import Node -class Download(BaseModel, table=True): +class Download(TableBase, table=True): __tablename__ = 'downloads' status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="下载状态: 0=进行中, 1=完成, 2=错误") diff --git a/models/file.py b/models/file.py index 33f1d9f..8c43ea8 100644 --- a/models/file.py +++ b/models/file.py @@ -2,7 +2,7 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: @@ -11,7 +11,7 @@ if TYPE_CHECKING: from .policy import Policy from .source_link import SourceLink -class File(BaseModel, table=True): +class File(TableBase, table=True): __tablename__ = 'files' name: str = Field(max_length=255, description="文件名") diff --git a/models/folder.py b/models/folder.py index 9eaa541..b2f66a4 100644 --- a/models/folder.py +++ b/models/folder.py @@ -2,7 +2,7 @@ from typing import Optional, List, TYPE_CHECKING from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: @@ -10,7 +10,7 @@ if TYPE_CHECKING: from .policy import Policy from .file import File -class Folder(BaseModel, table=True): +class Folder(TableBase, table=True): __tablename__ = 'folders' __table_args__ = (UniqueConstraint("name", "parent_id", name="uq_folder_name_parent"),) diff --git a/models/group.py b/models/group.py index 9f2609e..6e175c5 100644 --- a/models/group.py +++ b/models/group.py @@ -3,13 +3,13 @@ from tokenize import group from typing import Optional, List, TYPE_CHECKING from sqlmodel import Field, Relationship, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User -class Group(BaseModel, table=True): +class Group(TableBase, table=True): __tablename__ = 'groups' name: str = Field(max_length=255, unique=True, description="用户组名") @@ -17,6 +17,7 @@ class Group(BaseModel, table=True): max_storage: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="最大存储空间(字节)") share_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许创建分享") web_dav_enabled: bool = Field(default=False, sa_column_kwargs={"server_default": text("false")}, description="是否允许使用WebDAV") + admin: bool = Field(default=False, description="是否为管理员组") speed_limit: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="速度限制 (KB/s), 0为不限制") options: Optional[str] = Field(default=None, description="其他选项 (JSON格式)") diff --git a/models/node.py b/models/node.py index 8a39e9a..8618099 100644 --- a/models/node.py +++ b/models/node.py @@ -2,13 +2,13 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .download import Download -class Node(BaseModel, table=True): +class Node(TableBase, table=True): __tablename__ = 'nodes' status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="节点状态: 0=正常, 1=离线") diff --git a/models/order.py b/models/order.py index 2ebb0c5..dae2106 100644 --- a/models/order.py +++ b/models/order.py @@ -2,13 +2,13 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User -class Order(BaseModel, table=True): +class Order(TableBase, table=True): __tablename__ = 'orders' order_no: str = Field(max_length=255, unique=True, index=True, description="订单号,唯一") diff --git a/models/policy.py b/models/policy.py index 4909762..ac8e19f 100644 --- a/models/policy.py +++ b/models/policy.py @@ -2,14 +2,14 @@ from typing import Optional, List, TYPE_CHECKING from sqlmodel import Field, Relationship, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .file import File from .folder import Folder -class Policy(BaseModel, table=True): +class Policy(TableBase, table=True): __tablename__ = 'policies' name: str = Field(max_length=255, unique=True, description="策略名称") diff --git a/models/redeem.py b/models/redeem.py index 4b6e594..d23043a 100644 --- a/models/redeem.py +++ b/models/redeem.py @@ -2,10 +2,10 @@ from typing import Optional from sqlmodel import Field, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime -class Redeem(BaseModel, table=True): +class Redeem(TableBase, table=True): __tablename__ = 'redeems' type: int = Field(description="兑换码类型") diff --git a/models/report.py b/models/report.py index 764a038..1f081a1 100644 --- a/models/report.py +++ b/models/report.py @@ -2,13 +2,13 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .share import Share -class Report(BaseModel, table=True): +class Report(TableBase, table=True): __tablename__ = 'reports' reason: int = Field(description="举报原因代码") diff --git a/models/response.py b/models/response.py index fdb76df..9bdc2db 100644 --- a/models/response.py +++ b/models/response.py @@ -107,7 +107,7 @@ class UserSettingModel(BaseModel): two_factor: bool = Field(default=False, description="是否启用两步验证") uid: int = Field(default=0, description="用户UID") -class FoldObjectModel(BaseModel): +class ObjectModel(BaseModel): id: str = Field(default=..., description="对象ID") name: str = Field(default=..., description="对象名称") path: str = Field(default=..., description="对象路径") @@ -133,5 +133,5 @@ class DirectoryModel(BaseModel): 目录模型 ''' parent: str = Field(default=..., description="父目录ID") - objects: list[FoldObjectModel] = Field(default_factory=list, description="目录下的对象列表") + objects: list[ObjectModel] = Field(default_factory=list, description="目录下的对象列表") policy: PolicyModel = Field(default_factory=PolicyModel, description="存储策略") \ No newline at end of file diff --git a/models/setting.py b/models/setting.py index f6cb13c..5e1227a 100644 --- a/models/setting.py +++ b/models/setting.py @@ -2,7 +2,7 @@ from typing import Optional, Literal from sqlmodel import Field, UniqueConstraint, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime SETTINGS_TYPE = Literal[ @@ -34,7 +34,7 @@ SETTINGS_TYPE = Literal[ ] # 数据库模型 -class Setting(BaseModel, table=True): +class Setting(TableBase, table=True): __tablename__ = 'settings' __table_args__ = (UniqueConstraint("type", "name", name="uq_setting_type_name"),) diff --git a/models/share.py b/models/share.py index 2f4e2bf..13a249e 100644 --- a/models/share.py +++ b/models/share.py @@ -3,14 +3,14 @@ from typing import Optional, TYPE_CHECKING from datetime import datetime from sqlmodel import Field, Relationship, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User from .report import Report -class Share(BaseModel, table=True): +class Share(TableBase, table=True): __tablename__ = 'shares' password: Optional[str] = Field(default=None, max_length=255, description="分享密码(加密后)") diff --git a/models/source_link.py b/models/source_link.py index 9c5d3fb..c34ee58 100644 --- a/models/source_link.py +++ b/models/source_link.py @@ -2,13 +2,13 @@ from typing import TYPE_CHECKING, Optional from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .file import File -class SourceLink(BaseModel, table=True): +class SourceLink(TableBase, table=True): __tablename__ = 'source_links' name: str = Field(max_length=255, description="链接名称") diff --git a/models/storage_pack.py b/models/storage_pack.py index 2052ef2..6c3b8c6 100644 --- a/models/storage_pack.py +++ b/models/storage_pack.py @@ -3,13 +3,13 @@ from typing import Optional, TYPE_CHECKING from datetime import datetime from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User -class StoragePack(BaseModel, table=True): +class StoragePack(TableBase, table=True): __tablename__ = 'storage_packs' name: str = Field(max_length=255, description="容量包名称") diff --git a/models/tag.py b/models/tag.py index 9831813..519fe63 100644 --- a/models/tag.py +++ b/models/tag.py @@ -2,13 +2,13 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, UniqueConstraint, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User -class Tag(BaseModel, table=True): +class Tag(TableBase, table=True): __tablename__ = 'tags' __table_args__ = (UniqueConstraint("name", "user_id", name="uq_tag_name_user"),) diff --git a/models/task.py b/models/task.py index 8952df2..c6a81f4 100644 --- a/models/task.py +++ b/models/task.py @@ -2,14 +2,14 @@ from typing import Optional, TYPE_CHECKING from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from datetime import datetime if TYPE_CHECKING: from .user import User from .download import Download -class Task(BaseModel, table=True): +class Task(TableBase, table=True): __tablename__ = 'tasks' status: int = Field(default=0, sa_column_kwargs={"server_default": "0"}, description="任务状态: 0=排队中, 1=处理中, 2=完成, 3=错误") diff --git a/models/user.py b/models/user.py index 176e597..9972bbd 100644 --- a/models/user.py +++ b/models/user.py @@ -3,7 +3,7 @@ from typing import Optional, TYPE_CHECKING from datetime import datetime from sqlmodel import Field, Relationship, Column, func, DateTime -from .base import BaseModel +from .base import TableBase from .database import get_session from sqlmodel import select @@ -20,7 +20,7 @@ if TYPE_CHECKING: from .task import Task from .webdav import WebDAV -class User(BaseModel, table=True): +class User(TableBase, table=True): __tablename__ = 'users' email: str = Field(max_length=100, unique=True, index=True, description="用户邮箱,唯一") diff --git a/models/webdav.py b/models/webdav.py index 209ac32..53d360c 100644 --- a/models/webdav.py +++ b/models/webdav.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING from sqlmodel import Field, Relationship, UniqueConstraint, text, Column, func, DateTime -from .base import BaseModel +from .base import TableBase if TYPE_CHECKING: from .user import User -class WebDAV(BaseModel, table=True): +class WebDAV(TableBase, table=True): __tablename__ = 'webdavs' __table_args__ = (UniqueConstraint("name", "user_id", name="uq_webdav_name_user"),) diff --git a/service/user/login.py b/service/user/login.py index 29da851..8eb0ff4 100644 --- a/service/user/login.py +++ b/service/user/login.py @@ -1,4 +1,4 @@ -from typing import Optional +from pkg.JWT.jwt import create_access_token, create_refresh_token from models.setting import Setting from models.request import LoginRequest from models.response import TokenModel @@ -53,7 +53,6 @@ async def Login(LoginRequest: LoginRequest) -> tuple[bool, TokenModel | str]: return False, "Account is banned" # 创建令牌 - from pkg.JWT.JWT import create_access_token, create_refresh_token access_token, access_expire = create_access_token(data={'sub': user.email}) refresh_token, refresh_expire = create_refresh_token(data={'sub': user.email})