Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
This commit is contained in:
814
backend/tests/test_task_management_properties.py
Normal file
814
backend/tests/test_task_management_properties.py
Normal file
@@ -0,0 +1,814 @@
|
||||
"""
|
||||
Property-Based Tests for Task Management System
|
||||
|
||||
This module contains property-based tests that verify correctness properties
|
||||
of the task management system across all possible inputs.
|
||||
|
||||
Properties tested:
|
||||
- Property 1: Task management completeness (creation, tracking, update, cancellation)
|
||||
- Property 2: Task persistence and recovery after restart
|
||||
- Property 3: Task failure recording and retry mechanism
|
||||
|
||||
Requirements: 2.2, 2.3, 2.4
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
from unittest.mock import Mock, patch, AsyncMock
|
||||
from hypothesis import given, strategies as st, assume, settings
|
||||
from hypothesis.strategies import composite
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from src.config.database import engine, init_db
|
||||
from src.models.entities import TaskDB
|
||||
from src.models.schemas import Task
|
||||
from src.services.task_manager import (
|
||||
UnifiedTaskManager,
|
||||
TaskPriority,
|
||||
TaskConfig,
|
||||
TaskItem
|
||||
)
|
||||
from src.services.provider.base import TaskStatus
|
||||
from src.utils.errors import (
|
||||
TaskQueueFullException,
|
||||
TaskNotFoundException,
|
||||
GenerationFailedException
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Hypothesis Strategies for Generating Test Data
|
||||
# ============================================================================
|
||||
|
||||
@composite
|
||||
def task_types(draw):
|
||||
"""Generate valid task types"""
|
||||
return draw(st.sampled_from(["image", "video"]))
|
||||
|
||||
|
||||
@composite
|
||||
def task_statuses(draw):
|
||||
"""Generate valid task statuses"""
|
||||
return draw(st.sampled_from([
|
||||
TaskStatus.PENDING.value,
|
||||
TaskStatus.PROCESSING.value,
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.TIMEOUT.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
TaskStatus.RETRYING.value
|
||||
]))
|
||||
|
||||
|
||||
@composite
|
||||
def task_priorities(draw):
|
||||
"""Generate valid task priorities"""
|
||||
return draw(st.sampled_from(list(TaskPriority)))
|
||||
|
||||
|
||||
@composite
|
||||
def task_params(draw):
|
||||
"""Generate task parameters"""
|
||||
# Generate simple dictionaries with string keys and various value types
|
||||
keys = draw(st.lists(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))), min_size=1, max_size=5, unique=True))
|
||||
values = []
|
||||
for _ in keys:
|
||||
value = draw(st.one_of(
|
||||
st.text(min_size=1, max_size=100, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs'))),
|
||||
st.integers(min_value=1, max_value=1000),
|
||||
st.floats(min_value=0.1, max_value=100.0, allow_nan=False, allow_infinity=False),
|
||||
st.booleans()
|
||||
))
|
||||
values.append(value)
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
@composite
|
||||
def model_ids(draw):
|
||||
"""Generate model IDs"""
|
||||
return draw(st.text(min_size=3, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd'))))
|
||||
|
||||
|
||||
@composite
|
||||
def user_ids(draw):
|
||||
"""Generate user IDs"""
|
||||
return draw(st.one_of(
|
||||
st.none(),
|
||||
st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd')))
|
||||
))
|
||||
|
||||
|
||||
@composite
|
||||
def project_ids(draw):
|
||||
"""Generate project IDs"""
|
||||
return draw(st.one_of(
|
||||
st.none(),
|
||||
st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd')))
|
||||
))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 1: Task Management Completeness
|
||||
# ============================================================================
|
||||
|
||||
class TestProperty1TaskManagementCompleteness:
|
||||
"""
|
||||
Property 1: 任务管理完整性
|
||||
|
||||
验证任务创建、跟踪、更新、取消操作
|
||||
Validates: Requirements 2.2
|
||||
"""
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params(),
|
||||
priority=task_priorities(),
|
||||
user_id=user_ids(),
|
||||
project_id=project_ids()
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000) # 5 second deadline per example
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_creation_stores_all_information(
|
||||
self, task_type, model, params, priority, user_id, project_id
|
||||
):
|
||||
"""
|
||||
Property: Task creation should store all provided information correctly
|
||||
|
||||
For any valid task parameters, the created task should preserve all information
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
# Create task manager
|
||||
manager = UnifiedTaskManager()
|
||||
|
||||
try:
|
||||
# Create task
|
||||
task = await manager.create_task(
|
||||
task_type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
priority=priority,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
|
||||
# Verify task was created
|
||||
assert task is not None
|
||||
assert task.id is not None
|
||||
|
||||
# Verify all information is preserved
|
||||
assert task.type == task_type
|
||||
assert task.model == model
|
||||
assert task.params == params
|
||||
assert task.user_id == user_id
|
||||
assert task.project_id == project_id
|
||||
|
||||
# Verify initial status
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
assert task.retry_count == 0
|
||||
|
||||
# Verify timestamps
|
||||
assert task.created_at is not None
|
||||
assert task.updated_at is not None
|
||||
assert task.started_at is None
|
||||
assert task.completed_at is None
|
||||
|
||||
# Verify task can be retrieved
|
||||
retrieved_task = await manager.get_task(task.id)
|
||||
assert retrieved_task is not None
|
||||
assert retrieved_task.id == task.id
|
||||
assert retrieved_task.type == task_type
|
||||
assert retrieved_task.model == model
|
||||
|
||||
finally:
|
||||
# Cleanup: remove task from database
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_status_tracking_through_lifecycle(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Task status should be trackable through its entire lifecycle
|
||||
|
||||
For any task, status updates should be reflected in the database
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
# Create task manager
|
||||
manager = UnifiedTaskManager()
|
||||
|
||||
try:
|
||||
# Create task
|
||||
task = await manager.create_task(
|
||||
task_type=task_type,
|
||||
model=model,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Verify initial status
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
|
||||
# Update status to PROCESSING
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task.id)
|
||||
assert task_db is not None
|
||||
task_db.status = TaskStatus.PROCESSING.value
|
||||
task_db.started_at = datetime.now().timestamp()
|
||||
session.commit()
|
||||
|
||||
# Verify status update
|
||||
updated_task = await manager.get_task(task.id)
|
||||
assert updated_task.status == TaskStatus.PROCESSING.value
|
||||
assert updated_task.started_at is not None
|
||||
|
||||
# Update status to SUCCEEDED
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task.id)
|
||||
task_db.status = TaskStatus.SUCCEEDED.value
|
||||
task_db.completed_at = datetime.now().timestamp()
|
||||
task_db.result = {"output": "test_result"}
|
||||
session.commit()
|
||||
|
||||
# Verify final status
|
||||
final_task = await manager.get_task(task.id)
|
||||
assert final_task.status == TaskStatus.SUCCEEDED.value
|
||||
assert final_task.completed_at is not None
|
||||
assert final_task.result is not None
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_cancellation_updates_status(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Task cancellation should update status correctly
|
||||
|
||||
For any pending or processing task, cancellation should set status to CANCELLED
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
# Create task manager
|
||||
manager = UnifiedTaskManager()
|
||||
|
||||
try:
|
||||
# Create task
|
||||
task = await manager.create_task(
|
||||
task_type=task_type,
|
||||
model=model,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Verify task is pending
|
||||
assert task.status == TaskStatus.PENDING.value
|
||||
|
||||
# Cancel task
|
||||
cancelled = await manager.cancel_task(task.id)
|
||||
assert cancelled is True
|
||||
|
||||
# Verify status is CANCELLED
|
||||
cancelled_task = await manager.get_task(task.id)
|
||||
assert cancelled_task.status == TaskStatus.CANCELLED.value
|
||||
assert cancelled_task.completed_at is not None
|
||||
|
||||
# Verify already completed tasks cannot be cancelled
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task.id)
|
||||
task_db.status = TaskStatus.SUCCEEDED.value
|
||||
session.commit()
|
||||
|
||||
# Try to cancel again
|
||||
cancelled_again = await manager.cancel_task(task.id)
|
||||
assert cancelled_again is False
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=10, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_task_operations_raise_exceptions(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Operations on nonexistent tasks should raise appropriate exceptions
|
||||
|
||||
For any nonexistent task ID, operations should fail with TaskNotFoundException
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
# Create task manager
|
||||
manager = UnifiedTaskManager()
|
||||
|
||||
# Generate a random task ID that doesn't exist
|
||||
nonexistent_id = f"nonexistent_{model}_{time.time()}"
|
||||
|
||||
# Verify get_task returns None
|
||||
task = await manager.get_task(nonexistent_id)
|
||||
assert task is None
|
||||
|
||||
# Verify cancel_task raises exception
|
||||
with pytest.raises(TaskNotFoundException):
|
||||
await manager.cancel_task(nonexistent_id)
|
||||
|
||||
# Verify retry_task raises exception
|
||||
with pytest.raises(TaskNotFoundException):
|
||||
await manager.retry_task(nonexistent_id)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 2: Task Persistence and Recovery
|
||||
# ============================================================================
|
||||
|
||||
class TestProperty2TaskPersistenceRecovery:
|
||||
"""
|
||||
Property 2: 任务持久化恢复
|
||||
|
||||
验证重启后任务恢复
|
||||
Validates: Requirements 2.3
|
||||
"""
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params(),
|
||||
status=st.sampled_from([
|
||||
TaskStatus.PENDING.value,
|
||||
TaskStatus.PROCESSING.value,
|
||||
TaskStatus.RETRYING.value
|
||||
])
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_tasks_persisted_in_database(
|
||||
self, task_type, model, params, status
|
||||
):
|
||||
"""
|
||||
Property: Pending tasks should be persisted in database
|
||||
|
||||
For any task in non-terminal state, it should be stored in database
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create task directly in database (simulating existing task)
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=status,
|
||||
retry_count=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
|
||||
# Verify task is persisted
|
||||
with Session(engine) as session:
|
||||
retrieved_task = session.get(TaskDB, task_id)
|
||||
assert retrieved_task is not None
|
||||
assert retrieved_task.type == task_type
|
||||
assert retrieved_task.model == model
|
||||
assert retrieved_task.params == params
|
||||
assert retrieved_task.status == status
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_completed_tasks_remain_in_database(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Completed tasks should remain in database
|
||||
|
||||
For any task in terminal state, it should be stored in database
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create completed task in database
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=TaskStatus.SUCCEEDED.value,
|
||||
retry_count=0,
|
||||
max_retries=3,
|
||||
completed_at=datetime.now().timestamp()
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
original_status = task_db.status
|
||||
|
||||
# Verify task persists in database
|
||||
with Session(engine) as session:
|
||||
retrieved_task = session.get(TaskDB, task_id)
|
||||
assert retrieved_task is not None
|
||||
assert retrieved_task.status == original_status
|
||||
assert retrieved_task.completed_at is not None
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params(),
|
||||
retry_count=st.integers(min_value=0, max_value=5)
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_task_state_preserved_in_database(
|
||||
self, task_type, model, params, retry_count
|
||||
):
|
||||
"""
|
||||
Property: Task state should be preserved in database
|
||||
|
||||
For any task, all state information should be persisted
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create task with specific state
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=TaskStatus.RETRYING.value,
|
||||
retry_count=retry_count,
|
||||
max_retries=3,
|
||||
error="Previous error"
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
|
||||
# Verify all state is preserved in database
|
||||
with Session(engine) as session:
|
||||
retrieved_task = session.get(TaskDB, task_id)
|
||||
assert retrieved_task is not None
|
||||
assert retrieved_task.type == task_type
|
||||
assert retrieved_task.model == model
|
||||
assert retrieved_task.params == params
|
||||
assert retrieved_task.retry_count == retry_count
|
||||
assert retrieved_task.error == "Previous error"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 3: Task Failure Recording and Retry
|
||||
# ============================================================================
|
||||
|
||||
class TestProperty3TaskFailureRecordingRetry:
|
||||
"""
|
||||
Property 3: 任务失败记录
|
||||
|
||||
验证失败任务的错误记录和重试
|
||||
Validates: Requirements 2.4
|
||||
"""
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params(),
|
||||
error_message=st.text(min_size=5, max_size=200, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs', 'Po')))
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_task_records_error_details(
|
||||
self, task_type, model, params, error_message
|
||||
):
|
||||
"""
|
||||
Property: Failed tasks should record error details
|
||||
|
||||
For any task failure, error message and details should be stored
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create task
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=TaskStatus.PROCESSING.value,
|
||||
retry_count=0,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
|
||||
# Simulate failure
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task_id)
|
||||
task_db.status = TaskStatus.FAILED.value
|
||||
task_db.error = error_message
|
||||
task_db.completed_at = datetime.now().timestamp()
|
||||
session.commit()
|
||||
|
||||
# Verify error is recorded
|
||||
manager = UnifiedTaskManager()
|
||||
failed_task = await manager.get_task(task_id)
|
||||
|
||||
assert failed_task is not None
|
||||
assert failed_task.status == TaskStatus.FAILED.value
|
||||
assert failed_task.error == error_message
|
||||
assert failed_task.completed_at is not None
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params(),
|
||||
max_retries=st.integers(min_value=1, max_value=5)
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_increments_retry_count(
|
||||
self, task_type, model, params, max_retries
|
||||
):
|
||||
"""
|
||||
Property: Retry should increment retry count correctly
|
||||
|
||||
For any task retry, retry_count should increase by 1
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create task
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=TaskStatus.FAILED.value,
|
||||
retry_count=0,
|
||||
max_retries=max_retries,
|
||||
error="Initial failure"
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
|
||||
# Retry task
|
||||
manager = UnifiedTaskManager()
|
||||
success = await manager.retry_task(task_id)
|
||||
assert success is True
|
||||
|
||||
# Verify retry count is reset (manual retry resets count)
|
||||
retried_task = await manager.get_task(task_id)
|
||||
assert retried_task is not None
|
||||
assert retried_task.status == TaskStatus.PENDING.value
|
||||
assert retried_task.retry_count == 0 # Manual retry resets count
|
||||
assert retried_task.error is None # Error cleared on retry
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=5, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_history_preserved_in_result(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Retry history should be preserved in task result
|
||||
|
||||
For any task with retries, retry history should be stored
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
try:
|
||||
# Create task with retry history
|
||||
retry_history = [
|
||||
{
|
||||
'attempt': 1,
|
||||
'timestamp': datetime.now().timestamp(),
|
||||
'duration': 5.2,
|
||||
'error': 'First failure',
|
||||
'error_type': 'GenerationFailedException'
|
||||
},
|
||||
{
|
||||
'attempt': 2,
|
||||
'timestamp': datetime.now().timestamp(),
|
||||
'duration': 3.8,
|
||||
'error': 'Second failure',
|
||||
'error_type': 'TimeoutError'
|
||||
}
|
||||
]
|
||||
|
||||
task_db = TaskDB(
|
||||
type=task_type,
|
||||
model=model,
|
||||
params=params,
|
||||
status=TaskStatus.FAILED.value,
|
||||
retry_count=2,
|
||||
max_retries=3,
|
||||
result={'retry_history': retry_history},
|
||||
error="Final failure"
|
||||
)
|
||||
|
||||
with Session(engine) as session:
|
||||
session.add(task_db)
|
||||
session.commit()
|
||||
session.refresh(task_db)
|
||||
task_id = task_db.id
|
||||
|
||||
# Verify retry history is preserved
|
||||
manager = UnifiedTaskManager()
|
||||
task = await manager.get_task(task_id)
|
||||
|
||||
assert task is not None
|
||||
assert task.result is not None
|
||||
assert 'retry_history' in task.result
|
||||
assert len(task.result['retry_history']) == 2
|
||||
assert task.result['retry_history'][0]['attempt'] == 1
|
||||
assert task.result['retry_history'][1]['attempt'] == 2
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
@given(
|
||||
task_type=task_types(),
|
||||
model=model_ids(),
|
||||
params=task_params()
|
||||
)
|
||||
@settings(max_examples=3, deadline=5000)
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_failed_tasks_can_be_retried(
|
||||
self, task_type, model, params
|
||||
):
|
||||
"""
|
||||
Property: Only failed or timeout tasks can be manually retried
|
||||
|
||||
For any task not in FAILED or TIMEOUT state, retry should return False
|
||||
"""
|
||||
# Initialize database
|
||||
init_db()
|
||||
|
||||
manager = UnifiedTaskManager()
|
||||
|
||||
try:
|
||||
# Test with PENDING task
|
||||
task = await manager.create_task(
|
||||
task_type=task_type,
|
||||
model=model,
|
||||
params=params
|
||||
)
|
||||
|
||||
# Try to retry pending task
|
||||
success = await manager.retry_task(task.id)
|
||||
assert success is False
|
||||
|
||||
# Verify status unchanged
|
||||
task_after = await manager.get_task(task.id)
|
||||
assert task_after.status == TaskStatus.PENDING.value
|
||||
|
||||
# Update to SUCCEEDED
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task.id)
|
||||
task_db.status = TaskStatus.SUCCEEDED.value
|
||||
task_db.completed_at = datetime.now().timestamp()
|
||||
session.commit()
|
||||
|
||||
# Try to retry succeeded task
|
||||
success = await manager.retry_task(task.id)
|
||||
assert success is False
|
||||
|
||||
# Update to FAILED
|
||||
with Session(engine) as session:
|
||||
task_db = session.get(TaskDB, task.id)
|
||||
task_db.status = TaskStatus.FAILED.value
|
||||
session.commit()
|
||||
|
||||
# Now retry should work
|
||||
success = await manager.retry_task(task.id)
|
||||
assert success is True
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
with Session(engine) as session:
|
||||
statement = select(TaskDB).where(TaskDB.model == model)
|
||||
tasks = session.exec(statement).all()
|
||||
for t in tasks:
|
||||
session.delete(t)
|
||||
session.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user