feat: add models for physical files, policies, and user management
- Implement PhysicalFile model to manage physical file references and reference counting. - Create Policy model with associated options and group links for storage policies. - Introduce Redeem and Report models for handling redeem codes and reports. - Add Settings model for site configuration and user settings management. - Develop Share model for sharing objects with unique codes and associated metadata. - Implement SourceLink model for managing download links associated with objects. - Create StoragePack model for managing user storage packages. - Add Tag model for user-defined tags with manual and automatic types. - Implement Task model for managing background tasks with status tracking. - Develop User model with comprehensive user management features including authentication. - Introduce UserAuthn model for managing WebAuthn credentials. - Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
@@ -49,13 +49,13 @@ def main():
|
||||
("itsdangerous", "签名工具"),
|
||||
|
||||
# 项目模块
|
||||
("models", "数据库模型"),
|
||||
("models.user", "用户模型"),
|
||||
("models.group", "用户组模型"),
|
||||
("models.object", "对象模型"),
|
||||
("models.setting", "设置模型"),
|
||||
("models.policy", "策略模型"),
|
||||
("models.database", "数据库连接"),
|
||||
("sqlmodels", "数据库模型"),
|
||||
("sqlmodels.user", "用户模型"),
|
||||
("sqlmodels.group", "用户组模型"),
|
||||
("sqlmodels.object", "对象模型"),
|
||||
("sqlmodels.setting", "设置模型"),
|
||||
("sqlmodels.policy", "策略模型"),
|
||||
("sqlmodels.database", "数据库连接"),
|
||||
("utils.password.pwd", "密码工具"),
|
||||
("utils.JWT.JWT", "JWT 工具"),
|
||||
("service.user.login", "登录服务"),
|
||||
|
||||
@@ -23,12 +23,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
|
||||
from main import app
|
||||
from models.database import get_session
|
||||
from models.group import Group, GroupOptions
|
||||
from models.migration import migration
|
||||
from models.object import Object, ObjectType
|
||||
from models.policy import Policy, PolicyType
|
||||
from models.user import User
|
||||
from sqlmodels.database import get_session
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
from sqlmodels.migration import migration
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
from sqlmodels.user import User
|
||||
from utils.JWT.JWT import create_access_token
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession):
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
创建测试用户并返回 {id, username, password, token}
|
||||
创建测试用户并返回 {id, email, password, token}
|
||||
|
||||
创建一个普通用户,包含用户组、存储策略和根目录。
|
||||
"""
|
||||
@@ -190,7 +190,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建测试用户
|
||||
password = "test_password_123"
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
@@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建用户根目录
|
||||
root_folder = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -216,7 +216,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"group_id": group.id,
|
||||
@@ -227,7 +227,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
"""
|
||||
获取管理员用户 {id, username, token}
|
||||
获取管理员用户 {id, email, token}
|
||||
|
||||
创建具有管理员权限的用户。
|
||||
"""
|
||||
@@ -267,7 +267,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
# 创建管理员用户
|
||||
password = "admin_password_456"
|
||||
admin = User(
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
nickname="管理员",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
@@ -279,7 +279,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
# 创建管理员根目录
|
||||
root_folder = Object(
|
||||
name=admin.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=admin.id,
|
||||
@@ -293,7 +293,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
|
||||
|
||||
return {
|
||||
"id": admin.id,
|
||||
"username": admin.username,
|
||||
"email": admin.email,
|
||||
"password": password,
|
||||
"token": access_token,
|
||||
"group_id": admin_group.id,
|
||||
|
||||
@@ -8,9 +8,9 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from models.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from tests.fixtures import UserFactory, GroupFactory, ObjectFactory
|
||||
|
||||
|
||||
@@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession):
|
||||
user = await UserFactory.create(
|
||||
db_session,
|
||||
group_id=group.id,
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password="password123"
|
||||
)
|
||||
|
||||
# 验证
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.group_id == group.id
|
||||
assert user.status is True
|
||||
|
||||
@@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession):
|
||||
async def test_object_factory(db_session: AsyncSession):
|
||||
"""测试对象工厂的基本功能"""
|
||||
# 准备依赖
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = await GroupFactory.create(db_session)
|
||||
user = await UserFactory.create(db_session, group_id=group.id)
|
||||
@@ -102,7 +102,7 @@ async def test_conftest_fixtures(
|
||||
"""测试 conftest.py 中的 fixtures"""
|
||||
# 验证 test_user fixture
|
||||
assert test_user["id"] is not None
|
||||
assert test_user["username"] == "testuser"
|
||||
assert test_user["email"] == "testuser@test.local"
|
||||
assert test_user["token"] is not None
|
||||
|
||||
# 验证 auth_headers fixture
|
||||
@@ -112,7 +112,7 @@ async def test_conftest_fixtures(
|
||||
# 验证用户在数据库中存在
|
||||
user = await User.get(db_session, User.id == test_user["id"])
|
||||
assert user is not None
|
||||
assert user.username == test_user["username"]
|
||||
assert user.email == test_user["email"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@@ -145,7 +145,7 @@ async def test_test_directory_fixture(
|
||||
@pytest.mark.integration
|
||||
async def test_nested_structure_factory(db_session: AsyncSession):
|
||||
"""测试嵌套结构工厂"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
# 准备依赖
|
||||
group = await GroupFactory.create(db_session)
|
||||
|
||||
2
tests/fixtures/groups.py
vendored
2
tests/fixtures/groups.py
vendored
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions
|
||||
from sqlmodels.group import Group, GroupOptions
|
||||
|
||||
|
||||
class GroupFactory:
|
||||
|
||||
6
tests/fixtures/objects.py
vendored
6
tests/fixtures/objects.py
vendored
@@ -7,8 +7,8 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
|
||||
|
||||
class ObjectFactory:
|
||||
@@ -119,7 +119,7 @@ class ObjectFactory:
|
||||
Object: 创建的根目录实例
|
||||
"""
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
|
||||
50
tests/fixtures/users.py
vendored
50
tests/fixtures/users.py
vendored
@@ -7,7 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from sqlmodels.user import User
|
||||
from utils.password.pwd import Password
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class UserFactory:
|
||||
async def create(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None,
|
||||
**kwargs
|
||||
) -> User:
|
||||
@@ -28,7 +28,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: test_user_{随机})
|
||||
email: 用户邮箱(默认: test_user_{随机}@test.local)
|
||||
password: 明文密码(默认: password123)
|
||||
**kwargs: 其他用户字段
|
||||
|
||||
@@ -37,15 +37,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"test_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"test_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
if password is None:
|
||||
password = "password123"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=kwargs.get("nickname", username),
|
||||
email=email,
|
||||
nickname=kwargs.get("nickname", email),
|
||||
password=Password.hash(password),
|
||||
status=kwargs.get("status", True),
|
||||
storage=kwargs.get("storage", 0),
|
||||
@@ -67,7 +67,7 @@ class UserFactory:
|
||||
async def create_admin(
|
||||
session: AsyncSession,
|
||||
admin_group_id: UUID,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
@@ -76,7 +76,7 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
admin_group_id: 管理员组UUID
|
||||
username: 用户名(默认: admin_{随机})
|
||||
email: 用户邮箱(默认: admin_{随机}@disknext.local)
|
||||
password: 明文密码(默认: admin_password)
|
||||
|
||||
返回:
|
||||
@@ -84,15 +84,15 @@ class UserFactory:
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"admin_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local"
|
||||
|
||||
if password is None:
|
||||
password = "admin_password"
|
||||
|
||||
admin = User(
|
||||
username=username,
|
||||
nickname=f"管理员 {username}",
|
||||
email=email,
|
||||
nickname=f"管理员 {email}",
|
||||
password=Password.hash(password),
|
||||
status=True,
|
||||
storage=0,
|
||||
@@ -108,7 +108,7 @@ class UserFactory:
|
||||
async def create_banned(
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建被封禁用户
|
||||
@@ -116,19 +116,19 @@ class UserFactory:
|
||||
参数:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
username: 用户名(默认: banned_user_{随机})
|
||||
email: 用户邮箱(默认: banned_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的被封禁用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"banned_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
banned_user = User(
|
||||
username=username,
|
||||
nickname=f"封禁用户 {username}",
|
||||
email=email,
|
||||
nickname=f"封禁用户 {email}",
|
||||
password=Password.hash("banned_password"),
|
||||
status=False, # 封禁状态
|
||||
storage=0,
|
||||
@@ -145,7 +145,7 @@ class UserFactory:
|
||||
session: AsyncSession,
|
||||
group_id: UUID,
|
||||
storage_bytes: int,
|
||||
username: str | None = None
|
||||
email: str | None = None
|
||||
) -> User:
|
||||
"""
|
||||
创建已使用指定存储空间的用户
|
||||
@@ -154,19 +154,19 @@ class UserFactory:
|
||||
session: 数据库会话
|
||||
group_id: 用户组UUID
|
||||
storage_bytes: 已使用的存储空间(字节)
|
||||
username: 用户名(默认: storage_user_{随机})
|
||||
email: 用户邮箱(默认: storage_user_{随机}@test.local)
|
||||
|
||||
返回:
|
||||
User: 创建的用户实例
|
||||
"""
|
||||
import uuid
|
||||
|
||||
if username is None:
|
||||
username = f"storage_user_{uuid.uuid4().hex[:8]}"
|
||||
if email is None:
|
||||
email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local"
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
nickname=username,
|
||||
email=email,
|
||||
nickname=email,
|
||||
password=Password.hash("password123"),
|
||||
status=True,
|
||||
storage=storage_bytes,
|
||||
|
||||
@@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data(
|
||||
if len(users) > 0:
|
||||
user = users[0]
|
||||
assert "id" in user
|
||||
assert "username" in user
|
||||
assert "email" in user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient):
|
||||
"""测试创建用户需要认证"""
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
@@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin(
|
||||
response = await async_client.post(
|
||||
"/api/admin/user/create",
|
||||
headers=auth_headers,
|
||||
json={"username": "newadminuser", "password": "pass123"}
|
||||
json={"email": "newadminuser@test.local", "password": "pass123"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from uuid import UUID
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_requires_auth(async_client: AsyncClient):
|
||||
"""测试获取目录需要认证"""
|
||||
response = await async_client.get("/api/directory/testuser")
|
||||
response = await async_client.get("/api/directory/")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ async def test_directory_get_root(
|
||||
):
|
||||
"""测试获取用户根目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -45,7 +45,7 @@ async def test_directory_get_nested(
|
||||
):
|
||||
"""测试获取嵌套目录"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -63,7 +63,7 @@ async def test_directory_get_contains_children(
|
||||
):
|
||||
"""测试目录包含子对象"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/docs",
|
||||
"/api/directory/docs",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -75,19 +75,6 @@ async def test_directory_get_contains_children(
|
||||
assert len(objects) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_forbidden_other_user(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试访问他人目录返回 403"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_not_found(
|
||||
async_client: AsyncClient,
|
||||
@@ -95,23 +82,23 @@ async def test_directory_not_found(
|
||||
):
|
||||
"""测试目录不存在返回 404"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser/nonexistent",
|
||||
"/api/directory/nonexistent",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_directory_empty_path_returns_400(
|
||||
async def test_directory_root_returns_200(
|
||||
async_client: AsyncClient,
|
||||
auth_headers: dict[str, str]
|
||||
):
|
||||
"""测试空路径返回 400"""
|
||||
"""测试根目录端点返回 200"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 400
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -121,7 +108,7 @@ async def test_directory_response_includes_policy(
|
||||
):
|
||||
"""测试目录响应包含存储策略"""
|
||||
response = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response.status_code == 200
|
||||
@@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent(
|
||||
"""测试在他人目录下创建目录返回 404"""
|
||||
# 先用管理员账号获取管理员的根目录ID
|
||||
admin_response = await async_client.get(
|
||||
"/api/directory/admin",
|
||||
"/api/directory/",
|
||||
headers=admin_headers
|
||||
)
|
||||
assert admin_response.status_code == 200
|
||||
|
||||
@@ -16,7 +16,7 @@ async def test_user_login_success(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": test_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -38,7 +38,7 @@ async def test_user_login_wrong_password(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": test_user_info["username"],
|
||||
"username": test_user_info["email"],
|
||||
"password": "wrongpassword",
|
||||
}
|
||||
)
|
||||
@@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": "nonexistent",
|
||||
"username": "nonexistent@test.local",
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -67,7 +67,7 @@ async def test_user_login_user_banned(
|
||||
response = await async_client.post(
|
||||
"/api/user/session",
|
||||
data={
|
||||
"username": banned_user_info["username"],
|
||||
"username": banned_user_info["email"],
|
||||
"password": banned_user_info["password"],
|
||||
}
|
||||
)
|
||||
@@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": "newuser",
|
||||
"email": "newuser@test.local",
|
||||
"password": "newpass123",
|
||||
}
|
||||
)
|
||||
@@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient):
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "user_id" in data["data"]
|
||||
assert "username" in data["data"]
|
||||
assert data["data"]["username"] == "newuser"
|
||||
assert "email" in data["data"]
|
||||
assert data["data"]["email"] == "newuser@test.local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_register_duplicate_username(
|
||||
async def test_user_register_duplicate_email(
|
||||
async_client: AsyncClient,
|
||||
test_user_info: dict[str, str]
|
||||
):
|
||||
"""测试重复用户名返回 400"""
|
||||
"""测试重复邮箱返回 400"""
|
||||
response = await async_client.post(
|
||||
"/api/user/",
|
||||
json={
|
||||
"username": test_user_info["username"],
|
||||
"email": test_user_info["email"],
|
||||
"password": "anypassword",
|
||||
}
|
||||
)
|
||||
@@ -143,8 +143,8 @@ async def test_user_me_returns_user_info(
|
||||
assert "data" in data
|
||||
user_data = data["data"]
|
||||
assert "id" in user_data
|
||||
assert "username" in user_data
|
||||
assert user_data["username"] == "testuser"
|
||||
assert "email" in user_data
|
||||
assert user_data["email"] == "testuser@test.local"
|
||||
assert "group" in user_data
|
||||
assert "tags" in user_data
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
|
||||
from main import app
|
||||
from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
|
||||
from utils import Password
|
||||
from utils.JWT import create_access_token
|
||||
from utils.JWT import JWT
|
||||
@@ -92,6 +92,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
Setting(type=SettingsType.VIEW, name="home_view_method", value="list"),
|
||||
Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"),
|
||||
Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""),
|
||||
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
|
||||
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
|
||||
@@ -180,7 +181,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 6. 创建测试用户
|
||||
test_user = User(
|
||||
id=uuid4(),
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
password=Password.hash("testpass123"),
|
||||
nickname="测试用户",
|
||||
status=True,
|
||||
@@ -194,7 +195,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
admin_user = User(
|
||||
id=uuid4(),
|
||||
username="admin",
|
||||
email="admin@disknext.local",
|
||||
password=Password.hash("adminpass123"),
|
||||
nickname="管理员",
|
||||
status=True,
|
||||
@@ -208,7 +209,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
banned_user = User(
|
||||
id=uuid4(),
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("banned123"),
|
||||
nickname="封禁用户",
|
||||
status=False, # 封禁状态
|
||||
@@ -230,7 +231,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
# 7. 创建用户根目录
|
||||
test_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=test_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=test_user.id,
|
||||
parent_id=None,
|
||||
@@ -241,7 +242,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
|
||||
admin_user_root = Object(
|
||||
id=uuid4(),
|
||||
name=admin_user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
owner_id=admin_user.id,
|
||||
parent_id=None,
|
||||
@@ -264,7 +265,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
|
||||
def test_user_info() -> dict[str, str]:
|
||||
"""测试用户信息"""
|
||||
return {
|
||||
"username": "testuser",
|
||||
"email": "testuser@test.local",
|
||||
"password": "testpass123",
|
||||
}
|
||||
|
||||
@@ -273,7 +274,7 @@ def test_user_info() -> dict[str, str]:
|
||||
def admin_user_info() -> dict[str, str]:
|
||||
"""管理员用户信息"""
|
||||
return {
|
||||
"username": "admin",
|
||||
"email": "admin@disknext.local",
|
||||
"password": "adminpass123",
|
||||
}
|
||||
|
||||
@@ -282,7 +283,7 @@ def admin_user_info() -> dict[str, str]:
|
||||
def banned_user_info() -> dict[str, str]:
|
||||
"""封禁用户信息"""
|
||||
return {
|
||||
"username": "banneduser",
|
||||
"email": "banneduser@test.local",
|
||||
"password": "banned123",
|
||||
}
|
||||
|
||||
@@ -293,7 +294,7 @@ def banned_user_info() -> dict[str, str]:
|
||||
def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
"""生成测试用户的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
@@ -303,7 +304,7 @@ def test_user_token(test_user_info: dict[str, str]) -> str:
|
||||
def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
"""生成管理员的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": admin_user_info["username"]},
|
||||
data={"sub": admin_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1),
|
||||
)
|
||||
return token
|
||||
@@ -313,7 +314,7 @@ def admin_user_token(admin_user_info: dict[str, str]) -> str:
|
||||
def expired_token() -> str:
|
||||
"""生成过期的JWT token"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "testuser"},
|
||||
data={"sub": "testuser@test.local"},
|
||||
expires_delta=timedelta(seconds=-1), # 已过期
|
||||
)
|
||||
return token
|
||||
@@ -362,7 +363,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU
|
||||
"""创建测试目录结构"""
|
||||
|
||||
# 获取测试用户和根目录
|
||||
test_user = await User.get(initialized_db, User.username == "testuser")
|
||||
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
|
||||
test_user_root = await Object.get_root(initialized_db, test_user.id)
|
||||
|
||||
default_policy = await Policy.get(initialized_db, Policy.name == "本地存储")
|
||||
|
||||
@@ -83,7 +83,7 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
|
||||
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
|
||||
"""测试用户不存在的token返回 401"""
|
||||
token, _ = JWT.create_access_token(
|
||||
data={"sub": "nonexistent_user"},
|
||||
data={"sub": "nonexistent_user@test.local"},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
|
||||
@@ -178,12 +178,12 @@ async def test_auth_on_directory_endpoint(
|
||||
):
|
||||
"""测试目录端点应用认证"""
|
||||
# 无认证
|
||||
response_no_auth = await async_client.get("/api/directory/testuser")
|
||||
response_no_auth = await async_client.get("/api/directory/")
|
||||
assert response_no_auth.status_code == 401
|
||||
|
||||
# 有认证
|
||||
response_with_auth = await async_client.get(
|
||||
"/api/directory/testuser",
|
||||
"/api/directory/",
|
||||
headers=auth_headers
|
||||
)
|
||||
assert response_with_auth.status_code == 200
|
||||
@@ -235,7 +235,7 @@ async def test_auth_on_storage_endpoint(
|
||||
async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
"""测试刷新token格式正确"""
|
||||
refresh_token, _ = JWT.create_refresh_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(days=7)
|
||||
)
|
||||
|
||||
@@ -247,7 +247,7 @@ async def test_refresh_token_format(test_user_info: dict[str, str]):
|
||||
async def test_access_token_format(test_user_info: dict[str, str]):
|
||||
"""测试访问token格式正确"""
|
||||
access_token, expires = JWT.create_access_token(
|
||||
data={"sub": test_user_info["username"]},
|
||||
data={"sub": test_user_info["email"]},
|
||||
expires_delta=timedelta(hours=1)
|
||||
)
|
||||
|
||||
|
||||
@@ -3,14 +3,14 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_db():
|
||||
"""测试创建数据库结构"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session():
|
||||
"""测试获取数据库连接Session"""
|
||||
from models import database
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
@@ -20,8 +20,8 @@ async def db_session():
|
||||
@pytest.mark.asyncio
|
||||
async def test_migration():
|
||||
"""测试数据库创建并初始化配置"""
|
||||
from models import migration
|
||||
from models import database
|
||||
from sqlmodels import migration
|
||||
from sqlmodels import database
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database
|
||||
from models.setting import Setting
|
||||
from sqlmodels import database
|
||||
from sqlmodels.setting import Setting
|
||||
|
||||
await database.init_db(url='sqlite:///:memory:')
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ import pytest
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_curd():
|
||||
"""测试数据库的增删改查"""
|
||||
from models import database, migration
|
||||
from models.group import Group
|
||||
from models.user import User
|
||||
from sqlmodels import database, migration
|
||||
from sqlmodels.group import Group
|
||||
from sqlmodels.user import User
|
||||
|
||||
await database.init_db(url='sqlite+aiosqlite:///:memory:')
|
||||
|
||||
@@ -17,7 +17,7 @@ async def test_user_curd():
|
||||
created_group = await test_user_group.save(session)
|
||||
|
||||
test_user = User(
|
||||
username='test_user',
|
||||
email='test_user@test.local',
|
||||
password='test_password',
|
||||
group_id=created_group.id
|
||||
)
|
||||
@@ -27,7 +27,7 @@ async def test_user_curd():
|
||||
|
||||
# 验证用户是否存在
|
||||
assert created_user.id is not None
|
||||
assert created_user.username == 'test_user'
|
||||
assert created_user.email == 'test_user@test.local'
|
||||
assert created_user.password == 'test_password'
|
||||
assert created_user.group_id == created_group.id
|
||||
|
||||
@@ -35,18 +35,18 @@ async def test_user_curd():
|
||||
fetched_user = await User.get(session, User.id == created_user.id)
|
||||
|
||||
assert fetched_user is not None
|
||||
assert fetched_user.username == 'test_user'
|
||||
assert fetched_user.email == 'test_user@test.local'
|
||||
assert fetched_user.password == 'test_password'
|
||||
assert fetched_user.group_id == created_group.id
|
||||
|
||||
# 测试改 Update
|
||||
updated_user = await fetched_user.update(
|
||||
session,
|
||||
{"username": "updated_user", "password": "updated_password"}
|
||||
{"email": "updated_user@test.local", "password": "updated_password"}
|
||||
)
|
||||
|
||||
assert updated_user is not None
|
||||
assert updated_user.username == 'updated_user'
|
||||
assert updated_user.email == 'updated_user@test.local'
|
||||
assert updated_user.password == 'updated_password'
|
||||
|
||||
# 测试删除 Delete
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
from fastapi import HTTPException
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 更新数据
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后名称")
|
||||
updated_group = await group.update(db_session, update_data)
|
||||
|
||||
@@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 更新记录
|
||||
from models.group import GroupBase
|
||||
from sqlmodels.group import GroupBase
|
||||
update_data = GroupBase(name="更新后的名称")
|
||||
group = await group.update(db_session, update_data)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.group import Group, GroupOptions, GroupResponse
|
||||
from sqlmodels.group import Group, GroupOptions, GroupResponse
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -5,21 +5,21 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.object import Object, ObjectType
|
||||
from models.user import User
|
||||
from models.group import Group
|
||||
from sqlmodels.object import Object, ObjectType
|
||||
from sqlmodels.user import User
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_folder(db_session: AsyncSession):
|
||||
"""测试创建目录"""
|
||||
# 创建必要的依赖数据
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_create_file(db_session: AsyncSession):
|
||||
"""测试创建文件"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(
|
||||
@@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=1024,
|
||||
source_name="test_source.txt"
|
||||
)
|
||||
file = await file.save(db_session)
|
||||
|
||||
@@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession):
|
||||
assert file.name == "test.txt"
|
||||
assert file.type == ObjectType.FILE
|
||||
assert file.size == 1024
|
||||
assert file.source_name == "test_source.txt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_file_property(db_session: AsyncSession):
|
||||
"""测试 is_file 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
"""测试 is_folder 属性"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="testuser", password="password", group_id=group.id)
|
||||
user = User(email="testuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_root(db_session: AsyncSession):
|
||||
"""测试 get_root() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="rootuser", password="password", group_id=group.id)
|
||||
user = User(email="rootuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
"""测试获取根目录"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="pathuser", password="password", group_id=group.id)
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
root = await root.save(db_session)
|
||||
|
||||
# 通过路径获取根目录
|
||||
result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username)
|
||||
result = await Object.get_by_path(db_session, user.id, "/")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == root.id
|
||||
@@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
"""测试获取嵌套路径"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="nesteduser", password="password", group_id=group.id)
|
||||
user = User(email="nesteduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
|
||||
# 创建目录结构: root -> docs -> work -> project
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/nesteduser/docs/work/project",
|
||||
user.username
|
||||
"/docs/work/project",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
@@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
"""测试路径不存在"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="notfounduser", password="password", group_id=group.id)
|
||||
user = User(email="notfounduser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
|
||||
# 创建根目录
|
||||
root = Object(
|
||||
name=user.username,
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
@@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
result = await Object.get_by_path(
|
||||
db_session,
|
||||
user.id,
|
||||
"/notfounduser/nonexistent",
|
||||
user.username
|
||||
"/nonexistent",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_children(db_session: AsyncSession):
|
||||
"""测试 get_children() 方法"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="childrenuser", password="password", group_id=group.id)
|
||||
user = User(email="childrenuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
"""测试父子关系"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="reluser", password="password", group_id=group.id)
|
||||
user = User(email="reluser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
"""测试同目录名称唯一约束"""
|
||||
from models.policy import Policy, PolicyType
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(username="uniqueuser", password="password", group_id=group.id)
|
||||
user = User(email="uniqueuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
@@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession):
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await file2.save(db_session)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_object_get_full_path(db_session: AsyncSession):
|
||||
"""测试 get_full_path() 方法"""
|
||||
from sqlmodels.policy import Policy, PolicyType
|
||||
|
||||
group = Group(name="测试组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(email="pathuser", password="password", group_id=group.id)
|
||||
user = await user.save(db_session)
|
||||
|
||||
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
|
||||
policy = await policy.save(db_session)
|
||||
|
||||
# 创建目录结构: root -> docs -> images -> photo.jpg
|
||||
root = Object(
|
||||
name="/",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=None,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
root = await root.save(db_session)
|
||||
|
||||
docs = Object(
|
||||
name="docs",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=root.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
docs = await docs.save(db_session)
|
||||
|
||||
images = Object(
|
||||
name="images",
|
||||
type=ObjectType.FOLDER,
|
||||
parent_id=docs.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id
|
||||
)
|
||||
images = await images.save(db_session)
|
||||
|
||||
photo = Object(
|
||||
name="photo.jpg",
|
||||
type=ObjectType.FILE,
|
||||
parent_id=images.id,
|
||||
owner_id=user.id,
|
||||
policy_id=policy.id,
|
||||
size=2048
|
||||
)
|
||||
photo = await photo.save(db_session)
|
||||
|
||||
# 测试完整路径
|
||||
full_path = await photo.get_full_path(db_session)
|
||||
assert full_path == "/docs/images/photo.jpg"
|
||||
|
||||
# 测试根目录的 full_path
|
||||
root_path = await root.get_full_path(db_session)
|
||||
assert root_path == "/"
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.setting import Setting, SettingsType
|
||||
from sqlmodels.setting import Setting, SettingsType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
|
||||
setting = await setting.save(db_session)
|
||||
|
||||
# 更新值
|
||||
from models.base import SQLModelBase
|
||||
from sqlmodels.base import SQLModelBase
|
||||
|
||||
class SettingUpdate(SQLModelBase):
|
||||
value: str | None = None
|
||||
|
||||
273
tests/unit/models/test_uri.py
Normal file
273
tests/unit/models/test_uri.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
DiskNextURI 模型的单元测试
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from sqlmodels.uri import DiskNextURI, FileSystemNamespace
|
||||
|
||||
|
||||
class TestDiskNextURIParse:
|
||||
"""测试 URI 解析"""
|
||||
|
||||
def test_parse_my_root(self):
|
||||
"""测试解析个人空间根目录"""
|
||||
uri = DiskNextURI.parse("disknext://my/")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
assert uri.password is None
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_my_with_path(self):
|
||||
"""测试解析个人空间带路径"""
|
||||
uri = DiskNextURI.parse("disknext://my/docs/readme.md")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/docs/readme.md"
|
||||
assert uri.fs_id is None
|
||||
assert uri.path_parts == ["docs", "readme.md"]
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_parse_my_with_fs_id(self):
|
||||
"""测试解析带 fs_id 的个人空间"""
|
||||
uri = DiskNextURI.parse("disknext://some-uuid@my/docs")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.fs_id == "some-uuid"
|
||||
assert uri.path == "/docs"
|
||||
|
||||
def test_parse_share_with_code(self):
|
||||
"""测试解析分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123@share/")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.path == "/"
|
||||
assert uri.password is None
|
||||
|
||||
def test_parse_share_with_password(self):
|
||||
"""测试解析带密码的分享链接"""
|
||||
uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir")
|
||||
assert uri.namespace == FileSystemNamespace.SHARE
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "mypass"
|
||||
assert uri.path == "/sub/dir"
|
||||
|
||||
def test_parse_trash(self):
|
||||
"""测试解析回收站"""
|
||||
uri = DiskNextURI.parse("disknext://trash/")
|
||||
assert uri.namespace == FileSystemNamespace.TRASH
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_parse_with_query(self):
|
||||
"""测试解析带查询参数的 URI"""
|
||||
uri = DiskNextURI.parse("disknext://my/?name=report&type=file")
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.query is not None
|
||||
assert uri.query["name"] == "report"
|
||||
assert uri.query["type"] == "file"
|
||||
|
||||
def test_parse_invalid_scheme(self):
|
||||
"""测试无效的协议前缀"""
|
||||
with pytest.raises(ValueError, match="disknext://"):
|
||||
DiskNextURI.parse("http://my/docs")
|
||||
|
||||
def test_parse_invalid_namespace(self):
|
||||
"""测试无效的命名空间"""
|
||||
with pytest.raises(ValueError, match="无效的命名空间"):
|
||||
DiskNextURI.parse("disknext://invalid/docs")
|
||||
|
||||
def test_parse_no_namespace(self):
|
||||
"""测试缺少命名空间"""
|
||||
with pytest.raises(ValueError):
|
||||
DiskNextURI.parse("disknext://")
|
||||
|
||||
|
||||
class TestDiskNextURIBuild:
|
||||
"""测试 URI 构建"""
|
||||
|
||||
def test_build_simple(self):
|
||||
"""测试简单构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.namespace == FileSystemNamespace.MY
|
||||
assert uri.path == "/"
|
||||
assert uri.fs_id is None
|
||||
|
||||
def test_build_with_path(self):
|
||||
"""测试带路径构建"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_path_auto_prefix(self):
|
||||
"""测试路径自动添加 / 前缀"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md")
|
||||
assert uri.path == "/docs/readme.md"
|
||||
|
||||
def test_build_with_fs_id(self):
|
||||
"""测试带 fs_id 构建"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="abc123",
|
||||
password="secret",
|
||||
)
|
||||
assert uri.fs_id == "abc123"
|
||||
assert uri.password == "secret"
|
||||
|
||||
|
||||
class TestDiskNextURIToString:
|
||||
"""测试 URI 序列化"""
|
||||
|
||||
def test_to_string_simple(self):
|
||||
"""测试简单序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.to_string() == "disknext://my/"
|
||||
|
||||
def test_to_string_with_path(self):
|
||||
"""测试带路径序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.to_string() == "disknext://my/docs/readme.md"
|
||||
|
||||
def test_to_string_with_fs_id(self):
|
||||
"""测试带 fs_id 序列化"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123")
|
||||
assert uri.to_string() == "disknext://uuid-123@my/"
|
||||
|
||||
def test_to_string_with_password(self):
|
||||
"""测试带密码序列化"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.SHARE,
|
||||
fs_id="code",
|
||||
password="pass",
|
||||
)
|
||||
assert uri.to_string() == "disknext://code:pass@share/"
|
||||
|
||||
def test_to_string_roundtrip(self):
|
||||
"""测试序列化-反序列化往返"""
|
||||
original = "disknext://abc123:pass@share/sub/dir"
|
||||
uri = DiskNextURI.parse(original)
|
||||
result = uri.to_string()
|
||||
assert result == original
|
||||
|
||||
|
||||
class TestDiskNextURIId:
|
||||
"""测试 id() 方法"""
|
||||
|
||||
def test_id_with_fs_id(self):
|
||||
"""测试有 fs_id 时返回 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid")
|
||||
assert uri.id("default") == "my-uuid"
|
||||
|
||||
def test_id_without_fs_id(self):
|
||||
"""测试无 fs_id 时返回默认值"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id("default-uuid") == "default-uuid"
|
||||
|
||||
def test_id_without_fs_id_no_default(self):
|
||||
"""测试无 fs_id 且无默认值时返回 None"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
assert uri.id() is None
|
||||
|
||||
|
||||
class TestDiskNextURIJoin:
|
||||
"""测试 join() 方法"""
|
||||
|
||||
def test_join_single(self):
|
||||
"""测试拼接单个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
joined = uri.join("readme.md")
|
||||
assert joined.path == "/docs/readme.md"
|
||||
|
||||
def test_join_multiple(self):
|
||||
"""测试拼接多个路径元素"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY)
|
||||
joined = uri.join("docs", "work", "report.pdf")
|
||||
assert joined.path == "/docs/work/report.pdf"
|
||||
|
||||
def test_join_preserves_metadata(self):
|
||||
"""测试 join 保留 namespace 和 fs_id"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123")
|
||||
joined = uri.join("sub")
|
||||
assert joined.namespace == FileSystemNamespace.SHARE
|
||||
assert joined.fs_id == "code123"
|
||||
|
||||
|
||||
class TestDiskNextURIDirUri:
|
||||
"""测试 dir_uri() 方法"""
|
||||
|
||||
def test_dir_uri_file(self):
|
||||
"""测试获取文件的父目录 URI"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/docs/"
|
||||
|
||||
def test_dir_uri_root(self):
|
||||
"""测试根目录的 dir_uri 返回自身"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
parent = uri.dir_uri()
|
||||
assert parent.path == "/"
|
||||
|
||||
|
||||
class TestDiskNextURIRoot:
|
||||
"""测试 root() 方法"""
|
||||
|
||||
def test_root_resets_path(self):
|
||||
"""测试 root 重置路径"""
|
||||
uri = DiskNextURI.build(
|
||||
FileSystemNamespace.MY,
|
||||
path="/docs/work/report.pdf",
|
||||
fs_id="uuid-123",
|
||||
)
|
||||
root = uri.root()
|
||||
assert root.path == "/"
|
||||
assert root.fs_id == "uuid-123"
|
||||
assert root.namespace == FileSystemNamespace.MY
|
||||
|
||||
|
||||
class TestDiskNextURIName:
|
||||
"""测试 name() 方法"""
|
||||
|
||||
def test_name_file(self):
|
||||
"""测试获取文件名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
|
||||
assert uri.name() == "readme.md"
|
||||
|
||||
def test_name_directory(self):
|
||||
"""测试获取目录名"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work")
|
||||
assert uri.name() == "work"
|
||||
|
||||
def test_name_root(self):
|
||||
"""测试根目录的 name 返回空字符串"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.name() == ""
|
||||
|
||||
|
||||
class TestDiskNextURIProperties:
|
||||
"""测试属性方法"""
|
||||
|
||||
def test_path_parts(self):
|
||||
"""测试路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf")
|
||||
assert uri.path_parts == ["docs", "work", "report.pdf"]
|
||||
|
||||
def test_path_parts_root(self):
|
||||
"""测试根路径分割"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.path_parts == []
|
||||
|
||||
def test_is_root_true(self):
|
||||
"""测试 is_root 为真"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
|
||||
assert uri.is_root is True
|
||||
|
||||
def test_is_root_false(self):
|
||||
"""测试 is_root 为假"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert uri.is_root is False
|
||||
|
||||
def test_str_representation(self):
|
||||
"""测试字符串表示"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert str(uri) == "disknext://my/docs"
|
||||
|
||||
def test_repr(self):
|
||||
"""测试 repr"""
|
||||
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
|
||||
assert "disknext://my/docs" in repr(uri)
|
||||
@@ -5,8 +5,8 @@ import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, ThemeType, UserPublic
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, ThemeType, UserPublic
|
||||
from sqlmodels.group import Group
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="testuser",
|
||||
email="testuser@test.local",
|
||||
nickname="测试用户",
|
||||
password="hashed_password",
|
||||
group_id=group.id
|
||||
@@ -26,7 +26,7 @@ async def test_user_create(db_session: AsyncSession):
|
||||
user = await user.save(db_session)
|
||||
|
||||
assert user.id is not None
|
||||
assert user.username == "testuser"
|
||||
assert user.email == "testuser@test.local"
|
||||
assert user.nickname == "测试用户"
|
||||
assert user.status is True
|
||||
assert user.storage == 0
|
||||
@@ -34,15 +34,15 @@ async def test_user_create(db_session: AsyncSession):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_unique_username(db_session: AsyncSession):
|
||||
"""测试用户名唯一约束"""
|
||||
async def test_user_unique_email(db_session: AsyncSession):
|
||||
"""测试邮箱唯一约束"""
|
||||
# 创建用户组
|
||||
group = Group(name="默认组")
|
||||
group = await group.save(db_session)
|
||||
|
||||
# 创建第一个用户
|
||||
user1 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password1",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession):
|
||||
|
||||
# 尝试创建同名用户
|
||||
user2 = User(
|
||||
username="duplicate",
|
||||
email="duplicate@test.local",
|
||||
password="password2",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="publicuser",
|
||||
email="publicuser@test.local",
|
||||
nickname="公开用户",
|
||||
password="secret_password",
|
||||
storage=1024,
|
||||
@@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession):
|
||||
|
||||
assert isinstance(public_user, UserPublic)
|
||||
assert public_user.id == user.id
|
||||
assert public_user.username == "publicuser"
|
||||
assert public_user.email == "publicuser@test.local"
|
||||
# 注意: UserPublic.nick 字段名与 User.nickname 不同,
|
||||
# model_validate 不会自动映射,所以 nick 为 None
|
||||
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
|
||||
@@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession):
|
||||
|
||||
# 创建用户
|
||||
user = User(
|
||||
username="vipuser",
|
||||
email="vipuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -125,7 +125,7 @@ async def test_user_status_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="defaultuser",
|
||||
email="defaultuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="storageuser",
|
||||
email="storageuser@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试默认值
|
||||
user1 = User(
|
||||
username="user1",
|
||||
email="user1@test.local",
|
||||
password="password",
|
||||
group_id=group.id
|
||||
)
|
||||
@@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 LIGHT
|
||||
user2 = User(
|
||||
username="user2",
|
||||
email="user2@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.LIGHT,
|
||||
group_id=group.id
|
||||
@@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
|
||||
|
||||
# 测试设置为 DARK
|
||||
user3 = User(
|
||||
username="user3",
|
||||
email="user3@test.local",
|
||||
password="password",
|
||||
theme=ThemeType.DARK,
|
||||
group_id=group.id
|
||||
|
||||
@@ -4,8 +4,8 @@ Login 服务的单元测试
|
||||
import pytest
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
|
||||
from models.user import User, LoginRequest, TokenResponse
|
||||
from models.group import Group
|
||||
from sqlmodels.user import User, LoginRequest, TokenResponse
|
||||
from sqlmodels.group import Group
|
||||
from service.user.login import login
|
||||
from utils.password.pwd import Password
|
||||
|
||||
@@ -20,7 +20,7 @@ async def setup_user(db_session: AsyncSession):
|
||||
# 创建正常用户
|
||||
plain_password = "secure_password_123"
|
||||
user = User(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=Password.hash(plain_password),
|
||||
status=True,
|
||||
group_id=group.id
|
||||
@@ -41,7 +41,7 @@ async def setup_banned_user(db_session: AsyncSession):
|
||||
group = await group.save(db_session)
|
||||
|
||||
user = User(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=False, # 封禁状态
|
||||
group_id=group.id
|
||||
@@ -61,7 +61,7 @@ async def setup_2fa_user(db_session: AsyncSession):
|
||||
|
||||
secret = pyotp.random_base32()
|
||||
user = User(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=Password.hash("password"),
|
||||
status=True,
|
||||
two_factor=secret,
|
||||
@@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_not_found(db_session: AsyncSession):
|
||||
"""测试用户不存在"""
|
||||
login_request = LoginRequest(
|
||||
username="nonexistent_user",
|
||||
email="nonexistent@test.local",
|
||||
password="any_password"
|
||||
)
|
||||
|
||||
@@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
|
||||
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
"""测试密码错误"""
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password="wrong_password"
|
||||
)
|
||||
|
||||
@@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
|
||||
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
|
||||
"""测试用户被封禁"""
|
||||
login_request = LoginRequest(
|
||||
username="banneduser",
|
||||
email="banneduser@test.local",
|
||||
password="password"
|
||||
)
|
||||
|
||||
@@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"]
|
||||
# 未提供 two_fa_code
|
||||
)
|
||||
@@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
|
||||
user_data = setup_2fa_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code="000000" # 错误的验证码
|
||||
)
|
||||
@@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
|
||||
valid_code = totp.now()
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="2fauser",
|
||||
email="2fauser@test.local",
|
||||
password=user_data["password"],
|
||||
two_fa_code=valid_code
|
||||
)
|
||||
@@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
user_data = setup_user
|
||||
|
||||
login_request = LoginRequest(
|
||||
username="loginuser",
|
||||
email="loginuser@test.local",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
@@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user):
|
||||
"""测试用户名大小写敏感"""
|
||||
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
|
||||
"""测试邮箱大小写敏感"""
|
||||
user_data = setup_user
|
||||
|
||||
# 使用大写用户名登录(如果数据库是 loginuser)
|
||||
# 使用大写邮箱登录
|
||||
login_request = LoginRequest(
|
||||
username="LOGINUSER",
|
||||
email="LOGINUSER@TEST.LOCAL",
|
||||
password=user_data["password"]
|
||||
)
|
||||
|
||||
result = await login(db_session, login_request)
|
||||
|
||||
# 应该失败,因为用户名大小写不匹配
|
||||
# 应该失败,因为邮箱大小写不匹配
|
||||
assert result is None
|
||||
|
||||
@@ -72,9 +72,9 @@ def test_password_verify_expired():
|
||||
@pytest.mark.asyncio
|
||||
async def test_totp_generate():
|
||||
"""测试 TOTP 密钥生成"""
|
||||
username = "testuser"
|
||||
email = "testuser@test.local"
|
||||
|
||||
response = await Password.generate_totp(username)
|
||||
response = await Password.generate_totp(email)
|
||||
|
||||
assert response.setup_token is not None
|
||||
assert response.uri is not None
|
||||
@@ -82,7 +82,7 @@ async def test_totp_generate():
|
||||
assert isinstance(response.uri, str)
|
||||
# TOTP URI 格式: otpauth://totp/...
|
||||
assert response.uri.startswith("otpauth://totp/")
|
||||
assert username in response.uri
|
||||
assert email in response.uri
|
||||
|
||||
|
||||
def test_totp_verify_valid():
|
||||
|
||||
Reference in New Issue
Block a user