- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
349 lines
11 KiB
Python
349 lines
11 KiB
Python
"""
|
||
测试 BaseRepository 功能
|
||
|
||
验证通用仓储模式的 CRUD 操作、过滤、排序和分页功能
|
||
"""
|
||
|
||
import pytest
|
||
from datetime import datetime
|
||
from sqlmodel import Session, create_engine, SQLModel
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from src.models.entities import ProjectDB, TaskDB
|
||
from src.repositories.base_repository import BaseRepository
|
||
from src.repositories.task_repository import TaskRepository
|
||
|
||
|
||
@pytest.fixture
|
||
def engine():
|
||
"""创建内存数据库引擎用于测试"""
|
||
engine = create_engine(
|
||
"sqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
SQLModel.metadata.create_all(engine)
|
||
return engine
|
||
|
||
|
||
@pytest.fixture
|
||
def session(engine):
|
||
"""创建数据库会话"""
|
||
with Session(engine) as session:
|
||
yield session
|
||
|
||
|
||
@pytest.fixture
|
||
def task_repository(session):
|
||
"""创建 TaskRepository 实例"""
|
||
return TaskRepository(session)
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_tasks(session):
|
||
"""创建示例任务数据"""
|
||
tasks = [
|
||
TaskDB(
|
||
id=f"task_{i}",
|
||
type="image" if i % 2 == 0 else "video",
|
||
status="pending" if i < 3 else "processing",
|
||
model="flux-dev",
|
||
params={"prompt": f"test prompt {i}"},
|
||
user_id="user_1" if i < 5 else "user_2",
|
||
project_id="project_1",
|
||
created_at=datetime.now().timestamp() + i,
|
||
updated_at=datetime.now().timestamp() + i
|
||
)
|
||
for i in range(10)
|
||
]
|
||
|
||
for task in tasks:
|
||
session.add(task)
|
||
session.commit()
|
||
|
||
return tasks
|
||
|
||
|
||
class TestBaseRepositoryCRUD:
|
||
"""测试基础 CRUD 操作"""
|
||
|
||
def test_create(self, task_repository, session):
|
||
"""测试创建记录"""
|
||
task = TaskDB(
|
||
id="test_task",
|
||
type="image",
|
||
status="pending",
|
||
model="flux-dev",
|
||
params={"prompt": "test"},
|
||
created_at=datetime.now().timestamp(),
|
||
updated_at=datetime.now().timestamp()
|
||
)
|
||
|
||
created = task_repository.create(task)
|
||
|
||
assert created.id == "test_task"
|
||
assert created.type == "image"
|
||
assert created.status == "pending"
|
||
|
||
def test_get(self, task_repository, sample_tasks):
|
||
"""测试获取单条记录"""
|
||
task = task_repository.get("task_0")
|
||
|
||
assert task is not None
|
||
assert task.id == "task_0"
|
||
assert task.type == "image"
|
||
|
||
def test_get_not_found(self, task_repository):
|
||
"""测试获取不存在的记录"""
|
||
task = task_repository.get("nonexistent")
|
||
|
||
assert task is None
|
||
|
||
def test_get_by_field(self, task_repository, sample_tasks):
|
||
"""测试按字段获取记录"""
|
||
task = task_repository.get_by_field("user_id", "user_1")
|
||
|
||
assert task is not None
|
||
assert task.user_id == "user_1"
|
||
|
||
def test_update(self, task_repository, sample_tasks):
|
||
"""测试更新记录"""
|
||
task = task_repository.get("task_0")
|
||
task.status = "completed"
|
||
|
||
updated = task_repository.update(task)
|
||
|
||
assert updated.status == "completed"
|
||
|
||
# 验证更新已持久化
|
||
retrieved = task_repository.get("task_0")
|
||
assert retrieved.status == "completed"
|
||
|
||
def test_update_by_id(self, task_repository, sample_tasks):
|
||
"""测试按 ID 更新记录"""
|
||
updated = task_repository.update_by_id("task_0", {"status": "failed"})
|
||
|
||
assert updated is not None
|
||
assert updated.status == "failed"
|
||
|
||
def test_delete(self, task_repository, sample_tasks):
|
||
"""测试删除记录"""
|
||
result = task_repository.delete("task_0")
|
||
|
||
assert result is True
|
||
|
||
# 验证已删除
|
||
task = task_repository.get("task_0")
|
||
assert task is None
|
||
|
||
def test_delete_not_found(self, task_repository):
|
||
"""测试删除不存在的记录"""
|
||
result = task_repository.delete("nonexistent")
|
||
|
||
assert result is False
|
||
|
||
def test_soft_delete(self, task_repository, sample_tasks):
|
||
"""测试软删除"""
|
||
result = task_repository.soft_delete("task_0")
|
||
|
||
assert result is True
|
||
|
||
# 验证 deleted_at 已设置
|
||
task = task_repository.get("task_0")
|
||
assert task.deleted_at is not None
|
||
|
||
def test_exists(self, task_repository, sample_tasks):
|
||
"""测试检查记录是否存在"""
|
||
assert task_repository.exists("task_0") is True
|
||
assert task_repository.exists("nonexistent") is False
|
||
|
||
def test_exists_by_field(self, task_repository, sample_tasks):
|
||
"""测试按字段检查记录是否存在"""
|
||
assert task_repository.exists_by_field("user_id", "user_1") is True
|
||
assert task_repository.exists_by_field("user_id", "nonexistent") is False
|
||
|
||
|
||
class TestBaseRepositoryFiltering:
|
||
"""测试过滤功能"""
|
||
|
||
def test_list_with_simple_filter(self, task_repository, sample_tasks):
|
||
"""测试简单过滤"""
|
||
tasks = task_repository.list(filters={"type": "image"})
|
||
|
||
assert len(tasks) == 5
|
||
assert all(task.type == "image" for task in tasks)
|
||
|
||
def test_list_with_multiple_filters(self, task_repository, sample_tasks):
|
||
"""测试多个过滤条件"""
|
||
tasks = task_repository.list(filters={
|
||
"type": "image",
|
||
"status": "pending"
|
||
})
|
||
|
||
assert len(tasks) == 2 # task_0 和 task_2
|
||
assert all(task.type == "image" and task.status == "pending" for task in tasks)
|
||
|
||
def test_list_with_gt_filter(self, task_repository, sample_tasks):
|
||
"""测试大于过滤"""
|
||
# 获取前5个任务(user_1)
|
||
tasks = task_repository.list(filters={"user_id": "user_1"})
|
||
assert len(tasks) == 5
|
||
|
||
def test_list_with_in_filter(self, task_repository, sample_tasks):
|
||
"""测试 IN 过滤"""
|
||
tasks = task_repository.list(filters={
|
||
"status__in": ["pending", "processing"]
|
||
})
|
||
|
||
assert len(tasks) == 10
|
||
|
||
def test_list_with_like_filter(self, task_repository, sample_tasks):
|
||
"""测试 LIKE 过滤"""
|
||
# 注意:SQLite 的 LIKE 是大小写不敏感的
|
||
tasks = task_repository.list(filters={"model__like": "%flux%"})
|
||
|
||
assert len(tasks) == 10
|
||
|
||
|
||
class TestBaseRepositorySorting:
|
||
"""测试排序功能"""
|
||
|
||
def test_list_with_asc_sort(self, task_repository, sample_tasks):
|
||
"""测试升序排序"""
|
||
tasks = task_repository.list(sort_by="created_at", sort_order="asc")
|
||
|
||
# 验证按创建时间升序
|
||
for i in range(len(tasks) - 1):
|
||
assert tasks[i].created_at <= tasks[i + 1].created_at
|
||
|
||
def test_list_with_desc_sort(self, task_repository, sample_tasks):
|
||
"""测试降序排序"""
|
||
tasks = task_repository.list(sort_by="created_at", sort_order="desc")
|
||
|
||
# 验证按创建时间降序
|
||
for i in range(len(tasks) - 1):
|
||
assert tasks[i].created_at >= tasks[i + 1].created_at
|
||
|
||
|
||
class TestBaseRepositoryPagination:
|
||
"""测试分页功能"""
|
||
|
||
def test_list_with_pagination(self, task_repository, sample_tasks):
|
||
"""测试分页"""
|
||
# 第一页
|
||
page1 = task_repository.list(skip=0, limit=3)
|
||
assert len(page1) == 3
|
||
|
||
# 第二页
|
||
page2 = task_repository.list(skip=3, limit=3)
|
||
assert len(page2) == 3
|
||
|
||
# 验证不重复
|
||
page1_ids = {task.id for task in page1}
|
||
page2_ids = {task.id for task in page2}
|
||
assert len(page1_ids & page2_ids) == 0
|
||
|
||
def test_list_paginated(self, task_repository, sample_tasks):
|
||
"""测试分页方法"""
|
||
records, total = task_repository.list_paginated(page=1, page_size=3)
|
||
|
||
assert len(records) == 3
|
||
assert total == 10
|
||
|
||
# 第二页
|
||
records, total = task_repository.list_paginated(page=2, page_size=3)
|
||
assert len(records) == 3
|
||
assert total == 10
|
||
|
||
def test_count(self, task_repository, sample_tasks):
|
||
"""测试计数"""
|
||
total = task_repository.count()
|
||
assert total == 10
|
||
|
||
# 带过滤的计数
|
||
count = task_repository.count(filters={"type": "image"})
|
||
assert count == 5
|
||
|
||
|
||
class TestBaseRepositoryBatchOperations:
|
||
"""测试批量操作"""
|
||
|
||
def test_create_many(self, task_repository, session):
|
||
"""测试批量创建"""
|
||
tasks = [
|
||
TaskDB(
|
||
id=f"batch_task_{i}",
|
||
type="image",
|
||
status="pending",
|
||
model="flux-dev",
|
||
params={},
|
||
created_at=datetime.now().timestamp(),
|
||
updated_at=datetime.now().timestamp()
|
||
)
|
||
for i in range(5)
|
||
]
|
||
|
||
created = task_repository.create_many(tasks)
|
||
|
||
assert len(created) == 5
|
||
|
||
# 验证已创建
|
||
for task in created:
|
||
retrieved = task_repository.get(task.id)
|
||
assert retrieved is not None
|
||
|
||
|
||
class TestTaskRepositorySpecific:
|
||
"""测试 TaskRepository 特定方法"""
|
||
|
||
def test_list_by_status(self, task_repository, sample_tasks):
|
||
"""测试按状态列出任务"""
|
||
tasks = task_repository.list_by_status("pending")
|
||
|
||
assert len(tasks) == 3
|
||
assert all(task.status == "pending" for task in tasks)
|
||
|
||
def test_list_by_user(self, task_repository, sample_tasks):
|
||
"""测试按用户列出任务"""
|
||
tasks = task_repository.list_by_user("user_1")
|
||
|
||
assert len(tasks) == 5
|
||
assert all(task.user_id == "user_1" for task in tasks)
|
||
|
||
def test_list_by_project(self, task_repository, sample_tasks):
|
||
"""测试按项目列出任务"""
|
||
tasks = task_repository.list_by_project("project_1")
|
||
|
||
assert len(tasks) == 10
|
||
assert all(task.project_id == "project_1" for task in tasks)
|
||
|
||
def test_count_by_status(self, task_repository, sample_tasks):
|
||
"""测试按状态计数"""
|
||
count = task_repository.count_by_status("pending")
|
||
|
||
assert count == 3
|
||
|
||
def test_count_by_user(self, task_repository, sample_tasks):
|
||
"""测试按用户计数"""
|
||
count = task_repository.count_by_user("user_1")
|
||
|
||
assert count == 5
|
||
|
||
|
||
class TestQueryPerformanceTracking:
|
||
"""测试查询性能跟踪"""
|
||
|
||
def test_query_stats_tracking(self, task_repository, sample_tasks):
|
||
"""测试查询统计跟踪"""
|
||
# 执行一些查询
|
||
task_repository.list(limit=5)
|
||
task_repository.get("task_0")
|
||
task_repository.count()
|
||
|
||
# 获取统计信息
|
||
stats = task_repository.get_query_stats()
|
||
|
||
assert stats["query_count"] >= 3
|
||
assert stats["total_query_time"] > 0
|
||
assert stats["avg_query_time"] > 0
|