feat: add models for physical files, policies, and user management

- Implement PhysicalFile model to manage physical file references and reference counting.
- Create Policy model with associated options and group links for storage policies.
- Introduce Redeem and Report models for handling redeem codes and reports.
- Add Settings model for site configuration and user settings management.
- Develop Share model for sharing objects with unique codes and associated metadata.
- Implement SourceLink model for managing download links associated with objects.
- Create StoragePack model for managing user storage packages.
- Add Tag model for user-defined tags with manual and automatic types.
- Implement Task model for managing background tasks with status tracking.
- Develop User model with comprehensive user management features including authentication.
- Introduce UserAuthn model for managing WebAuthn credentials.
- Create WebDAV model for managing WebDAV accounts associated with users.
This commit is contained in:
2026-02-10 16:25:49 +08:00
parent 62c671e07b
commit 209cb24ab4
92 changed files with 3640 additions and 1444 deletions

View File

@@ -49,13 +49,13 @@ def main():
("itsdangerous", "签名工具"),
# 项目模块
("models", "数据库模型"),
("models.user", "用户模型"),
("models.group", "用户组模型"),
("models.object", "对象模型"),
("models.setting", "设置模型"),
("models.policy", "策略模型"),
("models.database", "数据库连接"),
("sqlmodels", "数据库模型"),
("sqlmodels.user", "用户模型"),
("sqlmodels.group", "用户组模型"),
("sqlmodels.object", "对象模型"),
("sqlmodels.setting", "设置模型"),
("sqlmodels.policy", "策略模型"),
("sqlmodels.database", "数据库连接"),
("utils.password.pwd", "密码工具"),
("utils.JWT.JWT", "JWT 工具"),
("service.user.login", "登录服务"),

View File

