Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
This commit is contained in:
654
backend/tests/test_provider_fallback_properties.py
Normal file
654
backend/tests/test_provider_fallback_properties.py
Normal file
@@ -0,0 +1,654 @@
|
||||
"""
|
||||
Property-Based Tests for AI Provider Fallback Mechanism
|
||||
|
||||
This module contains property-based tests that verify correctness properties
|
||||
of the AI provider fallback system across all possible inputs.
|
||||
|
||||
Properties tested:
|
||||
- Property 15: AI提供商故障转移 - 验证故障转移机制
|
||||
|
||||
Validates: Requirements 7.5
|
||||
"""
|
||||
import pytest
|
||||
import asyncio
|
||||
from typing import List, Optional
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
from hypothesis import given, strategies as st, assume, settings, HealthCheck
|
||||
from hypothesis.strategies import composite
|
||||
|
||||
from src.services.provider.fallback import ProviderService
|
||||
from src.services.provider.base import ServiceResponse, TaskStatus, GenerationResult
|
||||
from src.services.provider.registry import ModelRegistry, ModelType, ServiceConfig, ServiceFactory
|
||||
from src.services.provider.health import health_monitor, HealthStatus, HealthCheckResult
|
||||
from src.utils.errors import ModelNotAvailableException
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock Services for Testing
|
||||
# ============================================================================
|
||||
|
||||
class MockProvider:
|
||||
"""Mock AI provider for testing fallback behavior"""
|
||||
|
||||
def __init__(self, model_name: str, should_fail: bool = False,
|
||||
fail_with_exception: bool = False, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.should_fail = should_fail
|
||||
self.fail_with_exception = fail_with_exception
|
||||
self.call_count = 0
|
||||
self.config = {
|
||||
"provider": "mock",
|
||||
"type": "image"
|
||||
}
|
||||
self._kwargs = kwargs
|
||||
|
||||
async def generate_image(self, prompt: str, **kwargs):
|
||||
"""Mock image generation"""
|
||||
self.call_count += 1
|
||||
|
||||
# Check if should fail based on kwargs or instance setting
|
||||
should_fail = self._kwargs.get('should_fail', self.should_fail)
|
||||
fail_with_exception = self._kwargs.get('fail_with_exception', self.fail_with_exception)
|
||||
|
||||
if fail_with_exception:
|
||||
raise Exception(f"Provider {self.model_name} failed with exception")
|
||||
|
||||
if should_fail:
|
||||
return ServiceResponse(
|
||||
status=TaskStatus.FAILED,
|
||||
error=f"Mock failure from {self.model_name}"
|
||||
)
|
||||
|
||||
return ServiceResponse(
|
||||
status=TaskStatus.SUCCEEDED,
|
||||
results=[GenerationResult(
|
||||
url=f"http://example.com/{self.model_name}.jpg",
|
||||
content=f"Generated by {self.model_name}"
|
||||
)]
|
||||
)
|
||||
|
||||
async def generate_video_from_text(self, prompt: str, **kwargs):
|
||||
"""Mock video generation from text"""
|
||||
return await self.generate_image(prompt, **kwargs)
|
||||
|
||||
async def generate_video_from_image(self, image: str, prompt: str = "", **kwargs):
|
||||
"""Mock video generation from image"""
|
||||
return await self.generate_image(prompt, **kwargs)
|
||||
|
||||
async def generate_image_from_image(self, prompt: str, image_inputs: list, **kwargs):
|
||||
"""Mock image-to-image generation"""
|
||||
return await self.generate_image(prompt, **kwargs)
|
||||
|
||||
async def generate_text(self, prompt: str, **kwargs):
|
||||
"""Mock text generation"""
|
||||
return await self.generate_image(prompt, **kwargs)
|
||||
|
||||
def mark_unhealthy(self):
|
||||
"""Mark provider as unhealthy"""
|
||||
pass
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Hypothesis Strategies for Generating Test Data
|
||||
# ============================================================================
|
||||
|
||||
@composite
|
||||
def model_names(draw):
|
||||
"""Generate valid model names"""
|
||||
prefix = draw(st.sampled_from(["model", "provider", "service"]))
|
||||
suffix = draw(st.integers(min_value=1, max_value=100))
|
||||
return f"{prefix}-{suffix}"
|
||||
|
||||
|
||||
@composite
|
||||
def prompts(draw):
|
||||
"""Generate prompts for generation"""
|
||||
return draw(st.text(min_size=1, max_size=200))
|
||||
|
||||
|
||||
@composite
|
||||
def fallback_chain(draw, min_size=1, max_size=5, exclude=None):
|
||||
"""Generate a chain of fallback models, excluding specified models"""
|
||||
size = draw(st.integers(min_value=min_size, max_value=max_size))
|
||||
models = []
|
||||
exclude_set = set(exclude) if exclude else set()
|
||||
|
||||
for i in range(size):
|
||||
model = draw(model_names())
|
||||
# Ensure unique model names and not in exclude list
|
||||
while model in models or model in exclude_set:
|
||||
model = draw(model_names())
|
||||
models.append(model)
|
||||
return models
|
||||
|
||||
|
||||
@composite
|
||||
def failure_pattern(draw, num_models):
|
||||
"""
|
||||
Generate a failure pattern for a list of models.
|
||||
Returns a list of booleans indicating which models should fail.
|
||||
Ensures at least one model succeeds (last one).
|
||||
"""
|
||||
if num_models == 0:
|
||||
return []
|
||||
|
||||
# Generate failures for all but the last model
|
||||
failures = [draw(st.booleans()) for _ in range(num_models - 1)]
|
||||
# Last model always succeeds to ensure fallback eventually works
|
||||
failures.append(False)
|
||||
return failures
|
||||
|
||||
|
||||
@composite
|
||||
def all_fail_pattern(draw, num_models):
|
||||
"""Generate a pattern where all models fail"""
|
||||
return [True] * num_models
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Property 15: AI Provider Fallback
|
||||
# ============================================================================
|
||||
|
||||
class TestProperty15AIProviderFallback:
|
||||
"""
|
||||
Property 15: AI提供商故障转移
|
||||
|
||||
验证故障转移机制
|
||||
Validates: Requirements 7.5
|
||||
"""
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
prompt=prompts()
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_succeeds_when_primary_fails(
|
||||
self, primary_model, prompt
|
||||
):
|
||||
"""
|
||||
Property: When primary provider fails, system should automatically
|
||||
switch to fallback provider and succeed.
|
||||
|
||||
For any primary model and list of fallback models, if primary fails
|
||||
but at least one fallback succeeds, the operation should succeed.
|
||||
"""
|
||||
# Generate fallback models excluding the primary
|
||||
fallback_models = ["fallback-1", "fallback-2", "fallback-3"]
|
||||
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register primary (will fail)
|
||||
primary_config = ServiceConfig(
|
||||
id=primary_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=primary_model,
|
||||
args=[primary_model],
|
||||
kwargs={"should_fail": True},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=True
|
||||
)
|
||||
primary_factory = ServiceFactory(primary_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
||||
)
|
||||
|
||||
# Register fallbacks (first one succeeds, rest don't matter)
|
||||
for i, fallback_model in enumerate(fallback_models):
|
||||
fallback_config = ServiceConfig(
|
||||
id=fallback_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=fallback_model,
|
||||
args=[fallback_model],
|
||||
kwargs={"should_fail": False}, # First fallback succeeds
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True
|
||||
)
|
||||
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
fallback_model, fallback_factory, ModelType.IMAGE
|
||||
)
|
||||
|
||||
# Execute with fallback
|
||||
response = await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify success
|
||||
assert response.status == TaskStatus.SUCCEEDED, \
|
||||
"Fallback should succeed when primary fails but fallback works"
|
||||
assert len(response.results) > 0, \
|
||||
"Successful fallback should return results"
|
||||
|
||||
# Verify the result came from a fallback model (not primary)
|
||||
result_url = response.results[0].url
|
||||
assert primary_model not in result_url, \
|
||||
f"Result should not come from failed primary model {primary_model}"
|
||||
|
||||
# Verify result came from one of the fallback models
|
||||
assert any(fb_model in result_url for fb_model in fallback_models), \
|
||||
f"Result should come from one of the fallback models {fallback_models}"
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
prompt=prompts()
|
||||
)
|
||||
@settings(max_examples=50, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_raises_exception_when_all_fail(
|
||||
self, primary_model, prompt
|
||||
):
|
||||
"""
|
||||
Property: When all providers fail, system should raise
|
||||
ModelNotAvailableException.
|
||||
|
||||
For any primary model and list of fallback models, if all fail,
|
||||
the operation should raise an exception.
|
||||
"""
|
||||
# Generate fallback models excluding the primary
|
||||
fallback_models = ["fallback-fail-1", "fallback-fail-2"]
|
||||
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register all models as failing
|
||||
all_models = [primary_model] + fallback_models
|
||||
for model in all_models:
|
||||
config = ServiceConfig(
|
||||
id=model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=model,
|
||||
args=[model],
|
||||
kwargs={"should_fail": True},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=(model == primary_model)
|
||||
)
|
||||
factory = ServiceFactory(config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
model, factory, ModelType.IMAGE,
|
||||
is_default=(model == primary_model)
|
||||
)
|
||||
|
||||
# Execute with fallback - should raise exception
|
||||
with pytest.raises(ModelNotAvailableException) as exc_info:
|
||||
await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify exception contains relevant information
|
||||
exception_str = str(exc_info.value)
|
||||
assert "All providers failed" in exception_str or \
|
||||
"Model is not available" in exception_str, \
|
||||
"Exception should indicate all providers failed"
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
prompt=prompts(),
|
||||
success_index=st.integers(min_value=0, max_value=2)
|
||||
)
|
||||
@settings(max_examples=30, deadline=None, suppress_health_check=[HealthCheck.large_base_example])
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_tries_models_in_order(
|
||||
self, primary_model, prompt, success_index
|
||||
):
|
||||
"""
|
||||
Property: Fallback should try models in the specified order and stop
|
||||
at the first successful one.
|
||||
|
||||
For any chain of models, the system should try them in order and
|
||||
return the result from the first successful model.
|
||||
"""
|
||||
# Fixed fallback chain
|
||||
fallback_models = ["fallback-order-1", "fallback-order-2", "fallback-order-3"]
|
||||
|
||||
# Adjust success_index to be within bounds
|
||||
success_index = min(success_index, len(fallback_models) - 1)
|
||||
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register primary (will fail)
|
||||
primary_config = ServiceConfig(
|
||||
id=primary_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=primary_model,
|
||||
args=[primary_model],
|
||||
kwargs={"should_fail": True},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=True
|
||||
)
|
||||
primary_factory = ServiceFactory(primary_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
||||
)
|
||||
|
||||
# Register fallbacks - only the one at success_index succeeds
|
||||
for i, fallback_model in enumerate(fallback_models):
|
||||
should_fail = (i < success_index) # Fail until success_index
|
||||
fallback_config = ServiceConfig(
|
||||
id=fallback_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=fallback_model,
|
||||
args=[fallback_model],
|
||||
kwargs={"should_fail": should_fail},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True
|
||||
)
|
||||
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
fallback_model, fallback_factory, ModelType.IMAGE
|
||||
)
|
||||
|
||||
# Execute with fallback
|
||||
response = await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify success
|
||||
assert response.status == TaskStatus.SUCCEEDED
|
||||
|
||||
# Verify the result came from the expected model
|
||||
expected_model = fallback_models[success_index]
|
||||
result_url = response.results[0].url
|
||||
assert expected_model in result_url, \
|
||||
f"Result should come from model at index {success_index}: {expected_model}"
|
||||
|
||||
# Verify models after success_index were not tried
|
||||
for i in range(success_index + 1, len(fallback_models)):
|
||||
later_model = fallback_models[i]
|
||||
# Get the service instance to check call count
|
||||
service = ModelRegistry.get(later_model)
|
||||
if service and hasattr(service, 'call_count'):
|
||||
assert service.call_count == 0, \
|
||||
f"Model {later_model} at index {i} should not be called after success at {success_index}"
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
prompt=prompts()
|
||||
)
|
||||
@settings(max_examples=30, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_fallback_when_primary_succeeds(
|
||||
self, primary_model, prompt
|
||||
):
|
||||
"""
|
||||
Property: When primary provider succeeds, fallback models should not
|
||||
be tried.
|
||||
|
||||
For any primary model that succeeds, no fallback should occur.
|
||||
"""
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register primary (will succeed)
|
||||
primary_config = ServiceConfig(
|
||||
id=primary_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=primary_model,
|
||||
args=[primary_model],
|
||||
kwargs={"should_fail": False},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=True
|
||||
)
|
||||
primary_factory = ServiceFactory(primary_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
||||
)
|
||||
|
||||
# Register some fallback models
|
||||
fallback_models = ["fallback-1", "fallback-2"]
|
||||
for fallback_model in fallback_models:
|
||||
fallback_config = ServiceConfig(
|
||||
id=fallback_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=fallback_model,
|
||||
args=[fallback_model],
|
||||
kwargs={"should_fail": False},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True
|
||||
)
|
||||
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
fallback_model, fallback_factory, ModelType.IMAGE
|
||||
)
|
||||
|
||||
# Execute with fallback
|
||||
response = await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify success
|
||||
assert response.status == TaskStatus.SUCCEEDED
|
||||
|
||||
# Verify result came from primary
|
||||
result_url = response.results[0].url
|
||||
assert primary_model in result_url, \
|
||||
f"Result should come from primary model {primary_model}"
|
||||
|
||||
# Verify fallback models were not called
|
||||
for fallback_model in fallback_models:
|
||||
service = ModelRegistry.get(fallback_model)
|
||||
if service and hasattr(service, 'call_count'):
|
||||
assert service.call_count == 0, \
|
||||
f"Fallback model {fallback_model} should not be called when primary succeeds"
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
fallback_models=fallback_chain(min_size=1, max_size=3),
|
||||
prompt=prompts()
|
||||
)
|
||||
@settings(max_examples=30, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_skips_unhealthy_models(
|
||||
self, primary_model, fallback_models, prompt
|
||||
):
|
||||
"""
|
||||
Property: Fallback should skip models marked as unhealthy and try
|
||||
the next available model.
|
||||
|
||||
For any chain of models where some are unhealthy, the system should
|
||||
skip unhealthy ones and use the first healthy model.
|
||||
"""
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register primary (will fail)
|
||||
primary_config = ServiceConfig(
|
||||
id=primary_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=primary_model,
|
||||
args=[primary_model],
|
||||
kwargs={"should_fail": True},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=True
|
||||
)
|
||||
primary_factory = ServiceFactory(primary_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
||||
)
|
||||
|
||||
# Register fallbacks - all will succeed if called
|
||||
for fallback_model in fallback_models:
|
||||
fallback_config = ServiceConfig(
|
||||
id=fallback_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=fallback_model,
|
||||
args=[fallback_model],
|
||||
kwargs={"should_fail": False},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True
|
||||
)
|
||||
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
fallback_model, fallback_factory, ModelType.IMAGE
|
||||
)
|
||||
|
||||
# Mark first fallback as unhealthy (if there are multiple)
|
||||
if len(fallback_models) > 1:
|
||||
first_fallback = fallback_models[0]
|
||||
for _ in range(5): # Multiple failures to trigger UNHEALTHY
|
||||
result = HealthCheckResult(
|
||||
status=HealthStatus.UNHEALTHY,
|
||||
latency_ms=0.0,
|
||||
timestamp=datetime.now(),
|
||||
error="Test unhealthy"
|
||||
)
|
||||
health_monitor.update_health(first_fallback, result)
|
||||
|
||||
# Verify it's marked as unhealthy
|
||||
health = health_monitor.get_health(first_fallback)
|
||||
assert health is not None and health.status == HealthStatus.UNHEALTHY
|
||||
|
||||
# Execute with fallback
|
||||
response = await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify success
|
||||
assert response.status == TaskStatus.SUCCEEDED
|
||||
|
||||
# If we had multiple fallbacks and marked first as unhealthy,
|
||||
# verify result came from second fallback
|
||||
if len(fallback_models) > 1:
|
||||
result_url = response.results[0].url
|
||||
first_fallback = fallback_models[0]
|
||||
assert not result_url.endswith(f"/{first_fallback}.jpg"), \
|
||||
f"Result should not come from unhealthy model {first_fallback}"
|
||||
|
||||
# Should come from one of the healthy fallbacks
|
||||
healthy_fallbacks = fallback_models[1:]
|
||||
assert any(result_url.endswith(f"/{fb}.jpg") for fb in healthy_fallbacks), \
|
||||
f"Result should come from one of the healthy fallbacks {healthy_fallbacks}"
|
||||
|
||||
@given(
|
||||
primary_model=model_names(),
|
||||
prompt=prompts()
|
||||
)
|
||||
@settings(max_examples=30, deadline=None)
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_handles_exceptions(
|
||||
self, primary_model, prompt
|
||||
):
|
||||
"""
|
||||
Property: Fallback should handle exceptions from providers and
|
||||
continue to next provider.
|
||||
|
||||
For any provider that raises an exception, the system should catch it
|
||||
and try the next provider in the chain.
|
||||
"""
|
||||
# Fixed fallback models
|
||||
fallback_models = ["fallback-exc-1", "fallback-exc-2"]
|
||||
|
||||
# Clear registry
|
||||
ModelRegistry._factories = {}
|
||||
ModelRegistry._defaults = {}
|
||||
health_monitor._health_status.clear()
|
||||
|
||||
# Register primary (will raise exception)
|
||||
primary_config = ServiceConfig(
|
||||
id=primary_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=primary_model,
|
||||
args=[primary_model],
|
||||
kwargs={"fail_with_exception": True},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True,
|
||||
is_default=True
|
||||
)
|
||||
primary_factory = ServiceFactory(primary_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
||||
)
|
||||
|
||||
# Register fallbacks - first one succeeds
|
||||
for i, fallback_model in enumerate(fallback_models):
|
||||
fallback_config = ServiceConfig(
|
||||
id=fallback_model,
|
||||
module="test",
|
||||
class_name="MockProvider",
|
||||
name=fallback_model,
|
||||
args=[fallback_model],
|
||||
kwargs={"should_fail": False},
|
||||
type="image",
|
||||
provider="mock",
|
||||
enabled=True
|
||||
)
|
||||
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
||||
ModelRegistry.register_factory(
|
||||
fallback_model, fallback_factory, ModelType.IMAGE
|
||||
)
|
||||
|
||||
# Execute with fallback - should handle exception and succeed with fallback
|
||||
response = await ProviderService.generate_with_fallback(
|
||||
primary_model=primary_model,
|
||||
fallback_models=fallback_models,
|
||||
operation="generate_image",
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
# Verify success despite exception from primary
|
||||
assert response.status == TaskStatus.SUCCEEDED, \
|
||||
"Fallback should succeed even when primary raises exception"
|
||||
|
||||
# Verify result came from fallback
|
||||
result_url = response.results[0].url
|
||||
assert primary_model not in result_url, \
|
||||
f"Result should not come from failed primary {primary_model}"
|
||||
assert any(fb in result_url for fb in fallback_models), \
|
||||
f"Result should come from one of the fallbacks {fallback_models}"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
Reference in New Issue
Block a user