- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
815 lines
27 KiB
Python
815 lines
27 KiB
Python
"""
|
|
Property-Based Tests for Task Management System
|
|
|
|
This module contains property-based tests that verify correctness properties
|
|
of the task management system across all possible inputs.
|
|
|
|
Properties tested:
|
|
- Property 1: Task management completeness (creation, tracking, update, cancellation)
|
|
- Property 2: Task persistence and recovery after restart
|
|
- Property 3: Task failure recording and retry mechanism
|
|
|
|
Requirements: 2.2, 2.3, 2.4
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
import time
|
|
from datetime import datetime
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from hypothesis import given, strategies as st, assume, settings
|
|
from hypothesis.strategies import composite
|
|
from sqlmodel import Session, select
|
|
|
|
from src.config.database import engine, init_db
|
|
from src.models.entities import TaskDB
|
|
from src.models.schemas import Task
|
|
from src.services.task_manager import (
|
|
UnifiedTaskManager,
|
|
TaskPriority,
|
|
TaskConfig,
|
|
TaskItem
|
|
)
|
|
from src.services.provider.base import TaskStatus
|
|
from src.utils.errors import (
|
|
TaskQueueFullException,
|
|
TaskNotFoundException,
|
|
GenerationFailedException
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# Hypothesis Strategies for Generating Test Data
|
|
# ============================================================================
|
|
|
|
@composite
|
|
def task_types(draw):
|
|
"""Generate valid task types"""
|
|
return draw(st.sampled_from(["image", "video"]))
|
|
|
|
|
|
@composite
|
|
def task_statuses(draw):
|
|
"""Generate valid task statuses"""
|
|
return draw(st.sampled_from([
|
|
TaskStatus.PENDING.value,
|
|
TaskStatus.PROCESSING.value,
|
|
TaskStatus.SUCCEEDED.value,
|
|
TaskStatus.FAILED.value,
|
|
TaskStatus.TIMEOUT.value,
|
|
TaskStatus.CANCELLED.value,
|
|
TaskStatus.RETRYING.value
|
|
]))
|
|
|
|
|
|
@composite
|
|
def task_priorities(draw):
|
|
"""Generate valid task priorities"""
|
|
return draw(st.sampled_from(list(TaskPriority)))
|
|
|
|
|
|
@composite
|
|
def task_params(draw):
|
|
"""Generate task parameters"""
|
|
# Generate simple dictionaries with string keys and various value types
|
|
keys = draw(st.lists(st.text(min_size=1, max_size=20, alphabet=st.characters(whitelist_categories=('Lu', 'Ll'))), min_size=1, max_size=5, unique=True))
|
|
values = []
|
|
for _ in keys:
|
|
value = draw(st.one_of(
|
|
st.text(min_size=1, max_size=100, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs'))),
|
|
st.integers(min_value=1, max_value=1000),
|
|
st.floats(min_value=0.1, max_value=100.0, allow_nan=False, allow_infinity=False),
|
|
st.booleans()
|
|
))
|
|
values.append(value)
|
|
return dict(zip(keys, values))
|
|
|
|
|
|
@composite
|
|
def model_ids(draw):
|
|
"""Generate model IDs"""
|
|
return draw(st.text(min_size=3, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd'))))
|
|
|
|
|
|
@composite
|
|
def user_ids(draw):
|
|
"""Generate user IDs"""
|
|
return draw(st.one_of(
|
|
st.none(),
|
|
st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd')))
|
|
))
|
|
|
|
|
|
@composite
|
|
def project_ids(draw):
|
|
"""Generate project IDs"""
|
|
return draw(st.one_of(
|
|
st.none(),
|
|
st.text(min_size=5, max_size=50, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Pd')))
|
|
))
|
|
|
|
|
|
# ============================================================================
|
|
# Property 1: Task Management Completeness
|
|
# ============================================================================
|
|
|
|
class TestProperty1TaskManagementCompleteness:
|
|
"""
|
|
Property 1: 任务管理完整性
|
|
|
|
验证任务创建、跟踪、更新、取消操作
|
|
Validates: Requirements 2.2
|
|
"""
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params(),
|
|
priority=task_priorities(),
|
|
user_id=user_ids(),
|
|
project_id=project_ids()
|
|
)
|
|
@settings(max_examples=5, deadline=5000) # 5 second deadline per example
|
|
@pytest.mark.asyncio
|
|
async def test_task_creation_stores_all_information(
|
|
self, task_type, model, params, priority, user_id, project_id
|
|
):
|
|
"""
|
|
Property: Task creation should store all provided information correctly
|
|
|
|
For any valid task parameters, the created task should preserve all information
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
# Create task manager
|
|
manager = UnifiedTaskManager()
|
|
|
|
try:
|
|
# Create task
|
|
task = await manager.create_task(
|
|
task_type=task_type,
|
|
model=model,
|
|
params=params,
|
|
priority=priority,
|
|
user_id=user_id,
|
|
project_id=project_id
|
|
)
|
|
|
|
# Verify task was created
|
|
assert task is not None
|
|
assert task.id is not None
|
|
|
|
# Verify all information is preserved
|
|
assert task.type == task_type
|
|
assert task.model == model
|
|
assert task.params == params
|
|
assert task.user_id == user_id
|
|
assert task.project_id == project_id
|
|
|
|
# Verify initial status
|
|
assert task.status == TaskStatus.PENDING.value
|
|
assert task.retry_count == 0
|
|
|
|
# Verify timestamps
|
|
assert task.created_at is not None
|
|
assert task.updated_at is not None
|
|
assert task.started_at is None
|
|
assert task.completed_at is None
|
|
|
|
# Verify task can be retrieved
|
|
retrieved_task = await manager.get_task(task.id)
|
|
assert retrieved_task is not None
|
|
assert retrieved_task.id == task.id
|
|
assert retrieved_task.type == task_type
|
|
assert retrieved_task.model == model
|
|
|
|
finally:
|
|
# Cleanup: remove task from database
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_task_status_tracking_through_lifecycle(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Task status should be trackable through its entire lifecycle
|
|
|
|
For any task, status updates should be reflected in the database
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
# Create task manager
|
|
manager = UnifiedTaskManager()
|
|
|
|
try:
|
|
# Create task
|
|
task = await manager.create_task(
|
|
task_type=task_type,
|
|
model=model,
|
|
params=params
|
|
)
|
|
|
|
# Verify initial status
|
|
assert task.status == TaskStatus.PENDING.value
|
|
|
|
# Update status to PROCESSING
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task.id)
|
|
assert task_db is not None
|
|
task_db.status = TaskStatus.PROCESSING.value
|
|
task_db.started_at = datetime.now().timestamp()
|
|
session.commit()
|
|
|
|
# Verify status update
|
|
updated_task = await manager.get_task(task.id)
|
|
assert updated_task.status == TaskStatus.PROCESSING.value
|
|
assert updated_task.started_at is not None
|
|
|
|
# Update status to SUCCEEDED
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task.id)
|
|
task_db.status = TaskStatus.SUCCEEDED.value
|
|
task_db.completed_at = datetime.now().timestamp()
|
|
task_db.result = {"output": "test_result"}
|
|
session.commit()
|
|
|
|
# Verify final status
|
|
final_task = await manager.get_task(task.id)
|
|
assert final_task.status == TaskStatus.SUCCEEDED.value
|
|
assert final_task.completed_at is not None
|
|
assert final_task.result is not None
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_task_cancellation_updates_status(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Task cancellation should update status correctly
|
|
|
|
For any pending or processing task, cancellation should set status to CANCELLED
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
# Create task manager
|
|
manager = UnifiedTaskManager()
|
|
|
|
try:
|
|
# Create task
|
|
task = await manager.create_task(
|
|
task_type=task_type,
|
|
model=model,
|
|
params=params
|
|
)
|
|
|
|
# Verify task is pending
|
|
assert task.status == TaskStatus.PENDING.value
|
|
|
|
# Cancel task
|
|
cancelled = await manager.cancel_task(task.id)
|
|
assert cancelled is True
|
|
|
|
# Verify status is CANCELLED
|
|
cancelled_task = await manager.get_task(task.id)
|
|
assert cancelled_task.status == TaskStatus.CANCELLED.value
|
|
assert cancelled_task.completed_at is not None
|
|
|
|
# Verify already completed tasks cannot be cancelled
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task.id)
|
|
task_db.status = TaskStatus.SUCCEEDED.value
|
|
session.commit()
|
|
|
|
# Try to cancel again
|
|
cancelled_again = await manager.cancel_task(task.id)
|
|
assert cancelled_again is False
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=10, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_nonexistent_task_operations_raise_exceptions(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Operations on nonexistent tasks should raise appropriate exceptions
|
|
|
|
For any nonexistent task ID, operations should fail with TaskNotFoundException
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
# Create task manager
|
|
manager = UnifiedTaskManager()
|
|
|
|
# Generate a random task ID that doesn't exist
|
|
nonexistent_id = f"nonexistent_{model}_{time.time()}"
|
|
|
|
# Verify get_task returns None
|
|
task = await manager.get_task(nonexistent_id)
|
|
assert task is None
|
|
|
|
# Verify cancel_task raises exception
|
|
with pytest.raises(TaskNotFoundException):
|
|
await manager.cancel_task(nonexistent_id)
|
|
|
|
# Verify retry_task raises exception
|
|
with pytest.raises(TaskNotFoundException):
|
|
await manager.retry_task(nonexistent_id)
|
|
|
|
|
|
# ============================================================================
|
|
# Property 2: Task Persistence and Recovery
|
|
# ============================================================================
|
|
|
|
class TestProperty2TaskPersistenceRecovery:
|
|
"""
|
|
Property 2: 任务持久化恢复
|
|
|
|
验证重启后任务恢复
|
|
Validates: Requirements 2.3
|
|
"""
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params(),
|
|
status=st.sampled_from([
|
|
TaskStatus.PENDING.value,
|
|
TaskStatus.PROCESSING.value,
|
|
TaskStatus.RETRYING.value
|
|
])
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_pending_tasks_persisted_in_database(
|
|
self, task_type, model, params, status
|
|
):
|
|
"""
|
|
Property: Pending tasks should be persisted in database
|
|
|
|
For any task in non-terminal state, it should be stored in database
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create task directly in database (simulating existing task)
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=status,
|
|
retry_count=0,
|
|
max_retries=3
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
|
|
# Verify task is persisted
|
|
with Session(engine) as session:
|
|
retrieved_task = session.get(TaskDB, task_id)
|
|
assert retrieved_task is not None
|
|
assert retrieved_task.type == task_type
|
|
assert retrieved_task.model == model
|
|
assert retrieved_task.params == params
|
|
assert retrieved_task.status == status
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_completed_tasks_remain_in_database(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Completed tasks should remain in database
|
|
|
|
For any task in terminal state, it should be stored in database
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create completed task in database
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=TaskStatus.SUCCEEDED.value,
|
|
retry_count=0,
|
|
max_retries=3,
|
|
completed_at=datetime.now().timestamp()
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
original_status = task_db.status
|
|
|
|
# Verify task persists in database
|
|
with Session(engine) as session:
|
|
retrieved_task = session.get(TaskDB, task_id)
|
|
assert retrieved_task is not None
|
|
assert retrieved_task.status == original_status
|
|
assert retrieved_task.completed_at is not None
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params(),
|
|
retry_count=st.integers(min_value=0, max_value=5)
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_task_state_preserved_in_database(
|
|
self, task_type, model, params, retry_count
|
|
):
|
|
"""
|
|
Property: Task state should be preserved in database
|
|
|
|
For any task, all state information should be persisted
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create task with specific state
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=TaskStatus.RETRYING.value,
|
|
retry_count=retry_count,
|
|
max_retries=3,
|
|
error="Previous error"
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
|
|
# Verify all state is preserved in database
|
|
with Session(engine) as session:
|
|
retrieved_task = session.get(TaskDB, task_id)
|
|
assert retrieved_task is not None
|
|
assert retrieved_task.type == task_type
|
|
assert retrieved_task.model == model
|
|
assert retrieved_task.params == params
|
|
assert retrieved_task.retry_count == retry_count
|
|
assert retrieved_task.error == "Previous error"
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
|
|
# ============================================================================
|
|
# Property 3: Task Failure Recording and Retry
|
|
# ============================================================================
|
|
|
|
class TestProperty3TaskFailureRecordingRetry:
|
|
"""
|
|
Property 3: 任务失败记录
|
|
|
|
验证失败任务的错误记录和重试
|
|
Validates: Requirements 2.4
|
|
"""
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params(),
|
|
error_message=st.text(min_size=5, max_size=200, alphabet=st.characters(whitelist_categories=('Lu', 'Ll', 'Nd', 'Zs', 'Po')))
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_failed_task_records_error_details(
|
|
self, task_type, model, params, error_message
|
|
):
|
|
"""
|
|
Property: Failed tasks should record error details
|
|
|
|
For any task failure, error message and details should be stored
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create task
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=TaskStatus.PROCESSING.value,
|
|
retry_count=0,
|
|
max_retries=3
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
|
|
# Simulate failure
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task_id)
|
|
task_db.status = TaskStatus.FAILED.value
|
|
task_db.error = error_message
|
|
task_db.completed_at = datetime.now().timestamp()
|
|
session.commit()
|
|
|
|
# Verify error is recorded
|
|
manager = UnifiedTaskManager()
|
|
failed_task = await manager.get_task(task_id)
|
|
|
|
assert failed_task is not None
|
|
assert failed_task.status == TaskStatus.FAILED.value
|
|
assert failed_task.error == error_message
|
|
assert failed_task.completed_at is not None
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params(),
|
|
max_retries=st.integers(min_value=1, max_value=5)
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_retry_increments_retry_count(
|
|
self, task_type, model, params, max_retries
|
|
):
|
|
"""
|
|
Property: Retry should increment retry count correctly
|
|
|
|
For any task retry, retry_count should increase by 1
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create task
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=TaskStatus.FAILED.value,
|
|
retry_count=0,
|
|
max_retries=max_retries,
|
|
error="Initial failure"
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
|
|
# Retry task
|
|
manager = UnifiedTaskManager()
|
|
success = await manager.retry_task(task_id)
|
|
assert success is True
|
|
|
|
# Verify retry count is reset (manual retry resets count)
|
|
retried_task = await manager.get_task(task_id)
|
|
assert retried_task is not None
|
|
assert retried_task.status == TaskStatus.PENDING.value
|
|
assert retried_task.retry_count == 0 # Manual retry resets count
|
|
assert retried_task.error is None # Error cleared on retry
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=5, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_retry_history_preserved_in_result(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Retry history should be preserved in task result
|
|
|
|
For any task with retries, retry history should be stored
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
try:
|
|
# Create task with retry history
|
|
retry_history = [
|
|
{
|
|
'attempt': 1,
|
|
'timestamp': datetime.now().timestamp(),
|
|
'duration': 5.2,
|
|
'error': 'First failure',
|
|
'error_type': 'GenerationFailedException'
|
|
},
|
|
{
|
|
'attempt': 2,
|
|
'timestamp': datetime.now().timestamp(),
|
|
'duration': 3.8,
|
|
'error': 'Second failure',
|
|
'error_type': 'TimeoutError'
|
|
}
|
|
]
|
|
|
|
task_db = TaskDB(
|
|
type=task_type,
|
|
model=model,
|
|
params=params,
|
|
status=TaskStatus.FAILED.value,
|
|
retry_count=2,
|
|
max_retries=3,
|
|
result={'retry_history': retry_history},
|
|
error="Final failure"
|
|
)
|
|
|
|
with Session(engine) as session:
|
|
session.add(task_db)
|
|
session.commit()
|
|
session.refresh(task_db)
|
|
task_id = task_db.id
|
|
|
|
# Verify retry history is preserved
|
|
manager = UnifiedTaskManager()
|
|
task = await manager.get_task(task_id)
|
|
|
|
assert task is not None
|
|
assert task.result is not None
|
|
assert 'retry_history' in task.result
|
|
assert len(task.result['retry_history']) == 2
|
|
assert task.result['retry_history'][0]['attempt'] == 1
|
|
assert task.result['retry_history'][1]['attempt'] == 2
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
@given(
|
|
task_type=task_types(),
|
|
model=model_ids(),
|
|
params=task_params()
|
|
)
|
|
@settings(max_examples=3, deadline=5000)
|
|
@pytest.mark.asyncio
|
|
async def test_only_failed_tasks_can_be_retried(
|
|
self, task_type, model, params
|
|
):
|
|
"""
|
|
Property: Only failed or timeout tasks can be manually retried
|
|
|
|
For any task not in FAILED or TIMEOUT state, retry should return False
|
|
"""
|
|
# Initialize database
|
|
init_db()
|
|
|
|
manager = UnifiedTaskManager()
|
|
|
|
try:
|
|
# Test with PENDING task
|
|
task = await manager.create_task(
|
|
task_type=task_type,
|
|
model=model,
|
|
params=params
|
|
)
|
|
|
|
# Try to retry pending task
|
|
success = await manager.retry_task(task.id)
|
|
assert success is False
|
|
|
|
# Verify status unchanged
|
|
task_after = await manager.get_task(task.id)
|
|
assert task_after.status == TaskStatus.PENDING.value
|
|
|
|
# Update to SUCCEEDED
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task.id)
|
|
task_db.status = TaskStatus.SUCCEEDED.value
|
|
task_db.completed_at = datetime.now().timestamp()
|
|
session.commit()
|
|
|
|
# Try to retry succeeded task
|
|
success = await manager.retry_task(task.id)
|
|
assert success is False
|
|
|
|
# Update to FAILED
|
|
with Session(engine) as session:
|
|
task_db = session.get(TaskDB, task.id)
|
|
task_db.status = TaskStatus.FAILED.value
|
|
session.commit()
|
|
|
|
# Now retry should work
|
|
success = await manager.retry_task(task.id)
|
|
assert success is True
|
|
|
|
finally:
|
|
# Cleanup
|
|
with Session(engine) as session:
|
|
statement = select(TaskDB).where(TaskDB.model == model)
|
|
tasks = session.exec(statement).all()
|
|
for t in tasks:
|
|
session.delete(t)
|
|
session.commit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|