stock/tests/unit/test_strategy_base.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

1128 lines
39 KiB
Python

"""Unit tests for strategy framework.
This module provides comprehensive tests for the strategy base classes,
registry, and factory components.
"""
from typing import Any, Dict, Optional
import pandas as pd
import pytest
from pydantic import ValidationError
from openclaw.strategy.base import (
Signal,
SignalType,
Strategy,
StrategyContext,
StrategyParameters,
)
from openclaw.strategy.buy import BuyParameters, BuyStrategy
from openclaw.strategy.factory import (
StrategyConfigurationError,
StrategyFactory,
StrategyFactoryError,
create_strategy,
create_strategy_from_config,
)
from openclaw.strategy.registry import (
StrategyNotFoundError,
StrategyRegistrationError,
clear_registry,
discover_strategies,
get_registered_strategies,
get_registry_stats,
get_strategy_class,
get_strategy_info,
get_strategies_by_tag,
is_strategy_registered,
register_strategy,
unregister_strategy,
)
from openclaw.strategy.sell import SellParameters, SellStrategy
from openclaw.strategy.select import SelectParameters, SelectResult, SelectStrategy
# =============================================================================
# Test Fixtures
# =============================================================================
@pytest.fixture
def sample_bar() -> pd.Series:
"""Create a sample market data bar."""
return pd.Series({
"open": 100.0,
"high": 105.0,
"low": 99.0,
"close": 102.0,
"volume": 1000000,
})
@pytest.fixture
def sample_context() -> StrategyContext:
"""Create a sample strategy context."""
return StrategyContext(
symbol="AAPL",
equity=10000.0,
positions={},
trades=[],
equity_curve=[10000.0],
bar_index=0,
)
@pytest.fixture
def strategy_factory() -> StrategyFactory:
"""Create a strategy factory."""
return StrategyFactory()
# =============================================================================
# Signal Tests
# =============================================================================
class TestSignal:
"""Tests for Signal class."""
def test_signal_creation(self) -> None:
"""Test creating a signal."""
signal = Signal(
signal_type=SignalType.BUY,
symbol="AAPL",
price=100.0,
quantity=10.0,
confidence=0.8,
)
assert signal.signal_type == SignalType.BUY
assert signal.symbol == "AAPL"
assert signal.price == 100.0
assert signal.quantity == 10.0
assert signal.confidence == 0.8
def test_signal_invalid_confidence_low(self) -> None:
"""Test signal with confidence below 0."""
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
Signal(signal_type=SignalType.BUY, symbol="AAPL", confidence=-0.1)
def test_signal_invalid_confidence_high(self) -> None:
"""Test signal with confidence above 1."""
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
Signal(signal_type=SignalType.BUY, symbol="AAPL", confidence=1.1)
def test_signal_default_values(self) -> None:
"""Test signal default values."""
signal = Signal(signal_type=SignalType.SELL, symbol="MSFT")
assert signal.price is None
assert signal.quantity is None
assert signal.confidence == 0.5
assert signal.metadata == {}
# =============================================================================
# StrategyContext Tests
# =============================================================================
class TestStrategyContext:
"""Tests for StrategyContext class."""
def test_context_creation(self) -> None:
"""Test creating a strategy context."""
context = StrategyContext(
symbol="AAPL",
equity=10000.0,
positions={"AAPL": {"quantity": 100}},
bar_index=5,
)
assert context.symbol == "AAPL"
assert context.equity == 10000.0
assert context.positions == {"AAPL": {"quantity": 100}}
assert context.bar_index == 5
def test_context_defaults(self) -> None:
"""Test context default values."""
context = StrategyContext()
assert context.symbol == ""
assert context.equity == 0.0
assert context.positions == {}
assert context.trades == []
assert context.equity_curve == []
assert context.bar_index == 0
assert context.market_data == {}
# =============================================================================
# StrategyParameters Tests
# =============================================================================
class TestStrategyParameters:
"""Tests for StrategyParameters class."""
def test_base_parameters_creation(self) -> None:
"""Test creating base strategy parameters."""
params = StrategyParameters()
assert params is not None
def test_base_parameters_forbid_extra(self) -> None:
"""Test that base parameters forbid extra fields."""
with pytest.raises(ValidationError):
StrategyParameters(invalid_field=True) # type: ignore
class TestBuyParameters:
"""Tests for BuyParameters class."""
def test_default_parameters(self) -> None:
"""Test default buy parameters."""
params = BuyParameters()
assert params.max_position_size == 0.1
assert params.min_confidence == 0.5
assert params.max_hold_bars == 0
assert params.entry_threshold == 0.0
def test_custom_parameters(self) -> None:
"""Test custom buy parameters."""
params = BuyParameters(
max_position_size=0.25,
min_confidence=0.7,
max_hold_bars=20,
entry_threshold=0.05,
)
assert params.max_position_size == 0.25
assert params.min_confidence == 0.7
assert params.max_hold_bars == 20
assert params.entry_threshold == 0.05
def test_invalid_max_position_size(self) -> None:
"""Test invalid max position size."""
with pytest.raises(ValidationError):
BuyParameters(max_position_size=0)
with pytest.raises(ValidationError):
BuyParameters(max_position_size=1.5)
def test_invalid_min_confidence(self) -> None:
"""Test invalid min confidence."""
with pytest.raises(ValidationError):
BuyParameters(min_confidence=-0.1)
with pytest.raises(ValidationError):
BuyParameters(min_confidence=1.1)
class TestSellParameters:
"""Tests for SellParameters class."""
def test_default_parameters(self) -> None:
"""Test default sell parameters."""
params = SellParameters()
assert params.stop_loss_pct == 0.05
assert params.take_profit_pct == 0.10
assert params.trailing_stop_pct is None
assert params.min_confidence == 0.5
assert params.exit_threshold == 0.0
def test_stop_loss_validation(self) -> None:
"""Test stop loss validation."""
# Valid
params = SellParameters(stop_loss_pct=0.5)
assert params.stop_loss_pct == 0.5 # Capped at 50%
with pytest.raises(ValidationError):
SellParameters(stop_loss_pct=-0.1)
with pytest.raises(ValidationError):
SellParameters(stop_loss_pct=1.5)
class TestSelectParameters:
"""Tests for SelectParameters class."""
def test_default_parameters(self) -> None:
"""Test default select parameters."""
params = SelectParameters()
assert params.max_selections == 10
assert params.min_score == 0.0
assert params.top_n is None
assert params.filter_volume is None
assert params.filter_price is None
def test_max_selections_validation(self) -> None:
"""Test max selections validation."""
with pytest.raises(ValidationError):
SelectParameters(max_selections=0)
with pytest.raises(ValidationError):
SelectParameters(top_n=0)
# =============================================================================
# Concrete Strategy Implementations for Testing
# =============================================================================
class MockBuyStrategy(BuyStrategy):
"""Mock buy strategy for testing."""
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
"""Always buy."""
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
"""Return fixed confidence."""
return 0.8
class MockSellStrategy(SellStrategy):
"""Mock sell strategy for testing."""
def _should_sell(self, data: pd.Series, context: StrategyContext, position: Any) -> bool:
"""Always sell."""
return True
def _calculate_sell_confidence(self, data: pd.Series, context: StrategyContext, position: Any) -> float:
"""Return fixed confidence."""
return 0.7
class MockSelectStrategy(SelectStrategy):
"""Mock select strategy for testing."""
def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
"""Return score based on symbol length."""
return float(len(symbol))
# =============================================================================
# Strategy Base Tests
# =============================================================================
class TestStrategyBase:
"""Tests for Strategy base class."""
def test_strategy_initialization(self) -> None:
"""Test strategy initialization."""
strategy = MockBuyStrategy(name="test_strategy")
assert strategy.name == "test_strategy"
assert not strategy.is_initialized
assert not strategy.is_active
def test_strategy_initialize(self) -> None:
"""Test strategy initialize method."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
assert strategy.is_initialized
assert strategy.is_active
def test_strategy_double_initialize(self) -> None:
"""Test double initialization warning."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
strategy.initialize() # Should log warning but not fail
def test_strategy_shutdown(self) -> None:
"""Test strategy shutdown."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
strategy.shutdown()
assert strategy.is_initialized # Still initialized
assert not strategy.is_active # But not active
def test_strategy_context_manager(self) -> None:
"""Test strategy as context manager."""
with MockBuyStrategy(name="test_strategy") as strategy:
assert strategy.is_initialized
assert strategy.is_active
assert not strategy.is_active
def test_strategy_process_bar_not_initialized(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test processing bar without initialization."""
strategy = MockBuyStrategy(name="test_strategy")
with pytest.raises(RuntimeError, match="not initialized"):
strategy.process_bar(sample_bar, sample_context)
def test_strategy_process_bar_not_active(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test processing bar when not active."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
strategy.shutdown()
result = strategy.process_bar(sample_bar, sample_context)
assert result is None
def test_strategy_signal_counting(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test signal counting."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
assert strategy.signals_generated == 0
strategy.process_bar(sample_bar, sample_context)
assert strategy.signals_generated == 1
strategy.process_bar(sample_bar, sample_context)
assert strategy.signals_generated == 2
def test_strategy_get_state(self) -> None:
"""Test getting strategy state."""
strategy = MockBuyStrategy(name="test_strategy", description="Test strategy")
state = strategy.get_state()
assert state["name"] == "test_strategy"
assert state["description"] == "Test strategy"
assert not state["initialized"]
def test_strategy_reset(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test strategy reset."""
strategy = MockBuyStrategy(name="test_strategy")
strategy.initialize()
strategy.process_bar(sample_bar, sample_context)
assert strategy.signals_generated == 1
strategy.reset()
assert strategy.signals_generated == 0
# =============================================================================
# BuyStrategy Tests
# =============================================================================
class TestBuyStrategy:
"""Tests for BuyStrategy class."""
def test_buy_strategy_creation(self) -> None:
"""Test buy strategy creation."""
strategy = MockBuyStrategy(
name="test_buy",
parameters=BuyParameters(max_position_size=0.2),
description="Test buy strategy",
)
assert strategy.name == "test_buy"
assert strategy.parameters.max_position_size == 0.2
assert strategy.description == "Test buy strategy"
def test_buy_signal_generation(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test buy signal generation."""
strategy = MockBuyStrategy(name="test_buy")
strategy.initialize()
signal = strategy.process_bar(sample_bar, sample_context)
assert signal is not None
assert signal.signal_type == SignalType.BUY
assert signal.symbol == "AAPL"
assert signal.confidence == 0.8
def test_buy_position_size_calculation(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test position size calculation."""
strategy = MockBuyStrategy(
name="test_buy",
parameters=BuyParameters(max_position_size=0.1),
)
quantity = strategy._calculate_position_size(sample_bar, sample_context)
expected = (10000.0 * 0.1) / 102.0
assert quantity == pytest.approx(expected, rel=1e-5)
def test_buy_stats(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test buy strategy statistics."""
strategy = MockBuyStrategy(name="test_buy")
strategy.initialize()
stats = strategy.get_buy_stats()
assert stats["buy_signals_generated"] == 0
assert stats["positions_entered"] == 0
strategy.process_bar(sample_bar, sample_context)
stats = strategy.get_buy_stats()
assert stats["buy_signals_generated"] == 1
# =============================================================================
# SellStrategy Tests
# =============================================================================
class TestSellStrategy:
"""Tests for SellStrategy class."""
def test_sell_strategy_creation(self) -> None:
"""Test sell strategy creation."""
strategy = MockSellStrategy(
name="test_sell",
parameters=SellParameters(stop_loss_pct=0.03),
)
assert strategy.name == "test_sell"
assert strategy.parameters.stop_loss_pct == 0.03
def test_sell_signal_generation_no_position(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test sell signal without position."""
strategy = MockSellStrategy(name="test_sell")
strategy.initialize()
# No position, should not generate signal
signal = strategy.process_bar(sample_bar, sample_context)
assert signal is None
def test_sell_signal_generation_with_position(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
"""Test sell signal with position."""
strategy = MockSellStrategy(name="test_sell")
strategy.initialize()
# Add a position to context
sample_context.positions["AAPL"] = type("Position", (), {
"quantity": 100,
"entry_price": 100.0,
})()
signal = strategy.process_bar(sample_bar, sample_context)
assert signal is not None
assert signal.signal_type == SignalType.SELL
def test_stop_loss_check(self, sample_bar: pd.Series) -> None:
"""Test stop loss check."""
strategy = MockSellStrategy(name="test_sell")
position = type("Position", (), {"entry_price": 110.0})()
# Price at 102, entry at 110, stop at 5%
# Loss = (110 - 102) / 110 = 7.27%, should trigger
assert strategy._check_stop_loss(sample_bar, position)
def test_take_profit_check(self, sample_bar: pd.Series) -> None:
"""Test take profit check."""
strategy = MockSellStrategy(name="test_sell")
position = type("Position", (), {"entry_price": 90.0})()
# Price at 102, entry at 90, profit = 13.3%, should trigger
assert strategy._check_take_profit(sample_bar, position)
def test_trailing_stop(self, sample_bar: pd.Series) -> None:
"""Test trailing stop functionality."""
strategy = MockSellStrategy(
name="test_sell",
parameters=SellParameters(trailing_stop_pct=0.02),
)
# Update highest price
strategy._update_trailing_stop("AAPL", pd.Series({"high": 110.0}))
assert strategy._highest_price_seen["AAPL"] == 110.0
# Check trailing stop with 2% pullback
current_bar = pd.Series({"low": 107.0}) # 2.7% pullback from 110
position = type("Position", (), {"symbol": "AAPL"})()
assert strategy._check_trailing_stop(current_bar, position)
# =============================================================================
# SelectStrategy Tests
# =============================================================================
class TestSelectStrategy:
"""Tests for SelectStrategy class."""
def test_select_strategy_creation(self) -> None:
"""Test select strategy creation."""
strategy = MockSelectStrategy(
name="test_select",
parameters=SelectParameters(max_selections=5),
)
assert strategy.name == "test_select"
assert strategy.parameters.max_selections == 5
def test_select_result_creation(self) -> None:
"""Test select result creation."""
result = SelectResult(symbol="AAPL", score=10.5, selected=True, rank=1)
assert result.symbol == "AAPL"
assert result.score == 10.5
assert result.selected
assert result.rank == 1
def test_select_result_empty_symbol(self) -> None:
"""Test select result with empty symbol."""
with pytest.raises(ValueError, match="Symbol cannot be empty"):
SelectResult(symbol="")
def test_select_from_universe(self) -> None:
"""Test selection from universe."""
strategy = MockSelectStrategy(name="test_select")
universe = {
"A": pd.DataFrame({"close": [1, 2, 3]}),
"BB": pd.DataFrame({"close": [1, 2, 3]}),
"CCC": pd.DataFrame({"close": [1, 2, 3]}),
}
results = strategy.select(universe)
assert len(results) == 3
# Should be sorted by score (symbol length)
assert results[0].symbol == "CCC" # length 3
assert results[1].symbol == "BB" # length 2
assert results[2].symbol == "A" # length 1
def test_select_with_filters(self) -> None:
"""Test selection with filters."""
strategy = MockSelectStrategy(
name="test_select",
parameters=SelectParameters(filter_price=5.0),
)
universe = {
"A": pd.DataFrame({"close": [1, 2, 3]}), # price < 5, filtered out
"B": pd.DataFrame({"close": [10, 11, 12]}), # price > 5, included
}
results = strategy.select(universe)
a_result = next(r for r in results if r.symbol == "A")
b_result = next(r for r in results if r.symbol == "B")
assert not a_result.selected # Filtered out
assert b_result.selected # Included
def test_select_max_selections(self) -> None:
"""Test max selections limit."""
strategy = MockSelectStrategy(
name="test_select",
parameters=SelectParameters(max_selections=2),
)
universe = {
"A": pd.DataFrame({"close": [1]}),
"BB": pd.DataFrame({"close": [1]}),
"CCC": pd.DataFrame({"close": [1]}),
"DDDD": pd.DataFrame({"close": [1]}),
}
results = strategy.select(universe)
selected = [r for r in results if r.selected]
assert len(selected) == 2
def test_select_top_n(self) -> None:
"""Test top N selection."""
strategy = MockSelectStrategy(
name="test_select",
parameters=SelectParameters(top_n=2),
)
universe = {
"A": pd.DataFrame({"close": [1]}),
"BB": pd.DataFrame({"close": [1]}),
"CCC": pd.DataFrame({"close": [1]}),
}
results = strategy.select(universe)
selected = [r for r in results if r.selected]
assert len(selected) == 2
def test_get_top_selections(self) -> None:
"""Test getting top selections."""
strategy = MockSelectStrategy(name="test_select")
results = [
SelectResult(symbol="A", score=10, selected=True),
SelectResult(symbol="B", score=8, selected=True),
SelectResult(symbol="C", score=6, selected=True),
]
top = strategy.get_top_selections(results, n=2)
assert len(top) == 2
def test_select_stats(self) -> None:
"""Test selection statistics."""
strategy = MockSelectStrategy(name="test_select")
universe = {
"A": pd.DataFrame({"close": [1]}),
"BB": pd.DataFrame({"close": [1]}),
}
strategy.select(universe)
stats = strategy.get_selection_stats()
assert stats["selections_made"] == 1
assert stats["avg_candidates"] == 2.0
# =============================================================================
# Registry Tests
# =============================================================================
class TestRegistry:
"""Tests for strategy registry."""
def setup_method(self) -> None:
"""Clear registry before each test."""
clear_registry()
def teardown_method(self) -> None:
"""Clear registry after each test."""
clear_registry()
def test_register_strategy(self) -> None:
"""Test strategy registration."""
@register_strategy(
name="test_strategy",
description="A test strategy",
tags=["test", "mock"],
)
class TestStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
assert "test_strategy" in get_registered_strategies()
info = get_strategy_info("test_strategy")
assert info["description"] == "A test strategy"
assert "test" in info["tags"]
def test_register_duplicate(self) -> None:
"""Test registering duplicate strategy."""
@register_strategy(name="dup_strategy")
class Strategy1(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
with pytest.raises(StrategyRegistrationError, match="already registered"):
@register_strategy(name="dup_strategy")
class Strategy2(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
def test_register_non_strategy(self) -> None:
"""Test registering non-strategy class."""
with pytest.raises(StrategyRegistrationError, match="must inherit from Strategy"):
@register_strategy(name="invalid")
class NotAStrategy:
pass
def test_unregister_strategy(self) -> None:
"""Test unregistering strategy."""
@register_strategy(name="to_remove")
class TempStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
assert "to_remove" in get_registered_strategies()
result = unregister_strategy("to_remove")
assert result
assert "to_remove" not in get_registered_strategies()
def test_unregister_not_found(self) -> None:
"""Test unregistering non-existent strategy."""
result = unregister_strategy("non_existent")
assert not result
def test_get_strategy_class(self) -> None:
"""Test getting strategy class."""
@register_strategy(name="my_strategy")
class MyStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
cls = get_strategy_class("my_strategy")
assert cls.__name__ == "MyStrategy"
def test_get_strategy_class_not_found(self) -> None:
"""Test getting non-existent strategy class."""
with pytest.raises(StrategyNotFoundError):
get_strategy_class("non_existent")
def test_is_strategy_registered(self) -> None:
"""Test checking if strategy is registered."""
@register_strategy(name="check_me")
class CheckStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
assert is_strategy_registered("check_me")
assert not is_strategy_registered("not_registered")
def test_get_strategies_by_tag(self) -> None:
"""Test getting strategies by tag."""
@register_strategy(name="tagged_strategy", tags=["momentum", "trend"])
class TaggedStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
momentum_strategies = get_strategies_by_tag("momentum")
assert "tagged_strategy" in momentum_strategies
trend_strategies = get_strategies_by_tag("trend")
assert "tagged_strategy" in trend_strategies
def test_registry_stats(self) -> None:
"""Test registry statistics."""
@register_strategy(name="stats_strategy", tags=["test"])
class StatsStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
stats = get_registry_stats()
assert stats["total_strategies"] == 1
assert "stats_strategy" in stats["strategy_names"]
assert "test" in stats["unique_tags"]
def test_clear_registry(self) -> None:
"""Test clearing registry."""
@register_strategy(name="temp")
class TempStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.5
assert len(get_registered_strategies()) > 0
clear_registry()
assert len(get_registered_strategies()) == 0
# =============================================================================
# Factory Tests
# =============================================================================
class TestFactory:
"""Tests for StrategyFactory."""
def setup_method(self) -> None:
"""Clear registry and register test strategies."""
clear_registry()
@register_strategy(name="mock_buy")
class RegisteredBuyStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.8
def teardown_method(self) -> None:
"""Clear registry."""
clear_registry()
def test_factory_create(self, strategy_factory: StrategyFactory) -> None:
"""Test creating strategy via factory."""
strategy = strategy_factory.create(
name="mock_buy",
parameters={"max_position_size": 0.2},
)
assert strategy.name == "mock_buy"
assert strategy.parameters.max_position_size == 0.2
def test_factory_create_not_found(self, strategy_factory: StrategyFactory) -> None:
"""Test creating non-existent strategy."""
with pytest.raises(StrategyFactoryError, match="not found"):
strategy_factory.create("non_existent")
def test_factory_create_with_class(self, strategy_factory: StrategyFactory) -> None:
"""Test creating strategy with explicit class."""
strategy = strategy_factory.create(
name="explicit_strategy",
strategy_class=MockBuyStrategy,
parameters={"max_position_size": 0.3},
)
assert strategy.name == "explicit_strategy"
assert strategy.parameters.max_position_size == 0.3
def test_factory_create_from_config(self, strategy_factory: StrategyFactory) -> None:
"""Test creating strategy from config."""
config = {
"name": "mock_buy",
"parameters": {"max_position_size": 0.15},
"description": "Created from config",
}
strategy = strategy_factory.create_from_config(config)
assert strategy.name == "mock_buy"
assert strategy.parameters.max_position_size == 0.15
assert strategy.description == "Created from config"
def test_factory_invalid_config(self, strategy_factory: StrategyFactory) -> None:
"""Test creating strategy with invalid config."""
with pytest.raises(StrategyConfigurationError):
strategy_factory.create_from_config({}) # Missing name
def test_factory_create_buy_strategy(self, strategy_factory: StrategyFactory) -> None:
"""Test creating buy strategy."""
strategy = strategy_factory.create_buy_strategy(
name="mock_buy",
parameters={"max_position_size": 0.2},
)
assert isinstance(strategy, BuyStrategy)
def test_factory_create_buy_strategy_wrong_type(self, strategy_factory: StrategyFactory) -> None:
"""Test creating buy strategy with wrong type."""
with pytest.raises(StrategyFactoryError, match="not a BuyStrategy"):
strategy_factory.create_buy_strategy(
name="test",
strategy_class=MockSellStrategy, # type: ignore
)
def test_factory_create_sell_strategy(self, strategy_factory: StrategyFactory) -> None:
"""Test creating sell strategy."""
strategy = strategy_factory.create_sell_strategy(
name="test_sell",
strategy_class=MockSellStrategy,
)
assert isinstance(strategy, SellStrategy)
def test_factory_create_select_strategy(self, strategy_factory: StrategyFactory) -> None:
"""Test creating select strategy."""
strategy = strategy_factory.create_select_strategy(
name="test_select",
strategy_class=MockSelectStrategy,
)
assert isinstance(strategy, SelectStrategy)
def test_convenience_function_create_strategy(self) -> None:
"""Test create_strategy convenience function."""
strategy = create_strategy(
name="test",
strategy_class=MockBuyStrategy,
)
assert isinstance(strategy, BuyStrategy)
def test_convenience_function_create_from_config(self) -> None:
"""Test create_strategy_from_config convenience function."""
config = {
"name": "test",
"strategy_type": "mock_buy",
}
strategy = create_strategy_from_config(config)
assert strategy.name == "test"
# =============================================================================
# Edge Cases and Integration Tests
# =============================================================================
class TestEdgeCases:
"""Tests for edge cases."""
def test_signal_types(self) -> None:
"""Test all signal types."""
for sig_type in SignalType:
signal = Signal(signal_type=sig_type, symbol="TEST")
assert signal.signal_type == sig_type
def test_empty_universe_selection(self) -> None:
"""Test selection with empty universe."""
strategy = MockSelectStrategy(name="test")
results = strategy.select({})
assert results == []
def test_strategy_with_invalid_parameters(self) -> None:
"""Test strategy with invalid parameters."""
with pytest.raises(ValidationError):
BuyParameters(max_position_size=-1)
def test_context_with_custom_data(self) -> None:
"""Test context with custom data."""
context = StrategyContext(
custom_data={"indicator_value": 42, "threshold": 0.5},
)
assert context.custom_data["indicator_value"] == 42
def test_signal_metadata(self) -> None:
"""Test signal with metadata."""
signal = Signal(
signal_type=SignalType.BUY,
symbol="AAPL",
metadata={
"indicator": "RSI",
"value": 30.5,
"threshold": 30.0,
},
)
assert signal.metadata["indicator"] == "RSI"
assert signal.metadata["value"] == 30.5
# =============================================================================
# Integration Tests
# =============================================================================
class TestIntegration:
"""Integration tests for strategy framework."""
def setup_method(self) -> None:
"""Set up test environment."""
clear_registry()
def teardown_method(self) -> None:
"""Clean up test environment."""
clear_registry()
def test_full_strategy_lifecycle(self) -> None:
"""Test full strategy lifecycle."""
@register_strategy(name="lifecycle_test")
class LifecycleStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return data.get("close", 0) > 100
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.75
# Create via factory
factory = StrategyFactory()
strategy = factory.create("lifecycle_test", parameters={"max_position_size": 0.2})
# Initialize
strategy.initialize()
assert strategy.is_initialized
# Process bars
context = StrategyContext(symbol="AAPL", equity=10000.0)
bar1 = pd.Series({"open": 98, "high": 102, "low": 97, "close": 99, "volume": 1000})
signal1 = strategy.process_bar(bar1, context)
assert signal1 is None # Price <= 100
bar2 = pd.Series({"open": 101, "high": 105, "low": 100, "close": 102, "volume": 1500})
signal2 = strategy.process_bar(bar2, context)
assert signal2 is not None # Price > 100
assert signal2.signal_type == SignalType.BUY
assert signal2.confidence == 0.75
# Shutdown
strategy.shutdown()
assert not strategy.is_active
def test_registry_factory_integration(self) -> None:
"""Test registry and factory integration."""
@register_strategy(name="integration_test", tags=["integration"])
class IntegrationStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.9
# Verify registration
assert is_strategy_registered("integration_test")
# Get info
info = get_strategy_info("integration_test")
assert "integration" in info["tags"]
# Create via factory
factory = StrategyFactory()
strategy = factory.create("integration_test")
assert isinstance(strategy, BuyStrategy)
assert strategy.name == "integration_test"
def test_multiple_strategy_types(self) -> None:
"""Test using multiple strategy types together."""
# Register different strategy types
@register_strategy(name="buyer")
class TestBuyStrategy(BuyStrategy):
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
return True
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
return 0.8
@register_strategy(name="seller")
class TestSellStrategy(SellStrategy):
def _should_sell(self, data: pd.Series, context: StrategyContext, position: Any) -> bool:
return True
def _calculate_sell_confidence(self, data: pd.Series, context: StrategyContext, position: Any) -> float:
return 0.7
@register_strategy(name="selector")
class TestSelectStrategy(SelectStrategy):
def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
return float(len(symbol))
factory = StrategyFactory()
# Create each type
buy_strategy = factory.create_buy_strategy("buyer")
sell_strategy = factory.create_sell_strategy("seller")
select_strategy = factory.create_select_strategy("selector")
assert isinstance(buy_strategy, BuyStrategy)
assert isinstance(sell_strategy, SellStrategy)
assert isinstance(select_strategy, SelectStrategy)
# Check registry
stats = get_registry_stats()
assert stats["total_strategies"] == 3