""" Tests for Provider Fallback Mechanism Tests the automatic failover functionality when primary providers fail. Requirement 7.5: Implement故障转移机制 """ import pytest import asyncio from unittest.mock import Mock, AsyncMock, patch from src.services.provider.fallback import ProviderService from src.services.provider.base import ServiceResponse, TaskStatus, GenerationResult from src.services.provider.registry import ModelRegistry, ModelType, ServiceConfig, ServiceFactory from src.utils.errors import ModelNotAvailableException class MockImageService: """Mock image service for testing""" def __init__(self, model_name: str, should_fail: bool = False, **kwargs): self.model_name = model_name self.should_fail = should_fail self.config = { "provider": "mock", "type": "image" } self._kwargs = kwargs async def generate_image(self, prompt: str, **kwargs): # Check if this instance should fail based on kwargs passed during creation should_fail = self._kwargs.get('should_fail', self.should_fail) if should_fail: return ServiceResponse( status=TaskStatus.FAILED, error=f"Mock failure from {self.model_name}" ) return ServiceResponse( status=TaskStatus.SUCCEEDED, results=[GenerationResult( url=f"http://example.com/{self.model_name}.jpg", content=f"Generated by {self.model_name}" )] ) async def generate_image_from_image(self, prompt: str, image_inputs: list, **kwargs): return await self.generate_image(prompt, **kwargs) def mark_unhealthy(self): pass class MockVideoService: """Mock video service for testing""" def __init__(self, model_name: str, should_fail: bool = False, **kwargs): self.model_name = model_name self.should_fail = should_fail self.config = { "provider": "mock", "type": "video" } self._kwargs = kwargs async def generate_video_from_text(self, prompt: str, **kwargs): should_fail = self._kwargs.get('should_fail', self.should_fail) if should_fail: return ServiceResponse( status=TaskStatus.FAILED, error=f"Mock failure from {self.model_name}" ) return ServiceResponse( status=TaskStatus.SUCCEEDED, results=[GenerationResult( url=f"http://example.com/{self.model_name}.mp4", content=f"Generated by {self.model_name}" )] ) async def generate_video_from_image(self, image: str, prompt: str = "", **kwargs): return await self.generate_video_from_text(prompt, **kwargs) def mark_unhealthy(self): pass @pytest.fixture def setup_mock_services(): """Setup mock services in registry""" # Clear registry ModelRegistry._factories = {} ModelRegistry._defaults = {} # Register mock image services for i, (name, should_fail) in enumerate([ ("mock-image-1", True), # Primary - will fail ("mock-image-2", False), # Fallback 1 - will succeed ("mock-image-3", False), # Fallback 2 - not needed ]): config = ServiceConfig( id=name, module="test", class_name="MockImageService", name=name, args=[name], kwargs={"should_fail": should_fail}, type="image", provider="mock", enabled=True, is_default=(i == 0) ) factory = ServiceFactory(config, MockImageService) ModelRegistry.register_factory(name, factory, ModelType.IMAGE, is_default=(i == 0)) # Register mock video services for i, (name, should_fail) in enumerate([ ("mock-video-1", True), # Primary - will fail ("mock-video-2", False), # Fallback 1 - will succeed ]): config = ServiceConfig( id=name, module="test", class_name="MockVideoService", name=name, args=[name], kwargs={"should_fail": should_fail}, type="video", provider="mock", enabled=True, is_default=(i == 0) ) factory = ServiceFactory(config, MockVideoService) ModelRegistry.register_factory(name, factory, ModelType.VIDEO, is_default=(i == 0)) yield # Cleanup ModelRegistry._factories = {} ModelRegistry._defaults = {} @pytest.mark.asyncio async def test_fallback_on_primary_failure(setup_mock_services): """Test that fallback works when primary provider fails""" response = await ProviderService.generate_with_fallback( primary_model="mock-image-1", fallback_models=["mock-image-2", "mock-image-3"], operation="generate_image", prompt="test prompt" ) assert response.status == TaskStatus.SUCCEEDED assert len(response.results) == 1 assert "mock-image-2" in response.results[0].url assert "mock-image-2" in response.results[0].content @pytest.mark.asyncio async def test_fallback_all_providers_fail(setup_mock_services): """Test that exception is raised when all providers fail""" # Register all failing services ModelRegistry._factories = {} for name in ["fail-1", "fail-2", "fail-3"]: config = ServiceConfig( id=name, module="test", class_name="MockImageService", name=name, args=[name], kwargs={"should_fail": True}, type="image", provider="mock", enabled=True ) factory = ServiceFactory(config, MockImageService) ModelRegistry.register_factory(name, factory, ModelType.IMAGE) with pytest.raises(ModelNotAvailableException): await ProviderService.generate_with_fallback( primary_model="fail-1", fallback_models=["fail-2", "fail-3"], operation="generate_image", prompt="test prompt" ) @pytest.mark.asyncio async def test_generate_image_with_fallback(setup_mock_services): """Test convenience method for image generation with fallback""" response = await ProviderService.generate_image_with_fallback( primary_model="mock-image-1", fallback_models=["mock-image-2"], prompt="test prompt", size="1024*1024" ) assert response.status == TaskStatus.SUCCEEDED assert "mock-image-2" in response.results[0].url @pytest.mark.asyncio async def test_generate_video_with_fallback(setup_mock_services): """Test convenience method for video generation with fallback""" response = await ProviderService.generate_video_with_fallback( primary_model="mock-video-1", fallback_models=["mock-video-2"], prompt="test prompt", duration=5 ) assert response.status == TaskStatus.SUCCEEDED assert "mock-video-2" in response.results[0].url @pytest.mark.asyncio async def test_auto_detect_fallback_models(setup_mock_services): """Test automatic detection of suitable fallback models""" # Test with None fallback_models - should auto-detect response = await ProviderService.generate_image_with_fallback( primary_model="mock-image-1", fallback_models=None, # Auto-detect prompt="test prompt" ) assert response.status == TaskStatus.SUCCEEDED # Should have used one of the available fallback models @pytest.mark.asyncio async def test_fallback_with_image_to_image(setup_mock_services): """Test fallback with image-to-image generation""" response = await ProviderService.generate_image_with_fallback( primary_model="mock-image-1", fallback_models=["mock-image-2"], prompt="test prompt", image_inputs=["http://example.com/ref.jpg"] ) assert response.status == TaskStatus.SUCCEEDED assert "mock-image-2" in response.results[0].url @pytest.mark.asyncio async def test_fallback_with_video_from_image(setup_mock_services): """Test fallback with image-to-video generation""" response = await ProviderService.generate_video_with_fallback( primary_model="mock-video-1", fallback_models=["mock-video-2"], prompt="test prompt", image="http://example.com/frame.jpg" ) assert response.status == TaskStatus.SUCCEEDED assert "mock-video-2" in response.results[0].url def test_configure_fallback(setup_mock_services): """Test configuring fallback models for a service""" # Get the config and modify it config = ModelRegistry.get_config("mock-image-1") assert config is not None # Configure fallback through the service ProviderService.configure_fallback( model_id="mock-image-1", fallback_models=["mock-image-2", "mock-image-3"] ) # Note: The current implementation stores in config dict, # but ServiceFactory creates new instances, so we need to verify differently # For now, just verify the method doesn't raise an error # In a real implementation, this would be stored in a persistent config def test_get_fallback_config_not_configured(setup_mock_services): """Test getting fallback config for unconfigured model""" fallback = ProviderService.get_fallback_config("mock-image-2") assert fallback is None @pytest.mark.asyncio async def test_fallback_skips_unhealthy_models(setup_mock_services): """Test that fallback skips models marked as unhealthy""" from src.services.provider.health import health_monitor, HealthStatus, HealthCheckResult from datetime import datetime # Mark mock-image-2 as unhealthy by simulating multiple failed health checks # Need multiple failures to trigger UNHEALTHY status (3+ failures) for _ in range(5): result = HealthCheckResult( status=HealthStatus.UNHEALTHY, latency_ms=0.0, timestamp=datetime.now(), error="Test unhealthy" ) health_monitor.update_health("mock-image-2", result) # Verify it's marked as unhealthy health = health_monitor.get_health("mock-image-2") assert health is not None assert health.status == HealthStatus.UNHEALTHY, f"Expected UNHEALTHY but got {health.status}" # Should skip mock-image-2 and use mock-image-3 response = await ProviderService.generate_image_with_fallback( primary_model="mock-image-1", fallback_models=["mock-image-2", "mock-image-3"], prompt="test prompt" ) assert response.status == TaskStatus.SUCCEEDED assert "mock-image-3" in response.results[0].url @pytest.mark.asyncio async def test_primary_success_no_fallback(setup_mock_services): """Test that fallback is not used when primary succeeds""" from src.services.provider.health import health_monitor # Clear any previous health status health_monitor._health_status.clear() # Use a non-failing primary response = await ProviderService.generate_with_fallback( primary_model="mock-image-2", fallback_models=["mock-image-3"], operation="generate_image", prompt="test prompt" ) assert response.status == TaskStatus.SUCCEEDED assert "mock-image-2" in response.results[0].url # Primary was used if __name__ == "__main__": pytest.main([__file__, "-v"])