@@ -23,12 +23,12 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from main import app
from models.database import get_session
from models.group import Group, GroupOptions
from models.migration import migration
from models.object import Object, ObjectType
from models.policy import Policy, PolicyType
from models.user import User
from sqlmodels.database import get_session
from sqlmodels.group import Group, GroupOptions
from sqlmodels.migration import migration
from sqlmodels.object import Object, ObjectType
from sqlmodels.policy import Policy, PolicyType
from sqlmodels.user import User
from utils.JWT.JWT import create_access_token
from utils.password.pwd import Password
@@ -153,7 +153,7 @@ def override_get_session(db_session: AsyncSession):
@pytest_asyncio.fixture(scope="function")
async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
"""
创建测试用户并返回 {id, username, password, token}
创建测试用户并返回 {id, email, password, token}
创建一个普通用户,包含用户组、存储策略和根目录。
"""
@@ -190,7 +190,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建测试用户
password = "test_password_123"
user = User(
username="testuser",
email="testuser@test.local",
nickname="测试用户",
password=Password.hash(password),
status=True,
@@ -202,7 +202,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建用户根目录
root_folder = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -216,7 +216,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
return {
"id": user.id,
"username": user.username,
"email": user.email,
"password": password,
"token": access_token,
"group_id": group.id,
@@ -227,7 +227,7 @@ async def test_user(db_session: AsyncSession) -> dict[str, str | UUID]:
@pytest_asyncio.fixture(scope="function")
async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
"""
获取管理员用户 {id, username, token}
获取管理员用户 {id, email, token}
创建具有管理员权限的用户。
"""
@@ -267,7 +267,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建管理员用户
password = "admin_password_456"
admin = User(
username="admin",
email="admin@disknext.local",
nickname="管理员",
password=Password.hash(password),
status=True,
@@ -279,7 +279,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
# 创建管理员根目录
root_folder = Object(
name=admin.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=admin.id,
@@ -293,7 +293,7 @@ async def admin_user(db_session: AsyncSession) -> dict[str, str | UUID]:
return {
"id": admin.id,
"username": admin.username,
"email": admin.email,
"password": password,
"token": access_token,
"group_id": admin_group.id,

View File

@@ -8,9 +8,9 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User
from models.group import Group
from models.object import Object, ObjectType
from sqlmodels.user import User
from sqlmodels.group import Group
from sqlmodels.object import Object, ObjectType
from tests.fixtures import UserFactory, GroupFactory, ObjectFactory
@@ -24,13 +24,13 @@ async def test_user_factory(db_session: AsyncSession):
user = await UserFactory.create(
db_session,
group_id=group.id,
username="testuser",
email="testuser@test.local",
password="password123"
)
# 验证
assert user.id is not None
assert user.username == "testuser"
assert user.email == "testuser@test.local"
assert user.group_id == group.id
assert user.status is True
@@ -51,7 +51,7 @@ async def test_group_factory(db_session: AsyncSession):
async def test_object_factory(db_session: AsyncSession):
"""测试对象工厂的基本功能"""
# 准备依赖
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = await GroupFactory.create(db_session)
user = await UserFactory.create(db_session, group_id=group.id)
@@ -102,7 +102,7 @@ async def test_conftest_fixtures(
"""测试 conftest.py 中的 fixtures"""
# 验证 test_user fixture
assert test_user["id"] is not None
assert test_user["username"] == "testuser"
assert test_user["email"] == "testuser@test.local"
assert test_user["token"] is not None
# 验证 auth_headers fixture
@@ -112,7 +112,7 @@ async def test_conftest_fixtures(
# 验证用户在数据库中存在
user = await User.get(db_session, User.id == test_user["id"])
assert user is not None
assert user.username == test_user["username"]
assert user.email == test_user["email"]
@pytest.mark.integration
@@ -145,7 +145,7 @@ async def test_test_directory_fixture(
@pytest.mark.integration
async def test_nested_structure_factory(db_session: AsyncSession):
"""测试嵌套结构工厂"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
# 准备依赖
group = await GroupFactory.create(db_session)

View File

@@ -5,7 +5,7 @@
"""
from sqlmodel.ext.asyncio.session import AsyncSession
from models.group import Group, GroupOptions
from sqlmodels.group import Group, GroupOptions
class GroupFactory:

View File

@@ -7,8 +7,8 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from models.object import Object, ObjectType
from models.user import User
from sqlmodels.object import Object, ObjectType
from sqlmodels.user import User
class ObjectFactory:
@@ -119,7 +119,7 @@ class ObjectFactory:
Object: 创建的根目录实例
"""
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,

View File

@@ -7,7 +7,7 @@ from uuid import UUID
from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User
from sqlmodels.user import User
from utils.password.pwd import Password
@@ -18,7 +18,7 @@ class UserFactory:
async def create(
session: AsyncSession,
group_id: UUID,
username: str | None = None,
email: str | None = None,
password: str | None = None,
**kwargs
) -> User:
@@ -28,7 +28,7 @@ class UserFactory:
参数:
session: 数据库会话
group_id: 用户组UUID
username: 用户(默认: test_user_{随机}
email: 用户邮箱(默认: test_user_{随机}@test.local
password: 明文密码(默认: password123
**kwargs: 其他用户字段
@@ -37,15 +37,15 @@ class UserFactory:
"""
import uuid
if username is None:
username = f"test_user_{uuid.uuid4().hex[:8]}"
if email is None:
email = f"test_user_{uuid.uuid4().hex[:8]}@test.local"
if password is None:
password = "password123"
user = User(
username=username,
nickname=kwargs.get("nickname", username),
email=email,
nickname=kwargs.get("nickname", email),
password=Password.hash(password),
status=kwargs.get("status", True),
storage=kwargs.get("storage", 0),
@@ -67,7 +67,7 @@ class UserFactory:
async def create_admin(
session: AsyncSession,
admin_group_id: UUID,
username: str | None = None,
email: str | None = None,
password: str | None = None
) -> User:
"""
@@ -76,7 +76,7 @@ class UserFactory:
参数:
session: 数据库会话
admin_group_id: 管理员组UUID
username: 用户(默认: admin_{随机}
email: 用户邮箱(默认: admin_{随机}@disknext.local
password: 明文密码(默认: admin_password
返回:
@@ -84,15 +84,15 @@ class UserFactory:
"""
import uuid
if username is None:
username = f"admin_{uuid.uuid4().hex[:8]}"
if email is None:
email = f"admin_{uuid.uuid4().hex[:8]}@disknext.local"
if password is None:
password = "admin_password"
admin = User(
username=username,
nickname=f"管理员 {username}",
email=email,
nickname=f"管理员 {email}",
password=Password.hash(password),
status=True,
storage=0,
@@ -108,7 +108,7 @@ class UserFactory:
async def create_banned(
session: AsyncSession,
group_id: UUID,
username: str | None = None
email: str | None = None
) -> User:
"""
创建被封禁用户
@@ -116,19 +116,19 @@ class UserFactory:
参数:
session: 数据库会话
group_id: 用户组UUID
username: 用户(默认: banned_user_{随机}
email: 用户邮箱(默认: banned_user_{随机}@test.local
返回:
User: 创建的被封禁用户实例
"""
import uuid
if username is None:
username = f"banned_user_{uuid.uuid4().hex[:8]}"
if email is None:
email = f"banned_user_{uuid.uuid4().hex[:8]}@test.local"
banned_user = User(
username=username,
nickname=f"封禁用户 {username}",
email=email,
nickname=f"封禁用户 {email}",
password=Password.hash("banned_password"),
status=False, # 封禁状态
storage=0,
@@ -145,7 +145,7 @@ class UserFactory:
session: AsyncSession,
group_id: UUID,
storage_bytes: int,
username: str | None = None
email: str | None = None
) -> User:
"""
创建已使用指定存储空间的用户
@@ -154,19 +154,19 @@ class UserFactory:
session: 数据库会话
group_id: 用户组UUID
storage_bytes: 已使用的存储空间(字节)
username: 用户(默认: storage_user_{随机}
email: 用户邮箱(默认: storage_user_{随机}@test.local
返回:
User: 创建的用户实例
"""
import uuid
if username is None:
username = f"storage_user_{uuid.uuid4().hex[:8]}"
if email is None:
email = f"storage_user_{uuid.uuid4().hex[:8]}@test.local"
user = User(
username=username,
nickname=username,
email=email,
nickname=email,
password=Password.hash("password123"),
status=True,
storage=storage_bytes,

View File

@@ -124,7 +124,7 @@ async def test_admin_get_user_list_contains_user_data(
if len(users) > 0:
user = users[0]
assert "id" in user
assert "username" in user
assert "email" in user
@pytest.mark.asyncio
@@ -132,7 +132,7 @@ async def test_admin_create_user_requires_auth(async_client: AsyncClient):
"""测试创建用户需要认证"""
response = await async_client.post(
"/api/admin/user/create",
json={"username": "newadminuser", "password": "pass123"}
json={"email": "newadminuser@test.local", "password": "pass123"}
)
assert response.status_code == 401
@@ -146,7 +146,7 @@ async def test_admin_create_user_requires_admin(
response = await async_client.post(
"/api/admin/user/create",
headers=auth_headers,
json={"username": "newadminuser", "password": "pass123"}
json={"email": "newadminuser@test.local", "password": "pass123"}
)
assert response.status_code == 403

View File

@@ -11,7 +11,7 @@ from uuid import UUID
@pytest.mark.asyncio
async def test_directory_requires_auth(async_client: AsyncClient):
"""测试获取目录需要认证"""
response = await async_client.get("/api/directory/testuser")
response = await async_client.get("/api/directory/")
assert response.status_code == 401
@@ -24,7 +24,7 @@ async def test_directory_get_root(
):
"""测试获取用户根目录"""
response = await async_client.get(
"/api/directory/testuser",
"/api/directory/",
headers=auth_headers
)
assert response.status_code == 200
@@ -45,7 +45,7 @@ async def test_directory_get_nested(
):
"""测试获取嵌套目录"""
response = await async_client.get(
"/api/directory/testuser/docs",
"/api/directory/docs",
headers=auth_headers
)
assert response.status_code == 200
@@ -63,7 +63,7 @@ async def test_directory_get_contains_children(
):
"""测试目录包含子对象"""
response = await async_client.get(
"/api/directory/testuser/docs",
"/api/directory/docs",
headers=auth_headers
)
assert response.status_code == 200
@@ -75,19 +75,6 @@ async def test_directory_get_contains_children(
assert len(objects) >= 1
@pytest.mark.asyncio
async def test_directory_forbidden_other_user(
async_client: AsyncClient,
auth_headers: dict[str, str]
):
"""测试访问他人目录返回 403"""
response = await async_client.get(
"/api/directory/admin",
headers=auth_headers
)
assert response.status_code == 403
@pytest.mark.asyncio
async def test_directory_not_found(
async_client: AsyncClient,
@@ -95,23 +82,23 @@ async def test_directory_not_found(
):
"""测试目录不存在返回 404"""
response = await async_client.get(
"/api/directory/testuser/nonexistent",
"/api/directory/nonexistent",
headers=auth_headers
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_directory_empty_path_returns_400(
async def test_directory_root_returns_200(
async_client: AsyncClient,
auth_headers: dict[str, str]
):
"""测试空路径返回 400"""
"""测试根目录端点返回 200"""
response = await async_client.get(
"/api/directory/",
headers=auth_headers
)
assert response.status_code == 400
assert response.status_code == 200
@pytest.mark.asyncio
@@ -121,7 +108,7 @@ async def test_directory_response_includes_policy(
):
"""测试目录响应包含存储策略"""
response = await async_client.get(
"/api/directory/testuser",
"/api/directory/",
headers=auth_headers
)
assert response.status_code == 200
@@ -284,7 +271,7 @@ async def test_directory_create_other_user_parent(
"""测试在他人目录下创建目录返回 404"""
# 先用管理员账号获取管理员的根目录ID
admin_response = await async_client.get(
"/api/directory/admin",
"/api/directory/",
headers=admin_headers
)
assert admin_response.status_code == 200

View File

@@ -16,7 +16,7 @@ async def test_user_login_success(
response = await async_client.post(
"/api/user/session",
data={
"username": test_user_info["username"],
"username": test_user_info["email"],
"password": test_user_info["password"],
}
)
@@ -38,7 +38,7 @@ async def test_user_login_wrong_password(
response = await async_client.post(
"/api/user/session",
data={
"username": test_user_info["username"],
"username": test_user_info["email"],
"password": "wrongpassword",
}
)
@@ -51,7 +51,7 @@ async def test_user_login_nonexistent_user(async_client: AsyncClient):
response = await async_client.post(
"/api/user/session",
data={
"username": "nonexistent",
"username": "nonexistent@test.local",
"password": "anypassword",
}
)
@@ -67,7 +67,7 @@ async def test_user_login_user_banned(
response = await async_client.post(
"/api/user/session",
data={
"username": banned_user_info["username"],
"username": banned_user_info["email"],
"password": banned_user_info["password"],
}
)
@@ -82,7 +82,7 @@ async def test_user_register_success(async_client: AsyncClient):
response = await async_client.post(
"/api/user/",
json={
"username": "newuser",
"email": "newuser@test.local",
"password": "newpass123",
}
)
@@ -91,20 +91,20 @@ async def test_user_register_success(async_client: AsyncClient):
data = response.json()
assert "data" in data
assert "user_id" in data["data"]
assert "username" in data["data"]
assert data["data"]["username"] == "newuser"
assert "email" in data["data"]
assert data["data"]["email"] == "newuser@test.local"
@pytest.mark.asyncio
async def test_user_register_duplicate_username(
async def test_user_register_duplicate_email(
async_client: AsyncClient,
test_user_info: dict[str, str]
):
"""测试重复用户名返回 400"""
"""测试重复邮箱返回 400"""
response = await async_client.post(
"/api/user/",
json={
"username": test_user_info["username"],
"email": test_user_info["email"],
"password": "anypassword",
}
)
@@ -143,8 +143,8 @@ async def test_user_me_returns_user_info(
assert "data" in data
user_data = data["data"]
assert "id" in user_data
assert "username" in user_data
assert user_data["username"] == "testuser"
assert "email" in user_data
assert user_data["email"] == "testuser@test.local"
assert "group" in user_data
assert "tags" in user_data

View File

@@ -22,7 +22,7 @@ from sqlalchemy.orm import sessionmaker
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from main import app
from models import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from sqlmodels import Group, GroupOptions, Object, ObjectType, Policy, PolicyType, Setting, SettingsType, User
from utils import Password
from utils.JWT import create_access_token
from utils.JWT import JWT
@@ -92,6 +92,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
Setting(type=SettingsType.VIEW, name="home_view_method", value="list"),
Setting(type=SettingsType.VIEW, name="share_view_method", value="grid"),
Setting(type=SettingsType.AUTHN, name="authn_enabled", value="0"),
Setting(type=SettingsType.CAPTCHA, name="captcha_type", value="default"),
Setting(type=SettingsType.CAPTCHA, name="captcha_ReCaptchaKey", value=""),
Setting(type=SettingsType.CAPTCHA, name="captcha_CloudflareKey", value=""),
Setting(type=SettingsType.REGISTER, name="register_enabled", value="1"),
@@ -180,7 +181,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 6. 创建测试用户
test_user = User(
id=uuid4(),
username="testuser",
email="testuser@test.local",
password=Password.hash("testpass123"),
nickname="测试用户",
status=True,
@@ -194,7 +195,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
admin_user = User(
id=uuid4(),
username="admin",
email="admin@disknext.local",
password=Password.hash("adminpass123"),
nickname="管理员",
status=True,
@@ -208,7 +209,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
banned_user = User(
id=uuid4(),
username="banneduser",
email="banneduser@test.local",
password=Password.hash("banned123"),
nickname="封禁用户",
status=False, # 封禁状态
@@ -230,7 +231,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
# 7. 创建用户根目录
test_user_root = Object(
id=uuid4(),
name=test_user.username,
name="/",
type=ObjectType.FOLDER,
owner_id=test_user.id,
parent_id=None,
@@ -241,7 +242,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
admin_user_root = Object(
id=uuid4(),
name=admin_user.username,
name="/",
type=ObjectType.FOLDER,
owner_id=admin_user.id,
parent_id=None,
@@ -264,7 +265,7 @@ async def initialized_db(test_session: AsyncSession) -> AsyncSession:
def test_user_info() -> dict[str, str]:
"""测试用户信息"""
return {
"username": "testuser",
"email": "testuser@test.local",
"password": "testpass123",
}
@@ -273,7 +274,7 @@ def test_user_info() -> dict[str, str]:
def admin_user_info() -> dict[str, str]:
"""管理员用户信息"""
return {
"username": "admin",
"email": "admin@disknext.local",
"password": "adminpass123",
}
@@ -282,7 +283,7 @@ def admin_user_info() -> dict[str, str]:
def banned_user_info() -> dict[str, str]:
"""封禁用户信息"""
return {
"username": "banneduser",
"email": "banneduser@test.local",
"password": "banned123",
}
@@ -293,7 +294,7 @@ def banned_user_info() -> dict[str, str]:
def test_user_token(test_user_info: dict[str, str]) -> str:
"""生成测试用户的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": test_user_info["username"]},
data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1),
)
return token
@@ -303,7 +304,7 @@ def test_user_token(test_user_info: dict[str, str]) -> str:
def admin_user_token(admin_user_info: dict[str, str]) -> str:
"""生成管理员的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": admin_user_info["username"]},
data={"sub": admin_user_info["email"]},
expires_delta=timedelta(hours=1),
)
return token
@@ -313,7 +314,7 @@ def admin_user_token(admin_user_info: dict[str, str]) -> str:
def expired_token() -> str:
"""生成过期的JWT token"""
token, _ = JWT.create_access_token(
data={"sub": "testuser"},
data={"sub": "testuser@test.local"},
expires_delta=timedelta(seconds=-1), # 已过期
)
return token
@@ -362,7 +363,7 @@ async def test_directory_structure(initialized_db: AsyncSession) -> dict[str, UU
"""创建测试目录结构"""
# 获取测试用户和根目录
test_user = await User.get(initialized_db, User.username == "testuser")
test_user = await User.get(initialized_db, User.email == "testuser@test.local")
test_user_root = await Object.get_root(initialized_db, test_user.id)
default_policy = await Policy.get(initialized_db, Policy.name == "本地存储")

View File

@@ -83,7 +83,7 @@ async def test_auth_required_token_without_sub(async_client: AsyncClient):
async def test_auth_required_nonexistent_user_token(async_client: AsyncClient):
"""测试用户不存在的token返回 401"""
token, _ = JWT.create_access_token(
data={"sub": "nonexistent_user"},
data={"sub": "nonexistent_user@test.local"},
expires_delta=timedelta(hours=1)
)
@@ -178,12 +178,12 @@ async def test_auth_on_directory_endpoint(
):
"""测试目录端点应用认证"""
# 无认证
response_no_auth = await async_client.get("/api/directory/testuser")
response_no_auth = await async_client.get("/api/directory/")
assert response_no_auth.status_code == 401
# 有认证
response_with_auth = await async_client.get(
"/api/directory/testuser",
"/api/directory/",
headers=auth_headers
)
assert response_with_auth.status_code == 200
@@ -235,7 +235,7 @@ async def test_auth_on_storage_endpoint(
async def test_refresh_token_format(test_user_info: dict[str, str]):
"""测试刷新token格式正确"""
refresh_token, _ = JWT.create_refresh_token(
data={"sub": test_user_info["username"]},
data={"sub": test_user_info["email"]},
expires_delta=timedelta(days=7)
)
@@ -247,7 +247,7 @@ async def test_refresh_token_format(test_user_info: dict[str, str]):
async def test_access_token_format(test_user_info: dict[str, str]):
"""测试访问token格式正确"""
access_token, expires = JWT.create_access_token(
data={"sub": test_user_info["username"]},
data={"sub": test_user_info["email"]},
expires_delta=timedelta(hours=1)
)

View File

@@ -3,14 +3,14 @@ import pytest
@pytest.mark.asyncio
async def test_initialize_db():
"""测试创建数据库结构"""
from models import database
from sqlmodels import database
await database.init_db(url='sqlite:///:memory:')
@pytest.fixture
async def db_session():
"""测试获取数据库连接Session"""
from models import database
from sqlmodels import database
await database.init_db(url='sqlite:///:memory:')
@@ -20,8 +20,8 @@ async def db_session():
@pytest.mark.asyncio
async def test_migration():
"""测试数据库创建并初始化配置"""
from models import migration
from models import database
from sqlmodels import migration
from sqlmodels import database
await database.init_db(url='sqlite:///:memory:')

View File

@@ -3,8 +3,8 @@ import pytest
@pytest.mark.asyncio
async def test_group_curd():
"""测试数据库的增删改查"""
from models import database, migration
from models.group import Group
from sqlmodels import database, migration
from sqlmodels.group import Group
await database.init_db(url='sqlite+aiosqlite:///:memory:')

View File

@@ -3,8 +3,8 @@ import pytest
@pytest.mark.asyncio
async def test_settings_curd():
"""测试数据库的增删改查"""
from models import database
from models.setting import Setting
from sqlmodels import database
from sqlmodels.setting import Setting
await database.init_db(url='sqlite:///:memory:')

View File

@@ -3,9 +3,9 @@ import pytest
@pytest.mark.asyncio
async def test_user_curd():
"""测试数据库的增删改查"""
from models import database, migration
from models.group import Group
from models.user import User
from sqlmodels import database, migration
from sqlmodels.group import Group
from sqlmodels.user import User
await database.init_db(url='sqlite+aiosqlite:///:memory:')
@@ -17,7 +17,7 @@ async def test_user_curd():
created_group = await test_user_group.save(session)
test_user = User(
username='test_user',
email='test_user@test.local',
password='test_password',
group_id=created_group.id
)
@@ -27,7 +27,7 @@ async def test_user_curd():
# 验证用户是否存在
assert created_user.id is not None
assert created_user.username == 'test_user'
assert created_user.email == 'test_user@test.local'
assert created_user.password == 'test_password'
assert created_user.group_id == created_group.id
@@ -35,18 +35,18 @@ async def test_user_curd():
fetched_user = await User.get(session, User.id == created_user.id)
assert fetched_user is not None
assert fetched_user.username == 'test_user'
assert fetched_user.email == 'test_user@test.local'
assert fetched_user.password == 'test_password'
assert fetched_user.group_id == created_group.id
# 测试改 Update
updated_user = await fetched_user.update(
session,
{"username": "updated_user", "password": "updated_password"}
{"email": "updated_user@test.local", "password": "updated_password"}
)
assert updated_user is not None
assert updated_user.username == 'updated_user'
assert updated_user.email == 'updated_user@test.local'
assert updated_user.password == 'updated_password'
# 测试删除 Delete

View File

@@ -8,8 +8,8 @@ import pytest
from fastapi import HTTPException
from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User
from models.group import Group
from sqlmodels.user import User
from sqlmodels.group import Group
@pytest.mark.asyncio
@@ -62,7 +62,7 @@ async def test_table_base_update(db_session: AsyncSession):
group = await group.save(db_session)
# 更新数据
from models.group import GroupBase
from sqlmodels.group import GroupBase
update_data = GroupBase(name="更新后名称")
updated_group = await group.update(db_session, update_data)
@@ -200,7 +200,7 @@ async def test_timestamps_auto_update(db_session: AsyncSession):
await asyncio.sleep(0.1)
# 更新记录
from models.group import GroupBase
from sqlmodels.group import GroupBase
update_data = GroupBase(name="更新后的名称")
group = await group.update(db_session, update_data)

View File

@@ -4,7 +4,7 @@ Group 和 GroupOptions 模型的单元测试
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from models.group import Group, GroupOptions, GroupResponse
from sqlmodels.group import Group, GroupOptions, GroupResponse
@pytest.mark.asyncio

View File

@@ -5,21 +5,21 @@ import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from models.object import Object, ObjectType
from models.user import User
from models.group import Group
from sqlmodels.object import Object, ObjectType
from sqlmodels.user import User
from sqlmodels.group import Group
@pytest.mark.asyncio
async def test_object_create_folder(db_session: AsyncSession):
"""测试创建目录"""
# 创建必要的依赖数据
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id)
user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(
@@ -48,12 +48,12 @@ async def test_object_create_folder(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_create_file(db_session: AsyncSession):
"""测试创建文件"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id)
user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(
@@ -65,7 +65,7 @@ async def test_object_create_file(db_session: AsyncSession):
# 创建根目录
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -81,7 +81,6 @@ async def test_object_create_file(db_session: AsyncSession):
owner_id=user.id,
policy_id=policy.id,
size=1024,
source_name="test_source.txt"
)
file = await file.save(db_session)
@@ -89,18 +88,17 @@ async def test_object_create_file(db_session: AsyncSession):
assert file.name == "test.txt"
assert file.type == ObjectType.FILE
assert file.size == 1024
assert file.source_name == "test_source.txt"
@pytest.mark.asyncio
async def test_object_is_file_property(db_session: AsyncSession):
"""测试 is_file 属性"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id)
user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -122,12 +120,12 @@ async def test_object_is_file_property(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_is_folder_property(db_session: AsyncSession):
"""测试 is_folder 属性"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="testuser", password="password", group_id=group.id)
user = User(email="testuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -148,12 +146,12 @@ async def test_object_is_folder_property(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_get_root(db_session: AsyncSession):
"""测试 get_root() 方法"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="rootuser", password="password", group_id=group.id)
user = User(email="rootuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -161,7 +159,7 @@ async def test_object_get_root(db_session: AsyncSession):
# 创建根目录
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -180,12 +178,12 @@ async def test_object_get_root(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_get_by_path_root(db_session: AsyncSession):
"""测试获取根目录"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="pathuser", password="password", group_id=group.id)
user = User(email="pathuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -193,7 +191,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
# 创建根目录
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -202,7 +200,7 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
root = await root.save(db_session)
# 通过路径获取根目录
result = await Object.get_by_path(db_session, user.id, "/pathuser", user.username)
result = await Object.get_by_path(db_session, user.id, "/")
assert result is not None
assert result.id == root.id
@@ -211,12 +209,12 @@ async def test_object_get_by_path_root(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_get_by_path_nested(db_session: AsyncSession):
"""测试获取嵌套路径"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="nesteduser", password="password", group_id=group.id)
user = User(email="nesteduser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -224,7 +222,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
# 创建目录结构: root -> docs -> work -> project
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -263,8 +261,7 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
result = await Object.get_by_path(
db_session,
user.id,
"/nesteduser/docs/work/project",
user.username
"/docs/work/project",
)
assert result is not None
@@ -275,12 +272,12 @@ async def test_object_get_by_path_nested(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_get_by_path_not_found(db_session: AsyncSession):
"""测试路径不存在"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="notfounduser", password="password", group_id=group.id)
user = User(email="notfounduser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -288,7 +285,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
# 创建根目录
root = Object(
name=user.username,
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
@@ -300,8 +297,7 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
result = await Object.get_by_path(
db_session,
user.id,
"/notfounduser/nonexistent",
user.username
"/nonexistent",
)
assert result is None
@@ -310,12 +306,12 @@ async def test_object_get_by_path_not_found(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_get_children(db_session: AsyncSession):
"""测试 get_children() 方法"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="childrenuser", password="password", group_id=group.id)
user = User(email="childrenuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -362,12 +358,12 @@ async def test_object_get_children(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_parent_child_relationship(db_session: AsyncSession):
"""测试父子关系"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="reluser", password="password", group_id=group.id)
user = User(email="reluser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -407,12 +403,12 @@ async def test_object_parent_child_relationship(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_object_unique_constraint(db_session: AsyncSession):
"""测试同目录名称唯一约束"""
from models.policy import Policy, PolicyType
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(username="uniqueuser", password="password", group_id=group.id)
user = User(email="uniqueuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
@@ -450,3 +446,64 @@ async def test_object_unique_constraint(db_session: AsyncSession):
with pytest.raises(IntegrityError):
await file2.save(db_session)
@pytest.mark.asyncio
async def test_object_get_full_path(db_session: AsyncSession):
"""测试 get_full_path() 方法"""
from sqlmodels.policy import Policy, PolicyType
group = Group(name="测试组")
group = await group.save(db_session)
user = User(email="pathuser", password="password", group_id=group.id)
user = await user.save(db_session)
policy = Policy(name="本地策略", type=PolicyType.LOCAL, server="/tmp/test")
policy = await policy.save(db_session)
# 创建目录结构: root -> docs -> images -> photo.jpg
root = Object(
name="/",
type=ObjectType.FOLDER,
parent_id=None,
owner_id=user.id,
policy_id=policy.id
)
root = await root.save(db_session)
docs = Object(
name="docs",
type=ObjectType.FOLDER,
parent_id=root.id,
owner_id=user.id,
policy_id=policy.id
)
docs = await docs.save(db_session)
images = Object(
name="images",
type=ObjectType.FOLDER,
parent_id=docs.id,
owner_id=user.id,
policy_id=policy.id
)
images = await images.save(db_session)
photo = Object(
name="photo.jpg",
type=ObjectType.FILE,
parent_id=images.id,
owner_id=user.id,
policy_id=policy.id,
size=2048
)
photo = await photo.save(db_session)
# 测试完整路径
full_path = await photo.get_full_path(db_session)
assert full_path == "/docs/images/photo.jpg"
# 测试根目录的 full_path
root_path = await root.get_full_path(db_session)
assert root_path == "/"

View File

@@ -5,7 +5,7 @@ import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from models.setting import Setting, SettingsType
from sqlmodels.setting import Setting, SettingsType
@pytest.mark.asyncio
@@ -113,7 +113,7 @@ async def test_setting_update_value(db_session: AsyncSession):
setting = await setting.save(db_session)
# 更新值
from models.base import SQLModelBase
from sqlmodels.base import SQLModelBase
class SettingUpdate(SQLModelBase):
value: str | None = None

View File

@@ -0,0 +1,273 @@
"""
DiskNextURI 模型的单元测试
"""
import pytest
from sqlmodels.uri import DiskNextURI, FileSystemNamespace
class TestDiskNextURIParse:
"""测试 URI 解析"""
def test_parse_my_root(self):
"""测试解析个人空间根目录"""
uri = DiskNextURI.parse("disknext://my/")
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/"
assert uri.fs_id is None
assert uri.password is None
assert uri.is_root is True
def test_parse_my_with_path(self):
"""测试解析个人空间带路径"""
uri = DiskNextURI.parse("disknext://my/docs/readme.md")
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/docs/readme.md"
assert uri.fs_id is None
assert uri.path_parts == ["docs", "readme.md"]
assert uri.is_root is False
def test_parse_my_with_fs_id(self):
"""测试解析带 fs_id 的个人空间"""
uri = DiskNextURI.parse("disknext://some-uuid@my/docs")
assert uri.namespace == FileSystemNamespace.MY
assert uri.fs_id == "some-uuid"
assert uri.path == "/docs"
def test_parse_share_with_code(self):
"""测试解析分享链接"""
uri = DiskNextURI.parse("disknext://abc123@share/")
assert uri.namespace == FileSystemNamespace.SHARE
assert uri.fs_id == "abc123"
assert uri.path == "/"
assert uri.password is None
def test_parse_share_with_password(self):
"""测试解析带密码的分享链接"""
uri = DiskNextURI.parse("disknext://abc123:mypass@share/sub/dir")
assert uri.namespace == FileSystemNamespace.SHARE
assert uri.fs_id == "abc123"
assert uri.password == "mypass"
assert uri.path == "/sub/dir"
def test_parse_trash(self):
"""测试解析回收站"""
uri = DiskNextURI.parse("disknext://trash/")
assert uri.namespace == FileSystemNamespace.TRASH
assert uri.is_root is True
def test_parse_with_query(self):
"""测试解析带查询参数的 URI"""
uri = DiskNextURI.parse("disknext://my/?name=report&type=file")
assert uri.namespace == FileSystemNamespace.MY
assert uri.query is not None
assert uri.query["name"] == "report"
assert uri.query["type"] == "file"
def test_parse_invalid_scheme(self):
"""测试无效的协议前缀"""
with pytest.raises(ValueError, match="disknext://"):
DiskNextURI.parse("http://my/docs")
def test_parse_invalid_namespace(self):
"""测试无效的命名空间"""
with pytest.raises(ValueError, match="无效的命名空间"):
DiskNextURI.parse("disknext://invalid/docs")
def test_parse_no_namespace(self):
"""测试缺少命名空间"""
with pytest.raises(ValueError):
DiskNextURI.parse("disknext://")
class TestDiskNextURIBuild:
"""测试 URI 构建"""
def test_build_simple(self):
"""测试简单构建"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.namespace == FileSystemNamespace.MY
assert uri.path == "/"
assert uri.fs_id is None
def test_build_with_path(self):
"""测试带路径构建"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.path == "/docs/readme.md"
def test_build_path_auto_prefix(self):
"""测试路径自动添加 / 前缀"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="docs/readme.md")
assert uri.path == "/docs/readme.md"
def test_build_with_fs_id(self):
"""测试带 fs_id 构建"""
uri = DiskNextURI.build(
FileSystemNamespace.SHARE,
fs_id="abc123",
password="secret",
)
assert uri.fs_id == "abc123"
assert uri.password == "secret"
class TestDiskNextURIToString:
"""测试 URI 序列化"""
def test_to_string_simple(self):
"""测试简单序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.to_string() == "disknext://my/"
def test_to_string_with_path(self):
"""测试带路径序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.to_string() == "disknext://my/docs/readme.md"
def test_to_string_with_fs_id(self):
"""测试带 fs_id 序列化"""
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="uuid-123")
assert uri.to_string() == "disknext://uuid-123@my/"
def test_to_string_with_password(self):
"""测试带密码序列化"""
uri = DiskNextURI.build(
FileSystemNamespace.SHARE,
fs_id="code",
password="pass",
)
assert uri.to_string() == "disknext://code:pass@share/"
def test_to_string_roundtrip(self):
"""测试序列化-反序列化往返"""
original = "disknext://abc123:pass@share/sub/dir"
uri = DiskNextURI.parse(original)
result = uri.to_string()
assert result == original
class TestDiskNextURIId:
"""测试 id() 方法"""
def test_id_with_fs_id(self):
"""测试有 fs_id 时返回 fs_id"""
uri = DiskNextURI.build(FileSystemNamespace.MY, fs_id="my-uuid")
assert uri.id("default") == "my-uuid"
def test_id_without_fs_id(self):
"""测试无 fs_id 时返回默认值"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.id("default-uuid") == "default-uuid"
def test_id_without_fs_id_no_default(self):
"""测试无 fs_id 且无默认值时返回 None"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
assert uri.id() is None
class TestDiskNextURIJoin:
"""测试 join() 方法"""
def test_join_single(self):
"""测试拼接单个路径元素"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
joined = uri.join("readme.md")
assert joined.path == "/docs/readme.md"
def test_join_multiple(self):
"""测试拼接多个路径元素"""
uri = DiskNextURI.build(FileSystemNamespace.MY)
joined = uri.join("docs", "work", "report.pdf")
assert joined.path == "/docs/work/report.pdf"
def test_join_preserves_metadata(self):
"""测试 join 保留 namespace 和 fs_id"""
uri = DiskNextURI.build(FileSystemNamespace.SHARE, fs_id="code123")
joined = uri.join("sub")
assert joined.namespace == FileSystemNamespace.SHARE
assert joined.fs_id == "code123"
class TestDiskNextURIDirUri:
"""测试 dir_uri() 方法"""
def test_dir_uri_file(self):
"""测试获取文件的父目录 URI"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
parent = uri.dir_uri()
assert parent.path == "/docs/"
def test_dir_uri_root(self):
"""测试根目录的 dir_uri 返回自身"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
parent = uri.dir_uri()
assert parent.path == "/"
class TestDiskNextURIRoot:
"""测试 root() 方法"""
def test_root_resets_path(self):
"""测试 root 重置路径"""
uri = DiskNextURI.build(
FileSystemNamespace.MY,
path="/docs/work/report.pdf",
fs_id="uuid-123",
)
root = uri.root()
assert root.path == "/"
assert root.fs_id == "uuid-123"
assert root.namespace == FileSystemNamespace.MY
class TestDiskNextURIName:
"""测试 name() 方法"""
def test_name_file(self):
"""测试获取文件名"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/readme.md")
assert uri.name() == "readme.md"
def test_name_directory(self):
"""测试获取目录名"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work")
assert uri.name() == "work"
def test_name_root(self):
"""测试根目录的 name 返回空字符串"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.name() == ""
class TestDiskNextURIProperties:
"""测试属性方法"""
def test_path_parts(self):
"""测试路径分割"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs/work/report.pdf")
assert uri.path_parts == ["docs", "work", "report.pdf"]
def test_path_parts_root(self):
"""测试根路径分割"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.path_parts == []
def test_is_root_true(self):
"""测试 is_root 为真"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/")
assert uri.is_root is True
def test_is_root_false(self):
"""测试 is_root 为假"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert uri.is_root is False
def test_str_representation(self):
"""测试字符串表示"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert str(uri) == "disknext://my/docs"
def test_repr(self):
"""测试 repr"""
uri = DiskNextURI.build(FileSystemNamespace.MY, path="/docs")
assert "disknext://my/docs" in repr(uri)

View File

@@ -5,8 +5,8 @@ import pytest
from sqlalchemy.exc import IntegrityError
from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User, ThemeType, UserPublic
from models.group import Group
from sqlmodels.user import User, ThemeType, UserPublic
from sqlmodels.group import Group
@pytest.mark.asyncio
@@ -18,7 +18,7 @@ async def test_user_create(db_session: AsyncSession):
# 创建用户
user = User(
username="testuser",
email="testuser@test.local",
nickname="测试用户",
password="hashed_password",
group_id=group.id
@@ -26,7 +26,7 @@ async def test_user_create(db_session: AsyncSession):
user = await user.save(db_session)
assert user.id is not None
assert user.username == "testuser"
assert user.email == "testuser@test.local"
assert user.nickname == "测试用户"
assert user.status is True
assert user.storage == 0
@@ -34,15 +34,15 @@ async def test_user_create(db_session: AsyncSession):
@pytest.mark.asyncio
async def test_user_unique_username(db_session: AsyncSession):
"""测试用户名唯一约束"""
async def test_user_unique_email(db_session: AsyncSession):
"""测试邮箱唯一约束"""
# 创建用户组
group = Group(name="默认组")
group = await group.save(db_session)
# 创建第一个用户
user1 = User(
username="duplicate",
email="duplicate@test.local",
password="password1",
group_id=group.id
)
@@ -50,7 +50,7 @@ async def test_user_unique_username(db_session: AsyncSession):
# 尝试创建同名用户
user2 = User(
username="duplicate",
email="duplicate@test.local",
password="password2",
group_id=group.id
)
@@ -68,7 +68,7 @@ async def test_user_to_public(db_session: AsyncSession):
# 创建用户
user = User(
username="publicuser",
email="publicuser@test.local",
nickname="公开用户",
password="secret_password",
storage=1024,
@@ -82,7 +82,7 @@ async def test_user_to_public(db_session: AsyncSession):
assert isinstance(public_user, UserPublic)
assert public_user.id == user.id
assert public_user.username == "publicuser"
assert public_user.email == "publicuser@test.local"
# 注意: UserPublic.nick 字段名与 User.nickname 不同,
# model_validate 不会自动映射,所以 nick 为 None
# 这是已知的设计问题,需要在 UserPublic 中添加别名或重命名字段
@@ -101,7 +101,7 @@ async def test_user_group_relationship(db_session: AsyncSession):
# 创建用户
user = User(
username="vipuser",
email="vipuser@test.local",
password="password",
group_id=group.id
)
@@ -125,7 +125,7 @@ async def test_user_status_default(db_session: AsyncSession):
group = await group.save(db_session)
user = User(
username="defaultuser",
email="defaultuser@test.local",
password="password",
group_id=group.id
)
@@ -141,7 +141,7 @@ async def test_user_storage_default(db_session: AsyncSession):
group = await group.save(db_session)
user = User(
username="storageuser",
email="storageuser@test.local",
password="password",
group_id=group.id
)
@@ -158,7 +158,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试默认值
user1 = User(
username="user1",
email="user1@test.local",
password="password",
group_id=group.id
)
@@ -167,7 +167,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 LIGHT
user2 = User(
username="user2",
email="user2@test.local",
password="password",
theme=ThemeType.LIGHT,
group_id=group.id
@@ -177,7 +177,7 @@ async def test_user_theme_enum(db_session: AsyncSession):
# 测试设置为 DARK
user3 = User(
username="user3",
email="user3@test.local",
password="password",
theme=ThemeType.DARK,
group_id=group.id

View File

@@ -4,8 +4,8 @@ Login 服务的单元测试
import pytest
from sqlmodel.ext.asyncio.session import AsyncSession
from models.user import User, LoginRequest, TokenResponse
from models.group import Group
from sqlmodels.user import User, LoginRequest, TokenResponse
from sqlmodels.group import Group
from service.user.login import login
from utils.password.pwd import Password
@@ -20,7 +20,7 @@ async def setup_user(db_session: AsyncSession):
# 创建正常用户
plain_password = "secure_password_123"
user = User(
username="loginuser",
email="loginuser@test.local",
password=Password.hash(plain_password),
status=True,
group_id=group.id
@@ -41,7 +41,7 @@ async def setup_banned_user(db_session: AsyncSession):
group = await group.save(db_session)
user = User(
username="banneduser",
email="banneduser@test.local",
password=Password.hash("password"),
status=False, # 封禁状态
group_id=group.id
@@ -61,7 +61,7 @@ async def setup_2fa_user(db_session: AsyncSession):
secret = pyotp.random_base32()
user = User(
username="2fauser",
email="2fauser@test.local",
password=Password.hash("password"),
status=True,
two_factor=secret,
@@ -82,7 +82,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
user_data = setup_user
login_request = LoginRequest(
username="loginuser",
email="loginuser@test.local",
password=user_data["password"]
)
@@ -99,7 +99,7 @@ async def test_login_success(db_session: AsyncSession, setup_user):
async def test_login_user_not_found(db_session: AsyncSession):
"""测试用户不存在"""
login_request = LoginRequest(
username="nonexistent_user",
email="nonexistent@test.local",
password="any_password"
)
@@ -112,7 +112,7 @@ async def test_login_user_not_found(db_session: AsyncSession):
async def test_login_wrong_password(db_session: AsyncSession, setup_user):
"""测试密码错误"""
login_request = LoginRequest(
username="loginuser",
email="loginuser@test.local",
password="wrong_password"
)
@@ -125,7 +125,7 @@ async def test_login_wrong_password(db_session: AsyncSession, setup_user):
async def test_login_user_banned(db_session: AsyncSession, setup_banned_user):
"""测试用户被封禁"""
login_request = LoginRequest(
username="banneduser",
email="banneduser@test.local",
password="password"
)
@@ -140,7 +140,7 @@ async def test_login_2fa_required(db_session: AsyncSession, setup_2fa_user):
user_data = setup_2fa_user
login_request = LoginRequest(
username="2fauser",
email="2fauser@test.local",
password=user_data["password"]
# 未提供 two_fa_code
)
@@ -156,7 +156,7 @@ async def test_login_2fa_invalid(db_session: AsyncSession, setup_2fa_user):
user_data = setup_2fa_user
login_request = LoginRequest(
username="2fauser",
email="2fauser@test.local",
password=user_data["password"],
two_fa_code="000000" # 错误的验证码
)
@@ -179,7 +179,7 @@ async def test_login_2fa_success(db_session: AsyncSession, setup_2fa_user):
valid_code = totp.now()
login_request = LoginRequest(
username="2fauser",
email="2fauser@test.local",
password=user_data["password"],
two_fa_code=valid_code
)
@@ -198,7 +198,7 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
user_data = setup_user
login_request = LoginRequest(
username="loginuser",
email="loginuser@test.local",
password=user_data["password"]
)
@@ -217,17 +217,17 @@ async def test_login_returns_valid_tokens(db_session: AsyncSession, setup_user):
@pytest.mark.asyncio
async def test_login_case_sensitive_username(db_session: AsyncSession, setup_user):
"""测试用户名大小写敏感"""
async def test_login_case_sensitive_email(db_session: AsyncSession, setup_user):
"""测试邮箱大小写敏感"""
user_data = setup_user
# 使用大写用户名登录(如果数据库是 loginuser
# 使用大写邮箱登录
login_request = LoginRequest(
username="LOGINUSER",
email="LOGINUSER@TEST.LOCAL",
password=user_data["password"]
)
result = await login(db_session, login_request)
# 应该失败,因为用户名大小写不匹配
# 应该失败,因为邮箱大小写不匹配
assert result is None

View File

@@ -72,9 +72,9 @@ def test_password_verify_expired():
@pytest.mark.asyncio
async def test_totp_generate():
"""测试 TOTP 密钥生成"""
username = "testuser"
email = "testuser@test.local"
response = await Password.generate_totp(username)
response = await Password.generate_totp(email)
assert response.setup_token is not None
assert response.uri is not None
@@ -82,7 +82,7 @@ async def test_totp_generate():
assert isinstance(response.uri, str)
# TOTP URI 格式: otpauth://totp/...
assert response.uri.startswith("otpauth://totp/")
assert username in response.uri
assert email in response.uri
def test_totp_verify_valid():