Files
pixel/backend/tests/test_provider_fallback.py
张鹏 f9f4560459 Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents
- Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow
- Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.)
- All HTTP clients migrated from sync requests to async httpx
- Admin-managed API keys via environment variables
- SSRF vulnerability fixed in ensure_url()
2026-04-29 01:20:12 +08:00

332 lines
11 KiB
Python

"""
Tests for Provider Fallback Mechanism
Tests the automatic failover functionality when primary providers fail.
Requirement 7.5: Implement故障转移机制
"""
import pytest
import asyncio
from unittest.mock import Mock, AsyncMock, patch
from src.services.provider.fallback import ProviderService
from src.services.provider.base import ServiceResponse, TaskStatus, GenerationResult
from src.services.provider.registry import ModelRegistry, ModelType, ServiceConfig, ServiceFactory
from src.utils.errors import ModelNotAvailableException
class MockImageService:
"""Mock image service for testing"""
def __init__(self, model_name: str, should_fail: bool = False, **kwargs):
self.model_name = model_name
self.should_fail = should_fail
self.config = {
"provider": "mock",
"type": "image"
}
self._kwargs = kwargs
async def generate_image(self, prompt: str, **kwargs):
# Check if this instance should fail based on kwargs passed during creation
should_fail = self._kwargs.get('should_fail', self.should_fail)
if should_fail:
return ServiceResponse(
status=TaskStatus.FAILED,
error=f"Mock failure from {self.model_name}"
)
return ServiceResponse(
status=TaskStatus.SUCCEEDED,
results=[GenerationResult(
url=f"http://example.com/{self.model_name}.jpg",
content=f"Generated by {self.model_name}"
)]
)
async def generate_image_from_image(self, prompt: str, image_inputs: list, **kwargs):
return await self.generate_image(prompt, **kwargs)
def mark_unhealthy(self):
pass
class MockVideoService:
"""Mock video service for testing"""
def __init__(self, model_name: str, should_fail: bool = False, **kwargs):
self.model_name = model_name
self.should_fail = should_fail
self.config = {
"provider": "mock",
"type": "video"
}
self._kwargs = kwargs
async def generate_video_from_text(self, prompt: str, **kwargs):
should_fail = self._kwargs.get('should_fail', self.should_fail)
if should_fail:
return ServiceResponse(
status=TaskStatus.FAILED,
error=f"Mock failure from {self.model_name}"
)
return ServiceResponse(
status=TaskStatus.SUCCEEDED,
results=[GenerationResult(
url=f"http://example.com/{self.model_name}.mp4",
content=f"Generated by {self.model_name}"
)]
)
async def generate_video_from_image(self, image: str, prompt: str = "", **kwargs):
return await self.generate_video_from_text(prompt, **kwargs)
def mark_unhealthy(self):
pass
@pytest.fixture
def setup_mock_services():
"""Setup mock services in registry"""
# Clear registry
ModelRegistry._factories = {}
ModelRegistry._defaults = {}
# Register mock image services
for i, (name, should_fail) in enumerate([
("mock-image-1", True), # Primary - will fail
("mock-image-2", False), # Fallback 1 - will succeed
("mock-image-3", False), # Fallback 2 - not needed
]):
config = ServiceConfig(
id=name,
module="test",
class_name="MockImageService",
name=name,
args=[name],
kwargs={"should_fail": should_fail},
type="image",
provider="mock",
enabled=True,
is_default=(i == 0)
)
factory = ServiceFactory(config, MockImageService)
ModelRegistry.register_factory(name, factory, ModelType.IMAGE, is_default=(i == 0))
# Register mock video services
for i, (name, should_fail) in enumerate([
("mock-video-1", True), # Primary - will fail
("mock-video-2", False), # Fallback 1 - will succeed
]):
config = ServiceConfig(
id=name,
module="test",
class_name="MockVideoService",
name=name,
args=[name],
kwargs={"should_fail": should_fail},
type="video",
provider="mock",
enabled=True,
is_default=(i == 0)
)
factory = ServiceFactory(config, MockVideoService)
ModelRegistry.register_factory(name, factory, ModelType.VIDEO, is_default=(i == 0))
yield
# Cleanup
ModelRegistry._factories = {}
ModelRegistry._defaults = {}
@pytest.mark.asyncio
async def test_fallback_on_primary_failure(setup_mock_services):
"""Test that fallback works when primary provider fails"""
response = await ProviderService.generate_with_fallback(
primary_model="mock-image-1",
fallback_models=["mock-image-2", "mock-image-3"],
operation="generate_image",
prompt="test prompt"
)
assert response.status == TaskStatus.SUCCEEDED
assert len(response.results) == 1
assert "mock-image-2" in response.results[0].url
assert "mock-image-2" in response.results[0].content
@pytest.mark.asyncio
async def test_fallback_all_providers_fail(setup_mock_services):
"""Test that exception is raised when all providers fail"""
# Register all failing services
ModelRegistry._factories = {}
for name in ["fail-1", "fail-2", "fail-3"]:
config = ServiceConfig(
id=name,
module="test",
class_name="MockImageService",
name=name,
args=[name],
kwargs={"should_fail": True},
type="image",
provider="mock",
enabled=True
)
factory = ServiceFactory(config, MockImageService)
ModelRegistry.register_factory(name, factory, ModelType.IMAGE)
with pytest.raises(ModelNotAvailableException):
await ProviderService.generate_with_fallback(
primary_model="fail-1",
fallback_models=["fail-2", "fail-3"],
operation="generate_image",
prompt="test prompt"
)
@pytest.mark.asyncio
async def test_generate_image_with_fallback(setup_mock_services):
"""Test convenience method for image generation with fallback"""
response = await ProviderService.generate_image_with_fallback(
primary_model="mock-image-1",
fallback_models=["mock-image-2"],
prompt="test prompt",
size="1024*1024"
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-image-2" in response.results[0].url
@pytest.mark.asyncio
async def test_generate_video_with_fallback(setup_mock_services):
"""Test convenience method for video generation with fallback"""
response = await ProviderService.generate_video_with_fallback(
primary_model="mock-video-1",
fallback_models=["mock-video-2"],
prompt="test prompt",
duration=5
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-video-2" in response.results[0].url
@pytest.mark.asyncio
async def test_auto_detect_fallback_models(setup_mock_services):
"""Test automatic detection of suitable fallback models"""
# Test with None fallback_models - should auto-detect
response = await ProviderService.generate_image_with_fallback(
primary_model="mock-image-1",
fallback_models=None, # Auto-detect
prompt="test prompt"
)
assert response.status == TaskStatus.SUCCEEDED
# Should have used one of the available fallback models
@pytest.mark.asyncio
async def test_fallback_with_image_to_image(setup_mock_services):
"""Test fallback with image-to-image generation"""
response = await ProviderService.generate_image_with_fallback(
primary_model="mock-image-1",
fallback_models=["mock-image-2"],
prompt="test prompt",
image_inputs=["http://example.com/ref.jpg"]
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-image-2" in response.results[0].url
@pytest.mark.asyncio
async def test_fallback_with_video_from_image(setup_mock_services):
"""Test fallback with image-to-video generation"""
response = await ProviderService.generate_video_with_fallback(
primary_model="mock-video-1",
fallback_models=["mock-video-2"],
prompt="test prompt",
image="http://example.com/frame.jpg"
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-video-2" in response.results[0].url
def test_configure_fallback(setup_mock_services):
"""Test configuring fallback models for a service"""
# Get the config and modify it
config = ModelRegistry.get_config("mock-image-1")
assert config is not None
# Configure fallback through the service
ProviderService.configure_fallback(
model_id="mock-image-1",
fallback_models=["mock-image-2", "mock-image-3"]
)
# Note: The current implementation stores in config dict,
# but ServiceFactory creates new instances, so we need to verify differently
# For now, just verify the method doesn't raise an error
# In a real implementation, this would be stored in a persistent config
def test_get_fallback_config_not_configured(setup_mock_services):
"""Test getting fallback config for unconfigured model"""
fallback = ProviderService.get_fallback_config("mock-image-2")
assert fallback is None
@pytest.mark.asyncio
async def test_fallback_skips_unhealthy_models(setup_mock_services):
"""Test that fallback skips models marked as unhealthy"""
from src.services.provider.health import health_monitor, HealthStatus, HealthCheckResult
from datetime import datetime
# Mark mock-image-2 as unhealthy by simulating multiple failed health checks
# Need multiple failures to trigger UNHEALTHY status (3+ failures)
for _ in range(5):
result = HealthCheckResult(
status=HealthStatus.UNHEALTHY,
latency_ms=0.0,
timestamp=datetime.now(),
error="Test unhealthy"
)
health_monitor.update_health("mock-image-2", result)
# Verify it's marked as unhealthy
health = health_monitor.get_health("mock-image-2")
assert health is not None
assert health.status == HealthStatus.UNHEALTHY, f"Expected UNHEALTHY but got {health.status}"
# Should skip mock-image-2 and use mock-image-3
response = await ProviderService.generate_image_with_fallback(
primary_model="mock-image-1",
fallback_models=["mock-image-2", "mock-image-3"],
prompt="test prompt"
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-image-3" in response.results[0].url
@pytest.mark.asyncio
async def test_primary_success_no_fallback(setup_mock_services):
"""Test that fallback is not used when primary succeeds"""
from src.services.provider.health import health_monitor
# Clear any previous health status
health_monitor._health_status.clear()
# Use a non-failing primary
response = await ProviderService.generate_with_fallback(
primary_model="mock-image-2",
fallback_models=["mock-image-3"],
operation="generate_image",
prompt="test prompt"
)
assert response.status == TaskStatus.SUCCEEDED
assert "mock-image-2" in response.results[0].url # Primary was used
if __name__ == "__main__":
pytest.main([__file__, "-v"])