- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
655 lines
24 KiB
Python
655 lines
24 KiB
Python
"""
|
|
Property-Based Tests for AI Provider Fallback Mechanism
|
|
|
|
This module contains property-based tests that verify correctness properties
|
|
of the AI provider fallback system across all possible inputs.
|
|
|
|
Properties tested:
|
|
- Property 15: AI提供商故障转移 - 验证故障转移机制
|
|
|
|
Validates: Requirements 7.5
|
|
"""
|
|
import pytest
|
|
import asyncio
|
|
from typing import List, Optional
|
|
from unittest.mock import Mock, AsyncMock, patch
|
|
from hypothesis import given, strategies as st, assume, settings, HealthCheck
|
|
from hypothesis.strategies import composite
|
|
|
|
from src.services.provider.fallback import ProviderService
|
|
from src.services.provider.base import ServiceResponse, TaskStatus, GenerationResult
|
|
from src.services.provider.registry import ModelRegistry, ModelType, ServiceConfig, ServiceFactory
|
|
from src.services.provider.health import health_monitor, HealthStatus, HealthCheckResult
|
|
from src.utils.errors import ModelNotAvailableException
|
|
from datetime import datetime
|
|
|
|
|
|
# ============================================================================
|
|
# Mock Services for Testing
|
|
# ============================================================================
|
|
|
|
class MockProvider:
|
|
"""Mock AI provider for testing fallback behavior"""
|
|
|
|
def __init__(self, model_name: str, should_fail: bool = False,
|
|
fail_with_exception: bool = False, **kwargs):
|
|
self.model_name = model_name
|
|
self.should_fail = should_fail
|
|
self.fail_with_exception = fail_with_exception
|
|
self.call_count = 0
|
|
self.config = {
|
|
"provider": "mock",
|
|
"type": "image"
|
|
}
|
|
self._kwargs = kwargs
|
|
|
|
async def generate_image(self, prompt: str, **kwargs):
|
|
"""Mock image generation"""
|
|
self.call_count += 1
|
|
|
|
# Check if should fail based on kwargs or instance setting
|
|
should_fail = self._kwargs.get('should_fail', self.should_fail)
|
|
fail_with_exception = self._kwargs.get('fail_with_exception', self.fail_with_exception)
|
|
|
|
if fail_with_exception:
|
|
raise Exception(f"Provider {self.model_name} failed with exception")
|
|
|
|
if should_fail:
|
|
return ServiceResponse(
|
|
status=TaskStatus.FAILED,
|
|
error=f"Mock failure from {self.model_name}"
|
|
)
|
|
|
|
return ServiceResponse(
|
|
status=TaskStatus.SUCCEEDED,
|
|
results=[GenerationResult(
|
|
url=f"http://example.com/{self.model_name}.jpg",
|
|
content=f"Generated by {self.model_name}"
|
|
)]
|
|
)
|
|
|
|
async def generate_video_from_text(self, prompt: str, **kwargs):
|
|
"""Mock video generation from text"""
|
|
return await self.generate_image(prompt, **kwargs)
|
|
|
|
async def generate_video_from_image(self, image: str, prompt: str = "", **kwargs):
|
|
"""Mock video generation from image"""
|
|
return await self.generate_image(prompt, **kwargs)
|
|
|
|
async def generate_image_from_image(self, prompt: str, image_inputs: list, **kwargs):
|
|
"""Mock image-to-image generation"""
|
|
return await self.generate_image(prompt, **kwargs)
|
|
|
|
async def generate_text(self, prompt: str, **kwargs):
|
|
"""Mock text generation"""
|
|
return await self.generate_image(prompt, **kwargs)
|
|
|
|
def mark_unhealthy(self):
|
|
"""Mark provider as unhealthy"""
|
|
pass
|
|
|
|
|
|
# ============================================================================
|
|
# Hypothesis Strategies for Generating Test Data
|
|
# ============================================================================
|
|
|
|
@composite
|
|
def model_names(draw):
|
|
"""Generate valid model names"""
|
|
prefix = draw(st.sampled_from(["model", "provider", "service"]))
|
|
suffix = draw(st.integers(min_value=1, max_value=100))
|
|
return f"{prefix}-{suffix}"
|
|
|
|
|
|
@composite
|
|
def prompts(draw):
|
|
"""Generate prompts for generation"""
|
|
return draw(st.text(min_size=1, max_size=200))
|
|
|
|
|
|
@composite
|
|
def fallback_chain(draw, min_size=1, max_size=5, exclude=None):
|
|
"""Generate a chain of fallback models, excluding specified models"""
|
|
size = draw(st.integers(min_value=min_size, max_value=max_size))
|
|
models = []
|
|
exclude_set = set(exclude) if exclude else set()
|
|
|
|
for i in range(size):
|
|
model = draw(model_names())
|
|
# Ensure unique model names and not in exclude list
|
|
while model in models or model in exclude_set:
|
|
model = draw(model_names())
|
|
models.append(model)
|
|
return models
|
|
|
|
|
|
@composite
|
|
def failure_pattern(draw, num_models):
|
|
"""
|
|
Generate a failure pattern for a list of models.
|
|
Returns a list of booleans indicating which models should fail.
|
|
Ensures at least one model succeeds (last one).
|
|
"""
|
|
if num_models == 0:
|
|
return []
|
|
|
|
# Generate failures for all but the last model
|
|
failures = [draw(st.booleans()) for _ in range(num_models - 1)]
|
|
# Last model always succeeds to ensure fallback eventually works
|
|
failures.append(False)
|
|
return failures
|
|
|
|
|
|
@composite
|
|
def all_fail_pattern(draw, num_models):
|
|
"""Generate a pattern where all models fail"""
|
|
return [True] * num_models
|
|
|
|
|
|
# ============================================================================
|
|
# Property 15: AI Provider Fallback
|
|
# ============================================================================
|
|
|
|
class TestProperty15AIProviderFallback:
|
|
"""
|
|
Property 15: AI提供商故障转移
|
|
|
|
验证故障转移机制
|
|
Validates: Requirements 7.5
|
|
"""
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
prompt=prompts()
|
|
)
|
|
@settings(max_examples=50, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_succeeds_when_primary_fails(
|
|
self, primary_model, prompt
|
|
):
|
|
"""
|
|
Property: When primary provider fails, system should automatically
|
|
switch to fallback provider and succeed.
|
|
|
|
For any primary model and list of fallback models, if primary fails
|
|
but at least one fallback succeeds, the operation should succeed.
|
|
"""
|
|
# Generate fallback models excluding the primary
|
|
fallback_models = ["fallback-1", "fallback-2", "fallback-3"]
|
|
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register primary (will fail)
|
|
primary_config = ServiceConfig(
|
|
id=primary_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=primary_model,
|
|
args=[primary_model],
|
|
kwargs={"should_fail": True},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=True
|
|
)
|
|
primary_factory = ServiceFactory(primary_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
|
)
|
|
|
|
# Register fallbacks (first one succeeds, rest don't matter)
|
|
for i, fallback_model in enumerate(fallback_models):
|
|
fallback_config = ServiceConfig(
|
|
id=fallback_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=fallback_model,
|
|
args=[fallback_model],
|
|
kwargs={"should_fail": False}, # First fallback succeeds
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True
|
|
)
|
|
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
fallback_model, fallback_factory, ModelType.IMAGE
|
|
)
|
|
|
|
# Execute with fallback
|
|
response = await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify success
|
|
assert response.status == TaskStatus.SUCCEEDED, \
|
|
"Fallback should succeed when primary fails but fallback works"
|
|
assert len(response.results) > 0, \
|
|
"Successful fallback should return results"
|
|
|
|
# Verify the result came from a fallback model (not primary)
|
|
result_url = response.results[0].url
|
|
assert primary_model not in result_url, \
|
|
f"Result should not come from failed primary model {primary_model}"
|
|
|
|
# Verify result came from one of the fallback models
|
|
assert any(fb_model in result_url for fb_model in fallback_models), \
|
|
f"Result should come from one of the fallback models {fallback_models}"
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
prompt=prompts()
|
|
)
|
|
@settings(max_examples=50, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_raises_exception_when_all_fail(
|
|
self, primary_model, prompt
|
|
):
|
|
"""
|
|
Property: When all providers fail, system should raise
|
|
ModelNotAvailableException.
|
|
|
|
For any primary model and list of fallback models, if all fail,
|
|
the operation should raise an exception.
|
|
"""
|
|
# Generate fallback models excluding the primary
|
|
fallback_models = ["fallback-fail-1", "fallback-fail-2"]
|
|
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register all models as failing
|
|
all_models = [primary_model] + fallback_models
|
|
for model in all_models:
|
|
config = ServiceConfig(
|
|
id=model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=model,
|
|
args=[model],
|
|
kwargs={"should_fail": True},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=(model == primary_model)
|
|
)
|
|
factory = ServiceFactory(config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
model, factory, ModelType.IMAGE,
|
|
is_default=(model == primary_model)
|
|
)
|
|
|
|
# Execute with fallback - should raise exception
|
|
with pytest.raises(ModelNotAvailableException) as exc_info:
|
|
await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify exception contains relevant information
|
|
exception_str = str(exc_info.value)
|
|
assert "All providers failed" in exception_str or \
|
|
"Model is not available" in exception_str, \
|
|
"Exception should indicate all providers failed"
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
prompt=prompts(),
|
|
success_index=st.integers(min_value=0, max_value=2)
|
|
)
|
|
@settings(max_examples=30, deadline=None, suppress_health_check=[HealthCheck.large_base_example])
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_tries_models_in_order(
|
|
self, primary_model, prompt, success_index
|
|
):
|
|
"""
|
|
Property: Fallback should try models in the specified order and stop
|
|
at the first successful one.
|
|
|
|
For any chain of models, the system should try them in order and
|
|
return the result from the first successful model.
|
|
"""
|
|
# Fixed fallback chain
|
|
fallback_models = ["fallback-order-1", "fallback-order-2", "fallback-order-3"]
|
|
|
|
# Adjust success_index to be within bounds
|
|
success_index = min(success_index, len(fallback_models) - 1)
|
|
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register primary (will fail)
|
|
primary_config = ServiceConfig(
|
|
id=primary_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=primary_model,
|
|
args=[primary_model],
|
|
kwargs={"should_fail": True},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=True
|
|
)
|
|
primary_factory = ServiceFactory(primary_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
|
)
|
|
|
|
# Register fallbacks - only the one at success_index succeeds
|
|
for i, fallback_model in enumerate(fallback_models):
|
|
should_fail = (i < success_index) # Fail until success_index
|
|
fallback_config = ServiceConfig(
|
|
id=fallback_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=fallback_model,
|
|
args=[fallback_model],
|
|
kwargs={"should_fail": should_fail},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True
|
|
)
|
|
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
fallback_model, fallback_factory, ModelType.IMAGE
|
|
)
|
|
|
|
# Execute with fallback
|
|
response = await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify success
|
|
assert response.status == TaskStatus.SUCCEEDED
|
|
|
|
# Verify the result came from the expected model
|
|
expected_model = fallback_models[success_index]
|
|
result_url = response.results[0].url
|
|
assert expected_model in result_url, \
|
|
f"Result should come from model at index {success_index}: {expected_model}"
|
|
|
|
# Verify models after success_index were not tried
|
|
for i in range(success_index + 1, len(fallback_models)):
|
|
later_model = fallback_models[i]
|
|
# Get the service instance to check call count
|
|
service = ModelRegistry.get(later_model)
|
|
if service and hasattr(service, 'call_count'):
|
|
assert service.call_count == 0, \
|
|
f"Model {later_model} at index {i} should not be called after success at {success_index}"
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
prompt=prompts()
|
|
)
|
|
@settings(max_examples=30, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_no_fallback_when_primary_succeeds(
|
|
self, primary_model, prompt
|
|
):
|
|
"""
|
|
Property: When primary provider succeeds, fallback models should not
|
|
be tried.
|
|
|
|
For any primary model that succeeds, no fallback should occur.
|
|
"""
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register primary (will succeed)
|
|
primary_config = ServiceConfig(
|
|
id=primary_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=primary_model,
|
|
args=[primary_model],
|
|
kwargs={"should_fail": False},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=True
|
|
)
|
|
primary_factory = ServiceFactory(primary_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
|
)
|
|
|
|
# Register some fallback models
|
|
fallback_models = ["fallback-1", "fallback-2"]
|
|
for fallback_model in fallback_models:
|
|
fallback_config = ServiceConfig(
|
|
id=fallback_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=fallback_model,
|
|
args=[fallback_model],
|
|
kwargs={"should_fail": False},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True
|
|
)
|
|
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
fallback_model, fallback_factory, ModelType.IMAGE
|
|
)
|
|
|
|
# Execute with fallback
|
|
response = await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify success
|
|
assert response.status == TaskStatus.SUCCEEDED
|
|
|
|
# Verify result came from primary
|
|
result_url = response.results[0].url
|
|
assert primary_model in result_url, \
|
|
f"Result should come from primary model {primary_model}"
|
|
|
|
# Verify fallback models were not called
|
|
for fallback_model in fallback_models:
|
|
service = ModelRegistry.get(fallback_model)
|
|
if service and hasattr(service, 'call_count'):
|
|
assert service.call_count == 0, \
|
|
f"Fallback model {fallback_model} should not be called when primary succeeds"
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
fallback_models=fallback_chain(min_size=1, max_size=3),
|
|
prompt=prompts()
|
|
)
|
|
@settings(max_examples=30, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_skips_unhealthy_models(
|
|
self, primary_model, fallback_models, prompt
|
|
):
|
|
"""
|
|
Property: Fallback should skip models marked as unhealthy and try
|
|
the next available model.
|
|
|
|
For any chain of models where some are unhealthy, the system should
|
|
skip unhealthy ones and use the first healthy model.
|
|
"""
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register primary (will fail)
|
|
primary_config = ServiceConfig(
|
|
id=primary_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=primary_model,
|
|
args=[primary_model],
|
|
kwargs={"should_fail": True},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=True
|
|
)
|
|
primary_factory = ServiceFactory(primary_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
|
)
|
|
|
|
# Register fallbacks - all will succeed if called
|
|
for fallback_model in fallback_models:
|
|
fallback_config = ServiceConfig(
|
|
id=fallback_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=fallback_model,
|
|
args=[fallback_model],
|
|
kwargs={"should_fail": False},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True
|
|
)
|
|
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
fallback_model, fallback_factory, ModelType.IMAGE
|
|
)
|
|
|
|
# Mark first fallback as unhealthy (if there are multiple)
|
|
if len(fallback_models) > 1:
|
|
first_fallback = fallback_models[0]
|
|
for _ in range(5): # Multiple failures to trigger UNHEALTHY
|
|
result = HealthCheckResult(
|
|
status=HealthStatus.UNHEALTHY,
|
|
latency_ms=0.0,
|
|
timestamp=datetime.now(),
|
|
error="Test unhealthy"
|
|
)
|
|
health_monitor.update_health(first_fallback, result)
|
|
|
|
# Verify it's marked as unhealthy
|
|
health = health_monitor.get_health(first_fallback)
|
|
assert health is not None and health.status == HealthStatus.UNHEALTHY
|
|
|
|
# Execute with fallback
|
|
response = await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify success
|
|
assert response.status == TaskStatus.SUCCEEDED
|
|
|
|
# If we had multiple fallbacks and marked first as unhealthy,
|
|
# verify result came from second fallback
|
|
if len(fallback_models) > 1:
|
|
result_url = response.results[0].url
|
|
first_fallback = fallback_models[0]
|
|
assert not result_url.endswith(f"/{first_fallback}.jpg"), \
|
|
f"Result should not come from unhealthy model {first_fallback}"
|
|
|
|
# Should come from one of the healthy fallbacks
|
|
healthy_fallbacks = fallback_models[1:]
|
|
assert any(result_url.endswith(f"/{fb}.jpg") for fb in healthy_fallbacks), \
|
|
f"Result should come from one of the healthy fallbacks {healthy_fallbacks}"
|
|
|
|
@given(
|
|
primary_model=model_names(),
|
|
prompt=prompts()
|
|
)
|
|
@settings(max_examples=30, deadline=None)
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_handles_exceptions(
|
|
self, primary_model, prompt
|
|
):
|
|
"""
|
|
Property: Fallback should handle exceptions from providers and
|
|
continue to next provider.
|
|
|
|
For any provider that raises an exception, the system should catch it
|
|
and try the next provider in the chain.
|
|
"""
|
|
# Fixed fallback models
|
|
fallback_models = ["fallback-exc-1", "fallback-exc-2"]
|
|
|
|
# Clear registry
|
|
ModelRegistry._factories = {}
|
|
ModelRegistry._defaults = {}
|
|
health_monitor._health_status.clear()
|
|
|
|
# Register primary (will raise exception)
|
|
primary_config = ServiceConfig(
|
|
id=primary_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=primary_model,
|
|
args=[primary_model],
|
|
kwargs={"fail_with_exception": True},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True,
|
|
is_default=True
|
|
)
|
|
primary_factory = ServiceFactory(primary_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
primary_model, primary_factory, ModelType.IMAGE, is_default=True
|
|
)
|
|
|
|
# Register fallbacks - first one succeeds
|
|
for i, fallback_model in enumerate(fallback_models):
|
|
fallback_config = ServiceConfig(
|
|
id=fallback_model,
|
|
module="test",
|
|
class_name="MockProvider",
|
|
name=fallback_model,
|
|
args=[fallback_model],
|
|
kwargs={"should_fail": False},
|
|
type="image",
|
|
provider="mock",
|
|
enabled=True
|
|
)
|
|
fallback_factory = ServiceFactory(fallback_config, MockProvider)
|
|
ModelRegistry.register_factory(
|
|
fallback_model, fallback_factory, ModelType.IMAGE
|
|
)
|
|
|
|
# Execute with fallback - should handle exception and succeed with fallback
|
|
response = await ProviderService.generate_with_fallback(
|
|
primary_model=primary_model,
|
|
fallback_models=fallback_models,
|
|
operation="generate_image",
|
|
prompt=prompt
|
|
)
|
|
|
|
# Verify success despite exception from primary
|
|
assert response.status == TaskStatus.SUCCEEDED, \
|
|
"Fallback should succeed even when primary raises exception"
|
|
|
|
# Verify result came from fallback
|
|
result_url = response.results[0].url
|
|
assert primary_model not in result_url, \
|
|
f"Result should not come from failed primary {primary_model}"
|
|
assert any(fb in result_url for fb in fallback_models), \
|
|
f"Result should come from one of the fallbacks {fallback_models}"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--tb=short"])
|