""" Property-Based Tests for Task Management System This module contains property-based tests that verify correctness properties of the task management system across all possible inputs. Properties tested: - Property 1: Task management completeness (creation, tracking, update, cancellation) - Property 2: Task persistence and recovery after restart - Property 3: Task failure recording and retry mechanism Requirements: 2.2, 2.3, 2.4 """ import pytest import asyncio import time from datetime import datetime from unittest.mock import Mock, patch, AsyncMock from hypothesis import given, strategies as st, assume, settings from hypothesis.strategies import composite from sqlmodel import Session, select from src.config.database import engine, init_db from src.models.entities import TaskDB from src.models.schemas import Task from src.services.task_manager import ( UnifiedTaskManager, TaskPriority, TaskConfig, TaskItem ) from src.services.provider.base import TaskStatus from src.utils.errors import ( TaskQueueFullException, TaskNotFoundException, GenerationFailedException ) # ============================================================================ # Hypothesis Strategies for Generating Test Data # ============================================================================ @composite def task_types(draw): """Generate valid task types""" return draw(st.sampled_from(["image", "video"])) @composite def task_statuses(draw): """Generate valid task statuses""" return draw(st.sampled_from([ TaskStatus.PENDING.value, TaskStatus.PROCESSING.value, TaskStatus.SUCCEEDED.value, TaskStatus.FAILED.value, TaskStatus.TIMEOUT.value, TaskStatus.CANCELLED.value, TaskStatus.RETRYING.value ])) @composite def task_priorities(draw): """Generate valid task priorities""" return draw(st.sampled_from(list(TaskPriority))) @composite def task_params(draw): """Generate task parameters""" # Generate simple dictionaries with string keys and various value types keys = draw(st.lists(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))), min_size=1, max_size=5, unique=True)) values = [] for _ in keys: value = draw(st.one_of( st.text(min_size=1, max_size=100, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs'))), st.integers(min_value=1, max_value=1000), st.floats(min_value=0.1, max_value=100.0, allow_nan=False, allow_infinity=False), st.booleans() )) values.append(value) return dict(zip(keys, values)) @composite def model_ids(draw): """Generate model IDs""" return draw(st.text(min_size=3, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd')))) @composite def user_ids(draw): """Generate user IDs""" return draw(st.one_of( st.none(), st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd'))) )) @composite def project_ids(draw): """Generate project IDs""" return draw(st.one_of( st.none(), st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd'))) )) # ============================================================================ # Property 1: Task Management Completeness # ============================================================================ class TestProperty1TaskManagementCompleteness: """ Property 1: 任务管理完整性 验证任务创建、跟踪、更新、取消操作 Validates: Requirements 2.2 """ @given( task_type=task_types(), model=model_ids(), params=task_params(), priority=task_priorities(), user_id=user_ids(), project_id=project_ids() ) @settings(max_examples=5, deadline=5000) # 5 second deadline per example @pytest.mark.asyncio async def test_task_creation_stores_all_information( self, task_type, model, params, priority, user_id, project_id ): """ Property: Task creation should store all provided information correctly For any valid task parameters, the created task should preserve all information """ # Initialize database init_db() # Create task manager manager = UnifiedTaskManager() try: # Create task task = await manager.create_task( task_type=task_type, model=model, params=params, priority=priority, user_id=user_id, project_id=project_id ) # Verify task was created assert task is not None assert task.id is not None # Verify all information is preserved assert task.type == task_type assert task.model == model assert task.params == params assert task.user_id == user_id assert task.project_id == project_id # Verify initial status assert task.status == TaskStatus.PENDING.value assert task.retry_count == 0 # Verify timestamps assert task.created_at is not None assert task.updated_at is not None assert task.started_at is None assert task.completed_at is None # Verify task can be retrieved retrieved_task = await manager.get_task(task.id) assert retrieved_task is not None assert retrieved_task.id == task.id assert retrieved_task.type == task_type assert retrieved_task.model == model finally: # Cleanup: remove task from database with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_task_status_tracking_through_lifecycle( self, task_type, model, params ): """ Property: Task status should be trackable through its entire lifecycle For any task, status updates should be reflected in the database """ # Initialize database init_db() # Create task manager manager = UnifiedTaskManager() try: # Create task task = await manager.create_task( task_type=task_type, model=model, params=params ) # Verify initial status assert task.status == TaskStatus.PENDING.value # Update status to PROCESSING with Session(engine) as session: task_db = session.get(TaskDB, task.id) assert task_db is not None task_db.status = TaskStatus.PROCESSING.value task_db.started_at = datetime.now().timestamp() session.commit() # Verify status update updated_task = await manager.get_task(task.id) assert updated_task.status == TaskStatus.PROCESSING.value assert updated_task.started_at is not None # Update status to SUCCEEDED with Session(engine) as session: task_db = session.get(TaskDB, task.id) task_db.status = TaskStatus.SUCCEEDED.value task_db.completed_at = datetime.now().timestamp() task_db.result = {"output": "test_result"} session.commit() # Verify final status final_task = await manager.get_task(task.id) assert final_task.status == TaskStatus.SUCCEEDED.value assert final_task.completed_at is not None assert final_task.result is not None finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_task_cancellation_updates_status( self, task_type, model, params ): """ Property: Task cancellation should update status correctly For any pending or processing task, cancellation should set status to CANCELLED """ # Initialize database init_db() # Create task manager manager = UnifiedTaskManager() try: # Create task task = await manager.create_task( task_type=task_type, model=model, params=params ) # Verify task is pending assert task.status == TaskStatus.PENDING.value # Cancel task cancelled = await manager.cancel_task(task.id) assert cancelled is True # Verify status is CANCELLED cancelled_task = await manager.get_task(task.id) assert cancelled_task.status == TaskStatus.CANCELLED.value assert cancelled_task.completed_at is not None # Verify already completed tasks cannot be cancelled with Session(engine) as session: task_db = session.get(TaskDB, task.id) task_db.status = TaskStatus.SUCCEEDED.value session.commit() # Try to cancel again cancelled_again = await manager.cancel_task(task.id) assert cancelled_again is False finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=10, deadline=None) @pytest.mark.asyncio async def test_nonexistent_task_operations_raise_exceptions( self, task_type, model, params ): """ Property: Operations on nonexistent tasks should raise appropriate exceptions For any nonexistent task ID, operations should fail with TaskNotFoundException """ # Initialize database init_db() # Create task manager manager = UnifiedTaskManager() # Generate a random task ID that doesn't exist nonexistent_id = f"nonexistent_{model}_{time.time()}" # Verify get_task returns None task = await manager.get_task(nonexistent_id) assert task is None # Verify cancel_task raises exception with pytest.raises(TaskNotFoundException): await manager.cancel_task(nonexistent_id) # Verify retry_task raises exception with pytest.raises(TaskNotFoundException): await manager.retry_task(nonexistent_id) # ============================================================================ # Property 2: Task Persistence and Recovery # ============================================================================ class TestProperty2TaskPersistenceRecovery: """ Property 2: 任务持久化恢复 验证重启后任务恢复 Validates: Requirements 2.3 """ @given( task_type=task_types(), model=model_ids(), params=task_params(), status=st.sampled_from([ TaskStatus.PENDING.value, TaskStatus.PROCESSING.value, TaskStatus.RETRYING.value ]) ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_pending_tasks_persisted_in_database( self, task_type, model, params, status ): """ Property: Pending tasks should be persisted in database For any task in non-terminal state, it should be stored in database """ # Initialize database init_db() try: # Create task directly in database (simulating existing task) task_db = TaskDB( type=task_type, model=model, params=params, status=status, retry_count=0, max_retries=3 ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id # Verify task is persisted with Session(engine) as session: retrieved_task = session.get(TaskDB, task_id) assert retrieved_task is not None assert retrieved_task.type == task_type assert retrieved_task.model == model assert retrieved_task.params == params assert retrieved_task.status == status finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_completed_tasks_remain_in_database( self, task_type, model, params ): """ Property: Completed tasks should remain in database For any task in terminal state, it should be stored in database """ # Initialize database init_db() try: # Create completed task in database task_db = TaskDB( type=task_type, model=model, params=params, status=TaskStatus.SUCCEEDED.value, retry_count=0, max_retries=3, completed_at=datetime.now().timestamp() ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id original_status = task_db.status # Verify task persists in database with Session(engine) as session: retrieved_task = session.get(TaskDB, task_id) assert retrieved_task is not None assert retrieved_task.status == original_status assert retrieved_task.completed_at is not None finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params(), retry_count=st.integers(min_value=0, max_value=5) ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_task_state_preserved_in_database( self, task_type, model, params, retry_count ): """ Property: Task state should be preserved in database For any task, all state information should be persisted """ # Initialize database init_db() try: # Create task with specific state task_db = TaskDB( type=task_type, model=model, params=params, status=TaskStatus.RETRYING.value, retry_count=retry_count, max_retries=3, error="Previous error" ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id # Verify all state is preserved in database with Session(engine) as session: retrieved_task = session.get(TaskDB, task_id) assert retrieved_task is not None assert retrieved_task.type == task_type assert retrieved_task.model == model assert retrieved_task.params == params assert retrieved_task.retry_count == retry_count assert retrieved_task.error == "Previous error" finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() # ============================================================================ # Property 3: Task Failure Recording and Retry # ============================================================================ class TestProperty3TaskFailureRecordingRetry: """ Property 3: 任务失败记录 验证失败任务的错误记录和重试 Validates: Requirements 2.4 """ @given( task_type=task_types(), model=model_ids(), params=task_params(), error_message=st.text(min_size=5, max_size=200, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs', 'Po'))) ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_failed_task_records_error_details( self, task_type, model, params, error_message ): """ Property: Failed tasks should record error details For any task failure, error message and details should be stored """ # Initialize database init_db() try: # Create task task_db = TaskDB( type=task_type, model=model, params=params, status=TaskStatus.PROCESSING.value, retry_count=0, max_retries=3 ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id # Simulate failure with Session(engine) as session: task_db = session.get(TaskDB, task_id) task_db.status = TaskStatus.FAILED.value task_db.error = error_message task_db.completed_at = datetime.now().timestamp() session.commit() # Verify error is recorded manager = UnifiedTaskManager() failed_task = await manager.get_task(task_id) assert failed_task is not None assert failed_task.status == TaskStatus.FAILED.value assert failed_task.error == error_message assert failed_task.completed_at is not None finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params(), max_retries=st.integers(min_value=1, max_value=5) ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_retry_increments_retry_count( self, task_type, model, params, max_retries ): """ Property: Retry should increment retry count correctly For any task retry, retry_count should increase by 1 """ # Initialize database init_db() try: # Create task task_db = TaskDB( type=task_type, model=model, params=params, status=TaskStatus.FAILED.value, retry_count=0, max_retries=max_retries, error="Initial failure" ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id # Retry task manager = UnifiedTaskManager() success = await manager.retry_task(task_id) assert success is True # Verify retry count is reset (manual retry resets count) retried_task = await manager.get_task(task_id) assert retried_task is not None assert retried_task.status == TaskStatus.PENDING.value assert retried_task.retry_count == 0 # Manual retry resets count assert retried_task.error is None # Error cleared on retry finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=5, deadline=5000) @pytest.mark.asyncio async def test_retry_history_preserved_in_result( self, task_type, model, params ): """ Property: Retry history should be preserved in task result For any task with retries, retry history should be stored """ # Initialize database init_db() try: # Create task with retry history retry_history = [ { 'attempt': 1, 'timestamp': datetime.now().timestamp(), 'duration': 5.2, 'error': 'First failure', 'error_type': 'GenerationFailedException' }, { 'attempt': 2, 'timestamp': datetime.now().timestamp(), 'duration': 3.8, 'error': 'Second failure', 'error_type': 'TimeoutError' } ] task_db = TaskDB( type=task_type, model=model, params=params, status=TaskStatus.FAILED.value, retry_count=2, max_retries=3, result={'retry_history': retry_history}, error="Final failure" ) with Session(engine) as session: session.add(task_db) session.commit() session.refresh(task_db) task_id = task_db.id # Verify retry history is preserved manager = UnifiedTaskManager() task = await manager.get_task(task_id) assert task is not None assert task.result is not None assert 'retry_history' in task.result assert len(task.result['retry_history']) == 2 assert task.result['retry_history'][0]['attempt'] == 1 assert task.result['retry_history'][1]['attempt'] == 2 finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() @given( task_type=task_types(), model=model_ids(), params=task_params() ) @settings(max_examples=3, deadline=5000) @pytest.mark.asyncio async def test_only_failed_tasks_can_be_retried( self, task_type, model, params ): """ Property: Only failed or timeout tasks can be manually retried For any task not in FAILED or TIMEOUT state, retry should return False """ # Initialize database init_db() manager = UnifiedTaskManager() try: # Test with PENDING task task = await manager.create_task( task_type=task_type, model=model, params=params ) # Try to retry pending task success = await manager.retry_task(task.id) assert success is False # Verify status unchanged task_after = await manager.get_task(task.id) assert task_after.status == TaskStatus.PENDING.value # Update to SUCCEEDED with Session(engine) as session: task_db = session.get(TaskDB, task.id) task_db.status = TaskStatus.SUCCEEDED.value task_db.completed_at = datetime.now().timestamp() session.commit() # Try to retry succeeded task success = await manager.retry_task(task.id) assert success is False # Update to FAILED with Session(engine) as session: task_db = session.get(TaskDB, task.id) task_db.status = TaskStatus.FAILED.value session.commit() # Now retry should work success = await manager.retry_task(task.id) assert success is True finally: # Cleanup with Session(engine) as session: statement = select(TaskDB).where(TaskDB.model == model) tasks = session.exec(statement).all() for t in tasks: session.delete(t) session.commit() if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])