""" Property-Based Tests for AI Provider Fallback Mechanism This module contains property-based tests that verify correctness properties of the AI provider fallback system across all possible inputs. Properties tested: - Property 15: AI提供商故障转移 - 验证故障转移机制 Validates: Requirements 7.5 """ import pytest import asyncio from typing import List, Optional from unittest.mock import Mock, AsyncMock, patch from hypothesis import given, strategies as st, assume, settings, HealthCheck from hypothesis.strategies import composite from src.services.provider.fallback import ProviderService from src.services.provider.base import ServiceResponse, TaskStatus, GenerationResult from src.services.provider.registry import ModelRegistry, ModelType, ServiceConfig, ServiceFactory from src.services.provider.health import health_monitor, HealthStatus, HealthCheckResult from src.utils.errors import ModelNotAvailableException from datetime import datetime # ============================================================================ # Mock Services for Testing # ============================================================================ class MockProvider: """Mock AI provider for testing fallback behavior""" def __init__(self, model_name: str, should_fail: bool = False, fail_with_exception: bool = False, **kwargs): self.model_name = model_name self.should_fail = should_fail self.fail_with_exception = fail_with_exception self.call_count = 0 self.config = { "provider": "mock", "type": "image" } self._kwargs = kwargs async def generate_image(self, prompt: str, **kwargs): """Mock image generation""" self.call_count += 1 # Check if should fail based on kwargs or instance setting should_fail = self._kwargs.get('should_fail', self.should_fail) fail_with_exception = self._kwargs.get('fail_with_exception', self.fail_with_exception) if fail_with_exception: raise Exception(f"Provider {self.model_name} failed with exception") if should_fail: return ServiceResponse( status=TaskStatus.FAILED, error=f"Mock failure from {self.model_name}" ) return ServiceResponse( status=TaskStatus.SUCCEEDED, results=[GenerationResult( url=f"http://example.com/{self.model_name}.jpg", content=f"Generated by {self.model_name}" )] ) async def generate_video_from_text(self, prompt: str, **kwargs): """Mock video generation from text""" return await self.generate_image(prompt, **kwargs) async def generate_video_from_image(self, image: str, prompt: str = "", **kwargs): """Mock video generation from image""" return await self.generate_image(prompt, **kwargs) async def generate_image_from_image(self, prompt: str, image_inputs: list, **kwargs): """Mock image-to-image generation""" return await self.generate_image(prompt, **kwargs) async def generate_text(self, prompt: str, **kwargs): """Mock text generation""" return await self.generate_image(prompt, **kwargs) def mark_unhealthy(self): """Mark provider as unhealthy""" pass # ============================================================================ # Hypothesis Strategies for Generating Test Data # ============================================================================ @composite def model_names(draw): """Generate valid model names""" prefix = draw(st.sampled_from(["model", "provider", "service"])) suffix = draw(st.integers(min_value=1, max_value=100)) return f"{prefix}-{suffix}" @composite def prompts(draw): """Generate prompts for generation""" return draw(st.text(min_size=1, max_size=200)) @composite def fallback_chain(draw, min_size=1, max_size=5, exclude=None): """Generate a chain of fallback models, excluding specified models""" size = draw(st.integers(min_value=min_size, max_value=max_size)) models = [] exclude_set = set(exclude) if exclude else set() for i in range(size): model = draw(model_names()) # Ensure unique model names and not in exclude list while model in models or model in exclude_set: model = draw(model_names()) models.append(model) return models @composite def failure_pattern(draw, num_models): """ Generate a failure pattern for a list of models. Returns a list of booleans indicating which models should fail. Ensures at least one model succeeds (last one). """ if num_models == 0: return [] # Generate failures for all but the last model failures = [draw(st.booleans()) for _ in range(num_models - 1)] # Last model always succeeds to ensure fallback eventually works failures.append(False) return failures @composite def all_fail_pattern(draw, num_models): """Generate a pattern where all models fail""" return [True] * num_models # ============================================================================ # Property 15: AI Provider Fallback # ============================================================================ class TestProperty15AIProviderFallback: """ Property 15: AI提供商故障转移 验证故障转移机制 Validates: Requirements 7.5 """ @given( primary_model=model_names(), prompt=prompts() ) @settings(max_examples=50, deadline=None) @pytest.mark.asyncio async def test_fallback_succeeds_when_primary_fails( self, primary_model, prompt ): """ Property: When primary provider fails, system should automatically switch to fallback provider and succeed. For any primary model and list of fallback models, if primary fails but at least one fallback succeeds, the operation should succeed. """ # Generate fallback models excluding the primary fallback_models = ["fallback-1", "fallback-2", "fallback-3"] # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register primary (will fail) primary_config = ServiceConfig( id=primary_model, module="test", class_name="MockProvider", name=primary_model, args=[primary_model], kwargs={"should_fail": True}, type="image", provider="mock", enabled=True, is_default=True ) primary_factory = ServiceFactory(primary_config, MockProvider) ModelRegistry.register_factory( primary_model, primary_factory, ModelType.IMAGE, is_default=True ) # Register fallbacks (first one succeeds, rest don't matter) for i, fallback_model in enumerate(fallback_models): fallback_config = ServiceConfig( id=fallback_model, module="test", class_name="MockProvider", name=fallback_model, args=[fallback_model], kwargs={"should_fail": False}, # First fallback succeeds type="image", provider="mock", enabled=True ) fallback_factory = ServiceFactory(fallback_config, MockProvider) ModelRegistry.register_factory( fallback_model, fallback_factory, ModelType.IMAGE ) # Execute with fallback response = await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify success assert response.status == TaskStatus.SUCCEEDED, \ "Fallback should succeed when primary fails but fallback works" assert len(response.results) > 0, \ "Successful fallback should return results" # Verify the result came from a fallback model (not primary) result_url = response.results[0].url assert primary_model not in result_url, \ f"Result should not come from failed primary model {primary_model}" # Verify result came from one of the fallback models assert any(fb_model in result_url for fb_model in fallback_models), \ f"Result should come from one of the fallback models {fallback_models}" @given( primary_model=model_names(), prompt=prompts() ) @settings(max_examples=50, deadline=None) @pytest.mark.asyncio async def test_fallback_raises_exception_when_all_fail( self, primary_model, prompt ): """ Property: When all providers fail, system should raise ModelNotAvailableException. For any primary model and list of fallback models, if all fail, the operation should raise an exception. """ # Generate fallback models excluding the primary fallback_models = ["fallback-fail-1", "fallback-fail-2"] # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register all models as failing all_models = [primary_model] + fallback_models for model in all_models: config = ServiceConfig( id=model, module="test", class_name="MockProvider", name=model, args=[model], kwargs={"should_fail": True}, type="image", provider="mock", enabled=True, is_default=(model == primary_model) ) factory = ServiceFactory(config, MockProvider) ModelRegistry.register_factory( model, factory, ModelType.IMAGE, is_default=(model == primary_model) ) # Execute with fallback - should raise exception with pytest.raises(ModelNotAvailableException) as exc_info: await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify exception contains relevant information exception_str = str(exc_info.value) assert "All providers failed" in exception_str or \ "Model is not available" in exception_str, \ "Exception should indicate all providers failed" @given( primary_model=model_names(), prompt=prompts(), success_index=st.integers(min_value=0, max_value=2) ) @settings(max_examples=30, deadline=None, suppress_health_check=[HealthCheck.large_base_example]) @pytest.mark.asyncio async def test_fallback_tries_models_in_order( self, primary_model, prompt, success_index ): """ Property: Fallback should try models in the specified order and stop at the first successful one. For any chain of models, the system should try them in order and return the result from the first successful model. """ # Fixed fallback chain fallback_models = ["fallback-order-1", "fallback-order-2", "fallback-order-3"] # Adjust success_index to be within bounds success_index = min(success_index, len(fallback_models) - 1) # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register primary (will fail) primary_config = ServiceConfig( id=primary_model, module="test", class_name="MockProvider", name=primary_model, args=[primary_model], kwargs={"should_fail": True}, type="image", provider="mock", enabled=True, is_default=True ) primary_factory = ServiceFactory(primary_config, MockProvider) ModelRegistry.register_factory( primary_model, primary_factory, ModelType.IMAGE, is_default=True ) # Register fallbacks - only the one at success_index succeeds for i, fallback_model in enumerate(fallback_models): should_fail = (i < success_index) # Fail until success_index fallback_config = ServiceConfig( id=fallback_model, module="test", class_name="MockProvider", name=fallback_model, args=[fallback_model], kwargs={"should_fail": should_fail}, type="image", provider="mock", enabled=True ) fallback_factory = ServiceFactory(fallback_config, MockProvider) ModelRegistry.register_factory( fallback_model, fallback_factory, ModelType.IMAGE ) # Execute with fallback response = await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify success assert response.status == TaskStatus.SUCCEEDED # Verify the result came from the expected model expected_model = fallback_models[success_index] result_url = response.results[0].url assert expected_model in result_url, \ f"Result should come from model at index {success_index}: {expected_model}" # Verify models after success_index were not tried for i in range(success_index + 1, len(fallback_models)): later_model = fallback_models[i] # Get the service instance to check call count service = ModelRegistry.get(later_model) if service and hasattr(service, 'call_count'): assert service.call_count == 0, \ f"Model {later_model} at index {i} should not be called after success at {success_index}" @given( primary_model=model_names(), prompt=prompts() ) @settings(max_examples=30, deadline=None) @pytest.mark.asyncio async def test_no_fallback_when_primary_succeeds( self, primary_model, prompt ): """ Property: When primary provider succeeds, fallback models should not be tried. For any primary model that succeeds, no fallback should occur. """ # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register primary (will succeed) primary_config = ServiceConfig( id=primary_model, module="test", class_name="MockProvider", name=primary_model, args=[primary_model], kwargs={"should_fail": False}, type="image", provider="mock", enabled=True, is_default=True ) primary_factory = ServiceFactory(primary_config, MockProvider) ModelRegistry.register_factory( primary_model, primary_factory, ModelType.IMAGE, is_default=True ) # Register some fallback models fallback_models = ["fallback-1", "fallback-2"] for fallback_model in fallback_models: fallback_config = ServiceConfig( id=fallback_model, module="test", class_name="MockProvider", name=fallback_model, args=[fallback_model], kwargs={"should_fail": False}, type="image", provider="mock", enabled=True ) fallback_factory = ServiceFactory(fallback_config, MockProvider) ModelRegistry.register_factory( fallback_model, fallback_factory, ModelType.IMAGE ) # Execute with fallback response = await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify success assert response.status == TaskStatus.SUCCEEDED # Verify result came from primary result_url = response.results[0].url assert primary_model in result_url, \ f"Result should come from primary model {primary_model}" # Verify fallback models were not called for fallback_model in fallback_models: service = ModelRegistry.get(fallback_model) if service and hasattr(service, 'call_count'): assert service.call_count == 0, \ f"Fallback model {fallback_model} should not be called when primary succeeds" @given( primary_model=model_names(), fallback_models=fallback_chain(min_size=1, max_size=3), prompt=prompts() ) @settings(max_examples=30, deadline=None) @pytest.mark.asyncio async def test_fallback_skips_unhealthy_models( self, primary_model, fallback_models, prompt ): """ Property: Fallback should skip models marked as unhealthy and try the next available model. For any chain of models where some are unhealthy, the system should skip unhealthy ones and use the first healthy model. """ # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register primary (will fail) primary_config = ServiceConfig( id=primary_model, module="test", class_name="MockProvider", name=primary_model, args=[primary_model], kwargs={"should_fail": True}, type="image", provider="mock", enabled=True, is_default=True ) primary_factory = ServiceFactory(primary_config, MockProvider) ModelRegistry.register_factory( primary_model, primary_factory, ModelType.IMAGE, is_default=True ) # Register fallbacks - all will succeed if called for fallback_model in fallback_models: fallback_config = ServiceConfig( id=fallback_model, module="test", class_name="MockProvider", name=fallback_model, args=[fallback_model], kwargs={"should_fail": False}, type="image", provider="mock", enabled=True ) fallback_factory = ServiceFactory(fallback_config, MockProvider) ModelRegistry.register_factory( fallback_model, fallback_factory, ModelType.IMAGE ) # Mark first fallback as unhealthy (if there are multiple) if len(fallback_models) > 1: first_fallback = fallback_models[0] for _ in range(5): # Multiple failures to trigger UNHEALTHY result = HealthCheckResult( status=HealthStatus.UNHEALTHY, latency_ms=0.0, timestamp=datetime.now(), error="Test unhealthy" ) health_monitor.update_health(first_fallback, result) # Verify it's marked as unhealthy health = health_monitor.get_health(first_fallback) assert health is not None and health.status == HealthStatus.UNHEALTHY # Execute with fallback response = await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify success assert response.status == TaskStatus.SUCCEEDED # If we had multiple fallbacks and marked first as unhealthy, # verify result came from second fallback if len(fallback_models) > 1: result_url = response.results[0].url first_fallback = fallback_models[0] assert not result_url.endswith(f"/{first_fallback}.jpg"), \ f"Result should not come from unhealthy model {first_fallback}" # Should come from one of the healthy fallbacks healthy_fallbacks = fallback_models[1:] assert any(result_url.endswith(f"/{fb}.jpg") for fb in healthy_fallbacks), \ f"Result should come from one of the healthy fallbacks {healthy_fallbacks}" @given( primary_model=model_names(), prompt=prompts() ) @settings(max_examples=30, deadline=None) @pytest.mark.asyncio async def test_fallback_handles_exceptions( self, primary_model, prompt ): """ Property: Fallback should handle exceptions from providers and continue to next provider. For any provider that raises an exception, the system should catch it and try the next provider in the chain. """ # Fixed fallback models fallback_models = ["fallback-exc-1", "fallback-exc-2"] # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} health_monitor._health_status.clear() # Register primary (will raise exception) primary_config = ServiceConfig( id=primary_model, module="test", class_name="MockProvider", name=primary_model, args=[primary_model], kwargs={"fail_with_exception": True}, type="image", provider="mock", enabled=True, is_default=True ) primary_factory = ServiceFactory(primary_config, MockProvider) ModelRegistry.register_factory( primary_model, primary_factory, ModelType.IMAGE, is_default=True ) # Register fallbacks - first one succeeds for i, fallback_model in enumerate(fallback_models): fallback_config = ServiceConfig( id=fallback_model, module="test", class_name="MockProvider", name=fallback_model, args=[fallback_model], kwargs={"should_fail": False}, type="image", provider="mock", enabled=True ) fallback_factory = ServiceFactory(fallback_config, MockProvider) ModelRegistry.register_factory( fallback_model, fallback_factory, ModelType.IMAGE ) # Execute with fallback - should handle exception and succeed with fallback response = await ProviderService.generate_with_fallback( primary_model=primary_model, fallback_models=fallback_models, operation="generate_image", prompt=prompt ) # Verify success despite exception from primary assert response.status == TaskStatus.SUCCEEDED, \ "Fallback should succeed even when primary raises exception" # Verify result came from fallback result_url = response.results[0].url assert primary_model not in result_url, \ f"Result should not come from failed primary {primary_model}" assert any(fb in result_url for fb in fallback_models), \ f"Result should come from one of the fallbacks {fallback_models}" if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"])