""" 测试 BaseRepository 功能 验证通用仓储模式的 CRUD 操作、过滤、排序和分页功能 """ import pytest from datetime import datetime from sqlmodel import Session, create_engine, SQLModel from sqlalchemy.pool import StaticPool from src.models.entities import ProjectDB, TaskDB from src.repositories.base_repository import BaseRepository from src.repositories.task_repository import TaskRepository @pytest.fixture def engine(): """创建内存数据库引擎用于测试""" engine = create_engine( "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) SQLModel.metadata.create_all(engine) return engine @pytest.fixture def session(engine): """创建数据库会话""" with Session(engine) as session: yield session @pytest.fixture def task_repository(session): """创建 TaskRepository 实例""" return TaskRepository(session) @pytest.fixture def sample_tasks(session): """创建示例任务数据""" tasks = [ TaskDB( id=f"task_{i}", type="image" if i % 2 == 0 else "video", status="pending" if i < 3 else "processing", model="flux-dev", params={"prompt": f"test prompt {i}"}, user_id="user_1" if i < 5 else "user_2", project_id="project_1", created_at=datetime.now().timestamp() + i, updated_at=datetime.now().timestamp() + i ) for i in range(10) ] for task in tasks: session.add(task) session.commit() return tasks class TestBaseRepositoryCRUD: """测试基础 CRUD 操作""" def test_create(self, task_repository, session): """测试创建记录""" task = TaskDB( id="test_task", type="image", status="pending", model="flux-dev", params={"prompt": "test"}, created_at=datetime.now().timestamp(), updated_at=datetime.now().timestamp() ) created = task_repository.create(task) assert created.id == "test_task" assert created.type == "image" assert created.status == "pending" def test_get(self, task_repository, sample_tasks): """测试获取单条记录""" task = task_repository.get("task_0") assert task is not None assert task.id == "task_0" assert task.type == "image" def test_get_not_found(self, task_repository): """测试获取不存在的记录""" task = task_repository.get("nonexistent") assert task is None def test_get_by_field(self, task_repository, sample_tasks): """测试按字段获取记录""" task = task_repository.get_by_field("user_id", "user_1") assert task is not None assert task.user_id == "user_1" def test_update(self, task_repository, sample_tasks): """测试更新记录""" task = task_repository.get("task_0") task.status = "completed" updated = task_repository.update(task) assert updated.status == "completed" # 验证更新已持久化 retrieved = task_repository.get("task_0") assert retrieved.status == "completed" def test_update_by_id(self, task_repository, sample_tasks): """测试按 ID 更新记录""" updated = task_repository.update_by_id("task_0", {"status": "failed"}) assert updated is not None assert updated.status == "failed" def test_delete(self, task_repository, sample_tasks): """测试删除记录""" result = task_repository.delete("task_0") assert result is True # 验证已删除 task = task_repository.get("task_0") assert task is None def test_delete_not_found(self, task_repository): """测试删除不存在的记录""" result = task_repository.delete("nonexistent") assert result is False def test_soft_delete(self, task_repository, sample_tasks): """测试软删除""" result = task_repository.soft_delete("task_0") assert result is True # 验证 deleted_at 已设置 task = task_repository.get("task_0") assert task.deleted_at is not None def test_exists(self, task_repository, sample_tasks): """测试检查记录是否存在""" assert task_repository.exists("task_0") is True assert task_repository.exists("nonexistent") is False def test_exists_by_field(self, task_repository, sample_tasks): """测试按字段检查记录是否存在""" assert task_repository.exists_by_field("user_id", "user_1") is True assert task_repository.exists_by_field("user_id", "nonexistent") is False class TestBaseRepositoryFiltering: """测试过滤功能""" def test_list_with_simple_filter(self, task_repository, sample_tasks): """测试简单过滤""" tasks = task_repository.list(filters={"type": "image"}) assert len(tasks) == 5 assert all(task.type == "image" for task in tasks) def test_list_with_multiple_filters(self, task_repository, sample_tasks): """测试多个过滤条件""" tasks = task_repository.list(filters={ "type": "image", "status": "pending" }) assert len(tasks) == 2 # task_0 和 task_2 assert all(task.type == "image" and task.status == "pending" for task in tasks) def test_list_with_gt_filter(self, task_repository, sample_tasks): """测试大于过滤""" # 获取前5个任务(user_1) tasks = task_repository.list(filters={"user_id": "user_1"}) assert len(tasks) == 5 def test_list_with_in_filter(self, task_repository, sample_tasks): """测试 IN 过滤""" tasks = task_repository.list(filters={ "status__in": ["pending", "processing"] }) assert len(tasks) == 10 def test_list_with_like_filter(self, task_repository, sample_tasks): """测试 LIKE 过滤""" # 注意:SQLite 的 LIKE 是大小写不敏感的 tasks = task_repository.list(filters={"model__like": "%flux%"}) assert len(tasks) == 10 class TestBaseRepositorySorting: """测试排序功能""" def test_list_with_asc_sort(self, task_repository, sample_tasks): """测试升序排序""" tasks = task_repository.list(sort_by="created_at", sort_order="asc") # 验证按创建时间升序 for i in range(len(tasks) - 1): assert tasks[i].created_at <= tasks[i + 1].created_at def test_list_with_desc_sort(self, task_repository, sample_tasks): """测试降序排序""" tasks = task_repository.list(sort_by="created_at", sort_order="desc") # 验证按创建时间降序 for i in range(len(tasks) - 1): assert tasks[i].created_at >= tasks[i + 1].created_at class TestBaseRepositoryPagination: """测试分页功能""" def test_list_with_pagination(self, task_repository, sample_tasks): """测试分页""" # 第一页 page1 = task_repository.list(skip=0, limit=3) assert len(page1) == 3 # 第二页 page2 = task_repository.list(skip=3, limit=3) assert len(page2) == 3 # 验证不重复 page1_ids = {task.id for task in page1} page2_ids = {task.id for task in page2} assert len(page1_ids & page2_ids) == 0 def test_list_paginated(self, task_repository, sample_tasks): """测试分页方法""" records, total = task_repository.list_paginated(page=1, page_size=3) assert len(records) == 3 assert total == 10 # 第二页 records, total = task_repository.list_paginated(page=2, page_size=3) assert len(records) == 3 assert total == 10 def test_count(self, task_repository, sample_tasks): """测试计数""" total = task_repository.count() assert total == 10 # 带过滤的计数 count = task_repository.count(filters={"type": "image"}) assert count == 5 class TestBaseRepositoryBatchOperations: """测试批量操作""" def test_create_many(self, task_repository, session): """测试批量创建""" tasks = [ TaskDB( id=f"batch_task_{i}", type="image", status="pending", model="flux-dev", params={}, created_at=datetime.now().timestamp(), updated_at=datetime.now().timestamp() ) for i in range(5) ] created = task_repository.create_many(tasks) assert len(created) == 5 # 验证已创建 for task in created: retrieved = task_repository.get(task.id) assert retrieved is not None class TestTaskRepositorySpecific: """测试 TaskRepository 特定方法""" def test_list_by_status(self, task_repository, sample_tasks): """测试按状态列出任务""" tasks = task_repository.list_by_status("pending") assert len(tasks) == 3 assert all(task.status == "pending" for task in tasks) def test_list_by_user(self, task_repository, sample_tasks): """测试按用户列出任务""" tasks = task_repository.list_by_user("user_1") assert len(tasks) == 5 assert all(task.user_id == "user_1" for task in tasks) def test_list_by_project(self, task_repository, sample_tasks): """测试按项目列出任务""" tasks = task_repository.list_by_project("project_1") assert len(tasks) == 10 assert all(task.project_id == "project_1" for task in tasks) def test_count_by_status(self, task_repository, sample_tasks): """测试按状态计数""" count = task_repository.count_by_status("pending") assert count == 3 def test_count_by_user(self, task_repository, sample_tasks): """测试按用户计数""" count = task_repository.count_by_user("user_1") assert count == 5 class TestQueryPerformanceTracking: """测试查询性能跟踪""" def test_query_stats_tracking(self, task_repository, sample_tasks): """测试查询统计跟踪""" # 执行一些查询 task_repository.list(limit=5) task_repository.get("task_0") task_repository.count() # 获取统计信息 stats = task_repository.get_query_stats() assert stats["query_count"] >= 3 assert stats["total_query_time"] > 0 assert stats["avg_query_time"] > 0