Files
pixel/backend/tests/test_base_repository.py
张鹏 f9f4560459 Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents
- Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow
- Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.)
- All HTTP clients migrated from sync requests to async httpx
- Admin-managed API keys via environment variables
- SSRF vulnerability fixed in ensure_url()
2026-04-29 01:20:12 +08:00

349 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
测试 BaseRepository 功能
验证通用仓储模式的 CRUD 操作、过滤、排序和分页功能
"""
import pytest
from datetime import datetime
from sqlmodel import Session, create_engine, SQLModel
from sqlalchemy.pool import StaticPool
from src.models.entities import ProjectDB, TaskDB
from src.repositories.base_repository import BaseRepository
from src.repositories.task_repository import TaskRepository
@pytest.fixture
def engine():
"""创建内存数据库引擎用于测试"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture
def session(engine):
"""创建数据库会话"""
with Session(engine) as session:
yield session
@pytest.fixture
def task_repository(session):
"""创建 TaskRepository 实例"""
return TaskRepository(session)
@pytest.fixture
def sample_tasks(session):
"""创建示例任务数据"""
tasks = [
TaskDB(
id=f"task_{i}",
type="image" if i % 2 == 0 else "video",
status="pending" if i < 3 else "processing",
model="flux-dev",
params={"prompt": f"test prompt {i}"},
user_id="user_1" if i < 5 else "user_2",
project_id="project_1",
created_at=datetime.now().timestamp() + i,
updated_at=datetime.now().timestamp() + i
)
for i in range(10)
]
for task in tasks:
session.add(task)
session.commit()
return tasks
class TestBaseRepositoryCRUD:
"""测试基础 CRUD 操作"""
def test_create(self, task_repository, session):
"""测试创建记录"""
task = TaskDB(
id="test_task",
type="image",
status="pending",
model="flux-dev",
params={"prompt": "test"},
created_at=datetime.now().timestamp(),
updated_at=datetime.now().timestamp()
)
created = task_repository.create(task)
assert created.id == "test_task"
assert created.type == "image"
assert created.status == "pending"
def test_get(self, task_repository, sample_tasks):
"""测试获取单条记录"""
task = task_repository.get("task_0")
assert task is not None
assert task.id == "task_0"
assert task.type == "image"
def test_get_not_found(self, task_repository):
"""测试获取不存在的记录"""
task = task_repository.get("nonexistent")
assert task is None
def test_get_by_field(self, task_repository, sample_tasks):
"""测试按字段获取记录"""
task = task_repository.get_by_field("user_id", "user_1")
assert task is not None
assert task.user_id == "user_1"
def test_update(self, task_repository, sample_tasks):
"""测试更新记录"""
task = task_repository.get("task_0")
task.status = "completed"
updated = task_repository.update(task)
assert updated.status == "completed"
# 验证更新已持久化
retrieved = task_repository.get("task_0")
assert retrieved.status == "completed"
def test_update_by_id(self, task_repository, sample_tasks):
"""测试按 ID 更新记录"""
updated = task_repository.update_by_id("task_0", {"status": "failed"})
assert updated is not None
assert updated.status == "failed"
def test_delete(self, task_repository, sample_tasks):
"""测试删除记录"""
result = task_repository.delete("task_0")
assert result is True
# 验证已删除
task = task_repository.get("task_0")
assert task is None
def test_delete_not_found(self, task_repository):
"""测试删除不存在的记录"""
result = task_repository.delete("nonexistent")
assert result is False
def test_soft_delete(self, task_repository, sample_tasks):
"""测试软删除"""
result = task_repository.soft_delete("task_0")
assert result is True
# 验证 deleted_at 已设置
task = task_repository.get("task_0")
assert task.deleted_at is not None
def test_exists(self, task_repository, sample_tasks):
"""测试检查记录是否存在"""
assert task_repository.exists("task_0") is True
assert task_repository.exists("nonexistent") is False
def test_exists_by_field(self, task_repository, sample_tasks):
"""测试按字段检查记录是否存在"""
assert task_repository.exists_by_field("user_id", "user_1") is True
assert task_repository.exists_by_field("user_id", "nonexistent") is False
class TestBaseRepositoryFiltering:
"""测试过滤功能"""
def test_list_with_simple_filter(self, task_repository, sample_tasks):
"""测试简单过滤"""
tasks = task_repository.list(filters={"type": "image"})
assert len(tasks) == 5
assert all(task.type == "image" for task in tasks)
def test_list_with_multiple_filters(self, task_repository, sample_tasks):
"""测试多个过滤条件"""
tasks = task_repository.list(filters={
"type": "image",
"status": "pending"
})
assert len(tasks) == 2 # task_0 和 task_2
assert all(task.type == "image" and task.status == "pending" for task in tasks)
def test_list_with_gt_filter(self, task_repository, sample_tasks):
"""测试大于过滤"""
# 获取前5个任务user_1
tasks = task_repository.list(filters={"user_id": "user_1"})
assert len(tasks) == 5
def test_list_with_in_filter(self, task_repository, sample_tasks):
"""测试 IN 过滤"""
tasks = task_repository.list(filters={
"status__in": ["pending", "processing"]
})
assert len(tasks) == 10
def test_list_with_like_filter(self, task_repository, sample_tasks):
"""测试 LIKE 过滤"""
# 注意SQLite 的 LIKE 是大小写不敏感的
tasks = task_repository.list(filters={"model__like": "%flux%"})
assert len(tasks) == 10
class TestBaseRepositorySorting:
"""测试排序功能"""
def test_list_with_asc_sort(self, task_repository, sample_tasks):
"""测试升序排序"""
tasks = task_repository.list(sort_by="created_at", sort_order="asc")
# 验证按创建时间升序
for i in range(len(tasks) - 1):
assert tasks[i].created_at <= tasks[i + 1].created_at
def test_list_with_desc_sort(self, task_repository, sample_tasks):
"""测试降序排序"""
tasks = task_repository.list(sort_by="created_at", sort_order="desc")
# 验证按创建时间降序
for i in range(len(tasks) - 1):
assert tasks[i].created_at >= tasks[i + 1].created_at
class TestBaseRepositoryPagination:
"""测试分页功能"""
def test_list_with_pagination(self, task_repository, sample_tasks):
"""测试分页"""
# 第一页
page1 = task_repository.list(skip=0, limit=3)
assert len(page1) == 3
# 第二页
page2 = task_repository.list(skip=3, limit=3)
assert len(page2) == 3
# 验证不重复
page1_ids = {task.id for task in page1}
page2_ids = {task.id for task in page2}
assert len(page1_ids & page2_ids) == 0
def test_list_paginated(self, task_repository, sample_tasks):
"""测试分页方法"""
records, total = task_repository.list_paginated(page=1, page_size=3)
assert len(records) == 3
assert total == 10
# 第二页
records, total = task_repository.list_paginated(page=2, page_size=3)
assert len(records) == 3
assert total == 10
def test_count(self, task_repository, sample_tasks):
"""测试计数"""
total = task_repository.count()
assert total == 10
# 带过滤的计数
count = task_repository.count(filters={"type": "image"})
assert count == 5
class TestBaseRepositoryBatchOperations:
"""测试批量操作"""
def test_create_many(self, task_repository, session):
"""测试批量创建"""
tasks = [
TaskDB(
id=f"batch_task_{i}",
type="image",
status="pending",
model="flux-dev",
params={},
created_at=datetime.now().timestamp(),
updated_at=datetime.now().timestamp()
)
for i in range(5)
]
created = task_repository.create_many(tasks)
assert len(created) == 5
# 验证已创建
for task in created:
retrieved = task_repository.get(task.id)
assert retrieved is not None
class TestTaskRepositorySpecific:
"""测试 TaskRepository 特定方法"""
def test_list_by_status(self, task_repository, sample_tasks):
"""测试按状态列出任务"""
tasks = task_repository.list_by_status("pending")
assert len(tasks) == 3
assert all(task.status == "pending" for task in tasks)
def test_list_by_user(self, task_repository, sample_tasks):
"""测试按用户列出任务"""
tasks = task_repository.list_by_user("user_1")
assert len(tasks) == 5
assert all(task.user_id == "user_1" for task in tasks)
def test_list_by_project(self, task_repository, sample_tasks):
"""测试按项目列出任务"""
tasks = task_repository.list_by_project("project_1")
assert len(tasks) == 10
assert all(task.project_id == "project_1" for task in tasks)
def test_count_by_status(self, task_repository, sample_tasks):
"""测试按状态计数"""
count = task_repository.count_by_status("pending")
assert count == 3
def test_count_by_user(self, task_repository, sample_tasks):
"""测试按用户计数"""
count = task_repository.count_by_user("user_1")
assert count == 5
class TestQueryPerformanceTracking:
"""测试查询性能跟踪"""
def test_query_stats_tracking(self, task_repository, sample_tasks):
"""测试查询统计跟踪"""
# 执行一些查询
task_repository.list(limit=5)
task_repository.get("task_0")
task_repository.count()
# 获取统计信息
stats = task_repository.get_query_stats()
assert stats["query_count"] >= 3
assert stats["total_query_time"] > 0
assert stats["avg_query_time"] > 0