1128 lines
39 KiB
Python
1128 lines
39 KiB
Python
"""Unit tests for strategy framework.
|
|
|
|
This module provides comprehensive tests for the strategy base classes,
|
|
registry, and factory components.
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
from pydantic import ValidationError
|
|
|
|
from openclaw.strategy.base import (
|
|
Signal,
|
|
SignalType,
|
|
Strategy,
|
|
StrategyContext,
|
|
StrategyParameters,
|
|
)
|
|
from openclaw.strategy.buy import BuyParameters, BuyStrategy
|
|
from openclaw.strategy.factory import (
|
|
StrategyConfigurationError,
|
|
StrategyFactory,
|
|
StrategyFactoryError,
|
|
create_strategy,
|
|
create_strategy_from_config,
|
|
)
|
|
from openclaw.strategy.registry import (
|
|
StrategyNotFoundError,
|
|
StrategyRegistrationError,
|
|
clear_registry,
|
|
discover_strategies,
|
|
get_registered_strategies,
|
|
get_registry_stats,
|
|
get_strategy_class,
|
|
get_strategy_info,
|
|
get_strategies_by_tag,
|
|
is_strategy_registered,
|
|
register_strategy,
|
|
unregister_strategy,
|
|
)
|
|
from openclaw.strategy.sell import SellParameters, SellStrategy
|
|
from openclaw.strategy.select import SelectParameters, SelectResult, SelectStrategy
|
|
|
|
|
|
# =============================================================================
|
|
# Test Fixtures
|
|
# =============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_bar() -> pd.Series:
|
|
"""Create a sample market data bar."""
|
|
return pd.Series({
|
|
"open": 100.0,
|
|
"high": 105.0,
|
|
"low": 99.0,
|
|
"close": 102.0,
|
|
"volume": 1000000,
|
|
})
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_context() -> StrategyContext:
|
|
"""Create a sample strategy context."""
|
|
return StrategyContext(
|
|
symbol="AAPL",
|
|
equity=10000.0,
|
|
positions={},
|
|
trades=[],
|
|
equity_curve=[10000.0],
|
|
bar_index=0,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def strategy_factory() -> StrategyFactory:
|
|
"""Create a strategy factory."""
|
|
return StrategyFactory()
|
|
|
|
|
|
# =============================================================================
|
|
# Signal Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestSignal:
|
|
"""Tests for Signal class."""
|
|
|
|
def test_signal_creation(self) -> None:
|
|
"""Test creating a signal."""
|
|
signal = Signal(
|
|
signal_type=SignalType.BUY,
|
|
symbol="AAPL",
|
|
price=100.0,
|
|
quantity=10.0,
|
|
confidence=0.8,
|
|
)
|
|
|
|
assert signal.signal_type == SignalType.BUY
|
|
assert signal.symbol == "AAPL"
|
|
assert signal.price == 100.0
|
|
assert signal.quantity == 10.0
|
|
assert signal.confidence == 0.8
|
|
|
|
def test_signal_invalid_confidence_low(self) -> None:
|
|
"""Test signal with confidence below 0."""
|
|
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
|
|
Signal(signal_type=SignalType.BUY, symbol="AAPL", confidence=-0.1)
|
|
|
|
def test_signal_invalid_confidence_high(self) -> None:
|
|
"""Test signal with confidence above 1."""
|
|
with pytest.raises(ValueError, match="Confidence must be between 0.0 and 1.0"):
|
|
Signal(signal_type=SignalType.BUY, symbol="AAPL", confidence=1.1)
|
|
|
|
def test_signal_default_values(self) -> None:
|
|
"""Test signal default values."""
|
|
signal = Signal(signal_type=SignalType.SELL, symbol="MSFT")
|
|
|
|
assert signal.price is None
|
|
assert signal.quantity is None
|
|
assert signal.confidence == 0.5
|
|
assert signal.metadata == {}
|
|
|
|
|
|
# =============================================================================
|
|
# StrategyContext Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestStrategyContext:
|
|
"""Tests for StrategyContext class."""
|
|
|
|
def test_context_creation(self) -> None:
|
|
"""Test creating a strategy context."""
|
|
context = StrategyContext(
|
|
symbol="AAPL",
|
|
equity=10000.0,
|
|
positions={"AAPL": {"quantity": 100}},
|
|
bar_index=5,
|
|
)
|
|
|
|
assert context.symbol == "AAPL"
|
|
assert context.equity == 10000.0
|
|
assert context.positions == {"AAPL": {"quantity": 100}}
|
|
assert context.bar_index == 5
|
|
|
|
def test_context_defaults(self) -> None:
|
|
"""Test context default values."""
|
|
context = StrategyContext()
|
|
|
|
assert context.symbol == ""
|
|
assert context.equity == 0.0
|
|
assert context.positions == {}
|
|
assert context.trades == []
|
|
assert context.equity_curve == []
|
|
assert context.bar_index == 0
|
|
assert context.market_data == {}
|
|
|
|
|
|
# =============================================================================
|
|
# StrategyParameters Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestStrategyParameters:
|
|
"""Tests for StrategyParameters class."""
|
|
|
|
def test_base_parameters_creation(self) -> None:
|
|
"""Test creating base strategy parameters."""
|
|
params = StrategyParameters()
|
|
assert params is not None
|
|
|
|
def test_base_parameters_forbid_extra(self) -> None:
|
|
"""Test that base parameters forbid extra fields."""
|
|
with pytest.raises(ValidationError):
|
|
StrategyParameters(invalid_field=True) # type: ignore
|
|
|
|
|
|
class TestBuyParameters:
|
|
"""Tests for BuyParameters class."""
|
|
|
|
def test_default_parameters(self) -> None:
|
|
"""Test default buy parameters."""
|
|
params = BuyParameters()
|
|
|
|
assert params.max_position_size == 0.1
|
|
assert params.min_confidence == 0.5
|
|
assert params.max_hold_bars == 0
|
|
assert params.entry_threshold == 0.0
|
|
|
|
def test_custom_parameters(self) -> None:
|
|
"""Test custom buy parameters."""
|
|
params = BuyParameters(
|
|
max_position_size=0.25,
|
|
min_confidence=0.7,
|
|
max_hold_bars=20,
|
|
entry_threshold=0.05,
|
|
)
|
|
|
|
assert params.max_position_size == 0.25
|
|
assert params.min_confidence == 0.7
|
|
assert params.max_hold_bars == 20
|
|
assert params.entry_threshold == 0.05
|
|
|
|
def test_invalid_max_position_size(self) -> None:
|
|
"""Test invalid max position size."""
|
|
with pytest.raises(ValidationError):
|
|
BuyParameters(max_position_size=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
BuyParameters(max_position_size=1.5)
|
|
|
|
def test_invalid_min_confidence(self) -> None:
|
|
"""Test invalid min confidence."""
|
|
with pytest.raises(ValidationError):
|
|
BuyParameters(min_confidence=-0.1)
|
|
|
|
with pytest.raises(ValidationError):
|
|
BuyParameters(min_confidence=1.1)
|
|
|
|
|
|
class TestSellParameters:
|
|
"""Tests for SellParameters class."""
|
|
|
|
def test_default_parameters(self) -> None:
|
|
"""Test default sell parameters."""
|
|
params = SellParameters()
|
|
|
|
assert params.stop_loss_pct == 0.05
|
|
assert params.take_profit_pct == 0.10
|
|
assert params.trailing_stop_pct is None
|
|
assert params.min_confidence == 0.5
|
|
assert params.exit_threshold == 0.0
|
|
|
|
def test_stop_loss_validation(self) -> None:
|
|
"""Test stop loss validation."""
|
|
# Valid
|
|
params = SellParameters(stop_loss_pct=0.5)
|
|
assert params.stop_loss_pct == 0.5 # Capped at 50%
|
|
|
|
with pytest.raises(ValidationError):
|
|
SellParameters(stop_loss_pct=-0.1)
|
|
|
|
with pytest.raises(ValidationError):
|
|
SellParameters(stop_loss_pct=1.5)
|
|
|
|
|
|
class TestSelectParameters:
|
|
"""Tests for SelectParameters class."""
|
|
|
|
def test_default_parameters(self) -> None:
|
|
"""Test default select parameters."""
|
|
params = SelectParameters()
|
|
|
|
assert params.max_selections == 10
|
|
assert params.min_score == 0.0
|
|
assert params.top_n is None
|
|
assert params.filter_volume is None
|
|
assert params.filter_price is None
|
|
|
|
def test_max_selections_validation(self) -> None:
|
|
"""Test max selections validation."""
|
|
with pytest.raises(ValidationError):
|
|
SelectParameters(max_selections=0)
|
|
|
|
with pytest.raises(ValidationError):
|
|
SelectParameters(top_n=0)
|
|
|
|
|
|
# =============================================================================
|
|
# Concrete Strategy Implementations for Testing
|
|
# =============================================================================
|
|
|
|
|
|
class MockBuyStrategy(BuyStrategy):
|
|
"""Mock buy strategy for testing."""
|
|
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
"""Always buy."""
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
"""Return fixed confidence."""
|
|
return 0.8
|
|
|
|
|
|
class MockSellStrategy(SellStrategy):
|
|
"""Mock sell strategy for testing."""
|
|
|
|
def _should_sell(self, data: pd.Series, context: StrategyContext, position: Any) -> bool:
|
|
"""Always sell."""
|
|
return True
|
|
|
|
def _calculate_sell_confidence(self, data: pd.Series, context: StrategyContext, position: Any) -> float:
|
|
"""Return fixed confidence."""
|
|
return 0.7
|
|
|
|
|
|
class MockSelectStrategy(SelectStrategy):
|
|
"""Mock select strategy for testing."""
|
|
|
|
def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
|
|
"""Return score based on symbol length."""
|
|
return float(len(symbol))
|
|
|
|
|
|
# =============================================================================
|
|
# Strategy Base Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestStrategyBase:
|
|
"""Tests for Strategy base class."""
|
|
|
|
def test_strategy_initialization(self) -> None:
|
|
"""Test strategy initialization."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
|
|
assert strategy.name == "test_strategy"
|
|
assert not strategy.is_initialized
|
|
assert not strategy.is_active
|
|
|
|
def test_strategy_initialize(self) -> None:
|
|
"""Test strategy initialize method."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
|
|
strategy.initialize()
|
|
|
|
assert strategy.is_initialized
|
|
assert strategy.is_active
|
|
|
|
def test_strategy_double_initialize(self) -> None:
|
|
"""Test double initialization warning."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
strategy.initialize()
|
|
strategy.initialize() # Should log warning but not fail
|
|
|
|
def test_strategy_shutdown(self) -> None:
|
|
"""Test strategy shutdown."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
strategy.initialize()
|
|
strategy.shutdown()
|
|
|
|
assert strategy.is_initialized # Still initialized
|
|
assert not strategy.is_active # But not active
|
|
|
|
def test_strategy_context_manager(self) -> None:
|
|
"""Test strategy as context manager."""
|
|
with MockBuyStrategy(name="test_strategy") as strategy:
|
|
assert strategy.is_initialized
|
|
assert strategy.is_active
|
|
|
|
assert not strategy.is_active
|
|
|
|
def test_strategy_process_bar_not_initialized(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test processing bar without initialization."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
|
|
with pytest.raises(RuntimeError, match="not initialized"):
|
|
strategy.process_bar(sample_bar, sample_context)
|
|
|
|
def test_strategy_process_bar_not_active(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test processing bar when not active."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
strategy.initialize()
|
|
strategy.shutdown()
|
|
|
|
result = strategy.process_bar(sample_bar, sample_context)
|
|
assert result is None
|
|
|
|
def test_strategy_signal_counting(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test signal counting."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
strategy.initialize()
|
|
|
|
assert strategy.signals_generated == 0
|
|
|
|
strategy.process_bar(sample_bar, sample_context)
|
|
assert strategy.signals_generated == 1
|
|
|
|
strategy.process_bar(sample_bar, sample_context)
|
|
assert strategy.signals_generated == 2
|
|
|
|
def test_strategy_get_state(self) -> None:
|
|
"""Test getting strategy state."""
|
|
strategy = MockBuyStrategy(name="test_strategy", description="Test strategy")
|
|
state = strategy.get_state()
|
|
|
|
assert state["name"] == "test_strategy"
|
|
assert state["description"] == "Test strategy"
|
|
assert not state["initialized"]
|
|
|
|
def test_strategy_reset(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test strategy reset."""
|
|
strategy = MockBuyStrategy(name="test_strategy")
|
|
strategy.initialize()
|
|
|
|
strategy.process_bar(sample_bar, sample_context)
|
|
assert strategy.signals_generated == 1
|
|
|
|
strategy.reset()
|
|
assert strategy.signals_generated == 0
|
|
|
|
|
|
# =============================================================================
|
|
# BuyStrategy Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestBuyStrategy:
|
|
"""Tests for BuyStrategy class."""
|
|
|
|
def test_buy_strategy_creation(self) -> None:
|
|
"""Test buy strategy creation."""
|
|
strategy = MockBuyStrategy(
|
|
name="test_buy",
|
|
parameters=BuyParameters(max_position_size=0.2),
|
|
description="Test buy strategy",
|
|
)
|
|
|
|
assert strategy.name == "test_buy"
|
|
assert strategy.parameters.max_position_size == 0.2
|
|
assert strategy.description == "Test buy strategy"
|
|
|
|
def test_buy_signal_generation(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test buy signal generation."""
|
|
strategy = MockBuyStrategy(name="test_buy")
|
|
strategy.initialize()
|
|
|
|
signal = strategy.process_bar(sample_bar, sample_context)
|
|
|
|
assert signal is not None
|
|
assert signal.signal_type == SignalType.BUY
|
|
assert signal.symbol == "AAPL"
|
|
assert signal.confidence == 0.8
|
|
|
|
def test_buy_position_size_calculation(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test position size calculation."""
|
|
strategy = MockBuyStrategy(
|
|
name="test_buy",
|
|
parameters=BuyParameters(max_position_size=0.1),
|
|
)
|
|
|
|
quantity = strategy._calculate_position_size(sample_bar, sample_context)
|
|
|
|
expected = (10000.0 * 0.1) / 102.0
|
|
assert quantity == pytest.approx(expected, rel=1e-5)
|
|
|
|
def test_buy_stats(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test buy strategy statistics."""
|
|
strategy = MockBuyStrategy(name="test_buy")
|
|
strategy.initialize()
|
|
|
|
stats = strategy.get_buy_stats()
|
|
assert stats["buy_signals_generated"] == 0
|
|
assert stats["positions_entered"] == 0
|
|
|
|
strategy.process_bar(sample_bar, sample_context)
|
|
|
|
stats = strategy.get_buy_stats()
|
|
assert stats["buy_signals_generated"] == 1
|
|
|
|
|
|
# =============================================================================
|
|
# SellStrategy Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestSellStrategy:
|
|
"""Tests for SellStrategy class."""
|
|
|
|
def test_sell_strategy_creation(self) -> None:
|
|
"""Test sell strategy creation."""
|
|
strategy = MockSellStrategy(
|
|
name="test_sell",
|
|
parameters=SellParameters(stop_loss_pct=0.03),
|
|
)
|
|
|
|
assert strategy.name == "test_sell"
|
|
assert strategy.parameters.stop_loss_pct == 0.03
|
|
|
|
def test_sell_signal_generation_no_position(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test sell signal without position."""
|
|
strategy = MockSellStrategy(name="test_sell")
|
|
strategy.initialize()
|
|
|
|
# No position, should not generate signal
|
|
signal = strategy.process_bar(sample_bar, sample_context)
|
|
assert signal is None
|
|
|
|
def test_sell_signal_generation_with_position(self, sample_bar: pd.Series, sample_context: StrategyContext) -> None:
|
|
"""Test sell signal with position."""
|
|
strategy = MockSellStrategy(name="test_sell")
|
|
strategy.initialize()
|
|
|
|
# Add a position to context
|
|
sample_context.positions["AAPL"] = type("Position", (), {
|
|
"quantity": 100,
|
|
"entry_price": 100.0,
|
|
})()
|
|
|
|
signal = strategy.process_bar(sample_bar, sample_context)
|
|
|
|
assert signal is not None
|
|
assert signal.signal_type == SignalType.SELL
|
|
|
|
def test_stop_loss_check(self, sample_bar: pd.Series) -> None:
|
|
"""Test stop loss check."""
|
|
strategy = MockSellStrategy(name="test_sell")
|
|
position = type("Position", (), {"entry_price": 110.0})()
|
|
|
|
# Price at 102, entry at 110, stop at 5%
|
|
# Loss = (110 - 102) / 110 = 7.27%, should trigger
|
|
assert strategy._check_stop_loss(sample_bar, position)
|
|
|
|
def test_take_profit_check(self, sample_bar: pd.Series) -> None:
|
|
"""Test take profit check."""
|
|
strategy = MockSellStrategy(name="test_sell")
|
|
position = type("Position", (), {"entry_price": 90.0})()
|
|
|
|
# Price at 102, entry at 90, profit = 13.3%, should trigger
|
|
assert strategy._check_take_profit(sample_bar, position)
|
|
|
|
def test_trailing_stop(self, sample_bar: pd.Series) -> None:
|
|
"""Test trailing stop functionality."""
|
|
strategy = MockSellStrategy(
|
|
name="test_sell",
|
|
parameters=SellParameters(trailing_stop_pct=0.02),
|
|
)
|
|
|
|
# Update highest price
|
|
strategy._update_trailing_stop("AAPL", pd.Series({"high": 110.0}))
|
|
assert strategy._highest_price_seen["AAPL"] == 110.0
|
|
|
|
# Check trailing stop with 2% pullback
|
|
current_bar = pd.Series({"low": 107.0}) # 2.7% pullback from 110
|
|
position = type("Position", (), {"symbol": "AAPL"})()
|
|
|
|
assert strategy._check_trailing_stop(current_bar, position)
|
|
|
|
|
|
# =============================================================================
|
|
# SelectStrategy Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestSelectStrategy:
|
|
"""Tests for SelectStrategy class."""
|
|
|
|
def test_select_strategy_creation(self) -> None:
|
|
"""Test select strategy creation."""
|
|
strategy = MockSelectStrategy(
|
|
name="test_select",
|
|
parameters=SelectParameters(max_selections=5),
|
|
)
|
|
|
|
assert strategy.name == "test_select"
|
|
assert strategy.parameters.max_selections == 5
|
|
|
|
def test_select_result_creation(self) -> None:
|
|
"""Test select result creation."""
|
|
result = SelectResult(symbol="AAPL", score=10.5, selected=True, rank=1)
|
|
|
|
assert result.symbol == "AAPL"
|
|
assert result.score == 10.5
|
|
assert result.selected
|
|
assert result.rank == 1
|
|
|
|
def test_select_result_empty_symbol(self) -> None:
|
|
"""Test select result with empty symbol."""
|
|
with pytest.raises(ValueError, match="Symbol cannot be empty"):
|
|
SelectResult(symbol="")
|
|
|
|
def test_select_from_universe(self) -> None:
|
|
"""Test selection from universe."""
|
|
strategy = MockSelectStrategy(name="test_select")
|
|
|
|
universe = {
|
|
"A": pd.DataFrame({"close": [1, 2, 3]}),
|
|
"BB": pd.DataFrame({"close": [1, 2, 3]}),
|
|
"CCC": pd.DataFrame({"close": [1, 2, 3]}),
|
|
}
|
|
|
|
results = strategy.select(universe)
|
|
|
|
assert len(results) == 3
|
|
# Should be sorted by score (symbol length)
|
|
assert results[0].symbol == "CCC" # length 3
|
|
assert results[1].symbol == "BB" # length 2
|
|
assert results[2].symbol == "A" # length 1
|
|
|
|
def test_select_with_filters(self) -> None:
|
|
"""Test selection with filters."""
|
|
strategy = MockSelectStrategy(
|
|
name="test_select",
|
|
parameters=SelectParameters(filter_price=5.0),
|
|
)
|
|
|
|
universe = {
|
|
"A": pd.DataFrame({"close": [1, 2, 3]}), # price < 5, filtered out
|
|
"B": pd.DataFrame({"close": [10, 11, 12]}), # price > 5, included
|
|
}
|
|
|
|
results = strategy.select(universe)
|
|
|
|
a_result = next(r for r in results if r.symbol == "A")
|
|
b_result = next(r for r in results if r.symbol == "B")
|
|
|
|
assert not a_result.selected # Filtered out
|
|
assert b_result.selected # Included
|
|
|
|
def test_select_max_selections(self) -> None:
|
|
"""Test max selections limit."""
|
|
strategy = MockSelectStrategy(
|
|
name="test_select",
|
|
parameters=SelectParameters(max_selections=2),
|
|
)
|
|
|
|
universe = {
|
|
"A": pd.DataFrame({"close": [1]}),
|
|
"BB": pd.DataFrame({"close": [1]}),
|
|
"CCC": pd.DataFrame({"close": [1]}),
|
|
"DDDD": pd.DataFrame({"close": [1]}),
|
|
}
|
|
|
|
results = strategy.select(universe)
|
|
selected = [r for r in results if r.selected]
|
|
|
|
assert len(selected) == 2
|
|
|
|
def test_select_top_n(self) -> None:
|
|
"""Test top N selection."""
|
|
strategy = MockSelectStrategy(
|
|
name="test_select",
|
|
parameters=SelectParameters(top_n=2),
|
|
)
|
|
|
|
universe = {
|
|
"A": pd.DataFrame({"close": [1]}),
|
|
"BB": pd.DataFrame({"close": [1]}),
|
|
"CCC": pd.DataFrame({"close": [1]}),
|
|
}
|
|
|
|
results = strategy.select(universe)
|
|
selected = [r for r in results if r.selected]
|
|
|
|
assert len(selected) == 2
|
|
|
|
def test_get_top_selections(self) -> None:
|
|
"""Test getting top selections."""
|
|
strategy = MockSelectStrategy(name="test_select")
|
|
|
|
results = [
|
|
SelectResult(symbol="A", score=10, selected=True),
|
|
SelectResult(symbol="B", score=8, selected=True),
|
|
SelectResult(symbol="C", score=6, selected=True),
|
|
]
|
|
|
|
top = strategy.get_top_selections(results, n=2)
|
|
assert len(top) == 2
|
|
|
|
def test_select_stats(self) -> None:
|
|
"""Test selection statistics."""
|
|
strategy = MockSelectStrategy(name="test_select")
|
|
|
|
universe = {
|
|
"A": pd.DataFrame({"close": [1]}),
|
|
"BB": pd.DataFrame({"close": [1]}),
|
|
}
|
|
|
|
strategy.select(universe)
|
|
|
|
stats = strategy.get_selection_stats()
|
|
assert stats["selections_made"] == 1
|
|
assert stats["avg_candidates"] == 2.0
|
|
|
|
|
|
# =============================================================================
|
|
# Registry Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestRegistry:
|
|
"""Tests for strategy registry."""
|
|
|
|
def setup_method(self) -> None:
|
|
"""Clear registry before each test."""
|
|
clear_registry()
|
|
|
|
def teardown_method(self) -> None:
|
|
"""Clear registry after each test."""
|
|
clear_registry()
|
|
|
|
def test_register_strategy(self) -> None:
|
|
"""Test strategy registration."""
|
|
@register_strategy(
|
|
name="test_strategy",
|
|
description="A test strategy",
|
|
tags=["test", "mock"],
|
|
)
|
|
class TestStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
assert "test_strategy" in get_registered_strategies()
|
|
|
|
info = get_strategy_info("test_strategy")
|
|
assert info["description"] == "A test strategy"
|
|
assert "test" in info["tags"]
|
|
|
|
def test_register_duplicate(self) -> None:
|
|
"""Test registering duplicate strategy."""
|
|
@register_strategy(name="dup_strategy")
|
|
class Strategy1(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
with pytest.raises(StrategyRegistrationError, match="already registered"):
|
|
@register_strategy(name="dup_strategy")
|
|
class Strategy2(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
def test_register_non_strategy(self) -> None:
|
|
"""Test registering non-strategy class."""
|
|
with pytest.raises(StrategyRegistrationError, match="must inherit from Strategy"):
|
|
@register_strategy(name="invalid")
|
|
class NotAStrategy:
|
|
pass
|
|
|
|
def test_unregister_strategy(self) -> None:
|
|
"""Test unregistering strategy."""
|
|
@register_strategy(name="to_remove")
|
|
class TempStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
assert "to_remove" in get_registered_strategies()
|
|
|
|
result = unregister_strategy("to_remove")
|
|
assert result
|
|
assert "to_remove" not in get_registered_strategies()
|
|
|
|
def test_unregister_not_found(self) -> None:
|
|
"""Test unregistering non-existent strategy."""
|
|
result = unregister_strategy("non_existent")
|
|
assert not result
|
|
|
|
def test_get_strategy_class(self) -> None:
|
|
"""Test getting strategy class."""
|
|
@register_strategy(name="my_strategy")
|
|
class MyStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
cls = get_strategy_class("my_strategy")
|
|
assert cls.__name__ == "MyStrategy"
|
|
|
|
def test_get_strategy_class_not_found(self) -> None:
|
|
"""Test getting non-existent strategy class."""
|
|
with pytest.raises(StrategyNotFoundError):
|
|
get_strategy_class("non_existent")
|
|
|
|
def test_is_strategy_registered(self) -> None:
|
|
"""Test checking if strategy is registered."""
|
|
@register_strategy(name="check_me")
|
|
class CheckStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
assert is_strategy_registered("check_me")
|
|
assert not is_strategy_registered("not_registered")
|
|
|
|
def test_get_strategies_by_tag(self) -> None:
|
|
"""Test getting strategies by tag."""
|
|
@register_strategy(name="tagged_strategy", tags=["momentum", "trend"])
|
|
class TaggedStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
momentum_strategies = get_strategies_by_tag("momentum")
|
|
assert "tagged_strategy" in momentum_strategies
|
|
|
|
trend_strategies = get_strategies_by_tag("trend")
|
|
assert "tagged_strategy" in trend_strategies
|
|
|
|
def test_registry_stats(self) -> None:
|
|
"""Test registry statistics."""
|
|
@register_strategy(name="stats_strategy", tags=["test"])
|
|
class StatsStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
stats = get_registry_stats()
|
|
assert stats["total_strategies"] == 1
|
|
assert "stats_strategy" in stats["strategy_names"]
|
|
assert "test" in stats["unique_tags"]
|
|
|
|
def test_clear_registry(self) -> None:
|
|
"""Test clearing registry."""
|
|
@register_strategy(name="temp")
|
|
class TempStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.5
|
|
|
|
assert len(get_registered_strategies()) > 0
|
|
clear_registry()
|
|
assert len(get_registered_strategies()) == 0
|
|
|
|
|
|
# =============================================================================
|
|
# Factory Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestFactory:
|
|
"""Tests for StrategyFactory."""
|
|
|
|
def setup_method(self) -> None:
|
|
"""Clear registry and register test strategies."""
|
|
clear_registry()
|
|
|
|
@register_strategy(name="mock_buy")
|
|
class RegisteredBuyStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.8
|
|
|
|
def teardown_method(self) -> None:
|
|
"""Clear registry."""
|
|
clear_registry()
|
|
|
|
def test_factory_create(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating strategy via factory."""
|
|
strategy = strategy_factory.create(
|
|
name="mock_buy",
|
|
parameters={"max_position_size": 0.2},
|
|
)
|
|
|
|
assert strategy.name == "mock_buy"
|
|
assert strategy.parameters.max_position_size == 0.2
|
|
|
|
def test_factory_create_not_found(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating non-existent strategy."""
|
|
with pytest.raises(StrategyFactoryError, match="not found"):
|
|
strategy_factory.create("non_existent")
|
|
|
|
def test_factory_create_with_class(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating strategy with explicit class."""
|
|
strategy = strategy_factory.create(
|
|
name="explicit_strategy",
|
|
strategy_class=MockBuyStrategy,
|
|
parameters={"max_position_size": 0.3},
|
|
)
|
|
|
|
assert strategy.name == "explicit_strategy"
|
|
assert strategy.parameters.max_position_size == 0.3
|
|
|
|
def test_factory_create_from_config(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating strategy from config."""
|
|
config = {
|
|
"name": "mock_buy",
|
|
"parameters": {"max_position_size": 0.15},
|
|
"description": "Created from config",
|
|
}
|
|
|
|
strategy = strategy_factory.create_from_config(config)
|
|
|
|
assert strategy.name == "mock_buy"
|
|
assert strategy.parameters.max_position_size == 0.15
|
|
assert strategy.description == "Created from config"
|
|
|
|
def test_factory_invalid_config(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating strategy with invalid config."""
|
|
with pytest.raises(StrategyConfigurationError):
|
|
strategy_factory.create_from_config({}) # Missing name
|
|
|
|
def test_factory_create_buy_strategy(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating buy strategy."""
|
|
strategy = strategy_factory.create_buy_strategy(
|
|
name="mock_buy",
|
|
parameters={"max_position_size": 0.2},
|
|
)
|
|
|
|
assert isinstance(strategy, BuyStrategy)
|
|
|
|
def test_factory_create_buy_strategy_wrong_type(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating buy strategy with wrong type."""
|
|
with pytest.raises(StrategyFactoryError, match="not a BuyStrategy"):
|
|
strategy_factory.create_buy_strategy(
|
|
name="test",
|
|
strategy_class=MockSellStrategy, # type: ignore
|
|
)
|
|
|
|
def test_factory_create_sell_strategy(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating sell strategy."""
|
|
strategy = strategy_factory.create_sell_strategy(
|
|
name="test_sell",
|
|
strategy_class=MockSellStrategy,
|
|
)
|
|
|
|
assert isinstance(strategy, SellStrategy)
|
|
|
|
def test_factory_create_select_strategy(self, strategy_factory: StrategyFactory) -> None:
|
|
"""Test creating select strategy."""
|
|
strategy = strategy_factory.create_select_strategy(
|
|
name="test_select",
|
|
strategy_class=MockSelectStrategy,
|
|
)
|
|
|
|
assert isinstance(strategy, SelectStrategy)
|
|
|
|
def test_convenience_function_create_strategy(self) -> None:
|
|
"""Test create_strategy convenience function."""
|
|
strategy = create_strategy(
|
|
name="test",
|
|
strategy_class=MockBuyStrategy,
|
|
)
|
|
|
|
assert isinstance(strategy, BuyStrategy)
|
|
|
|
def test_convenience_function_create_from_config(self) -> None:
|
|
"""Test create_strategy_from_config convenience function."""
|
|
config = {
|
|
"name": "test",
|
|
"strategy_type": "mock_buy",
|
|
}
|
|
|
|
strategy = create_strategy_from_config(config)
|
|
assert strategy.name == "test"
|
|
|
|
|
|
# =============================================================================
|
|
# Edge Cases and Integration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Tests for edge cases."""
|
|
|
|
def test_signal_types(self) -> None:
|
|
"""Test all signal types."""
|
|
for sig_type in SignalType:
|
|
signal = Signal(signal_type=sig_type, symbol="TEST")
|
|
assert signal.signal_type == sig_type
|
|
|
|
def test_empty_universe_selection(self) -> None:
|
|
"""Test selection with empty universe."""
|
|
strategy = MockSelectStrategy(name="test")
|
|
results = strategy.select({})
|
|
assert results == []
|
|
|
|
def test_strategy_with_invalid_parameters(self) -> None:
|
|
"""Test strategy with invalid parameters."""
|
|
with pytest.raises(ValidationError):
|
|
BuyParameters(max_position_size=-1)
|
|
|
|
def test_context_with_custom_data(self) -> None:
|
|
"""Test context with custom data."""
|
|
context = StrategyContext(
|
|
custom_data={"indicator_value": 42, "threshold": 0.5},
|
|
)
|
|
|
|
assert context.custom_data["indicator_value"] == 42
|
|
|
|
def test_signal_metadata(self) -> None:
|
|
"""Test signal with metadata."""
|
|
signal = Signal(
|
|
signal_type=SignalType.BUY,
|
|
symbol="AAPL",
|
|
metadata={
|
|
"indicator": "RSI",
|
|
"value": 30.5,
|
|
"threshold": 30.0,
|
|
},
|
|
)
|
|
|
|
assert signal.metadata["indicator"] == "RSI"
|
|
assert signal.metadata["value"] == 30.5
|
|
|
|
|
|
# =============================================================================
|
|
# Integration Tests
|
|
# =============================================================================
|
|
|
|
|
|
class TestIntegration:
|
|
"""Integration tests for strategy framework."""
|
|
|
|
def setup_method(self) -> None:
|
|
"""Set up test environment."""
|
|
clear_registry()
|
|
|
|
def teardown_method(self) -> None:
|
|
"""Clean up test environment."""
|
|
clear_registry()
|
|
|
|
def test_full_strategy_lifecycle(self) -> None:
|
|
"""Test full strategy lifecycle."""
|
|
|
|
@register_strategy(name="lifecycle_test")
|
|
class LifecycleStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return data.get("close", 0) > 100
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.75
|
|
|
|
# Create via factory
|
|
factory = StrategyFactory()
|
|
strategy = factory.create("lifecycle_test", parameters={"max_position_size": 0.2})
|
|
|
|
# Initialize
|
|
strategy.initialize()
|
|
assert strategy.is_initialized
|
|
|
|
# Process bars
|
|
context = StrategyContext(symbol="AAPL", equity=10000.0)
|
|
|
|
bar1 = pd.Series({"open": 98, "high": 102, "low": 97, "close": 99, "volume": 1000})
|
|
signal1 = strategy.process_bar(bar1, context)
|
|
assert signal1 is None # Price <= 100
|
|
|
|
bar2 = pd.Series({"open": 101, "high": 105, "low": 100, "close": 102, "volume": 1500})
|
|
signal2 = strategy.process_bar(bar2, context)
|
|
assert signal2 is not None # Price > 100
|
|
assert signal2.signal_type == SignalType.BUY
|
|
assert signal2.confidence == 0.75
|
|
|
|
# Shutdown
|
|
strategy.shutdown()
|
|
assert not strategy.is_active
|
|
|
|
def test_registry_factory_integration(self) -> None:
|
|
"""Test registry and factory integration."""
|
|
|
|
@register_strategy(name="integration_test", tags=["integration"])
|
|
class IntegrationStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.9
|
|
|
|
# Verify registration
|
|
assert is_strategy_registered("integration_test")
|
|
|
|
# Get info
|
|
info = get_strategy_info("integration_test")
|
|
assert "integration" in info["tags"]
|
|
|
|
# Create via factory
|
|
factory = StrategyFactory()
|
|
strategy = factory.create("integration_test")
|
|
|
|
assert isinstance(strategy, BuyStrategy)
|
|
assert strategy.name == "integration_test"
|
|
|
|
def test_multiple_strategy_types(self) -> None:
|
|
"""Test using multiple strategy types together."""
|
|
|
|
# Register different strategy types
|
|
@register_strategy(name="buyer")
|
|
class TestBuyStrategy(BuyStrategy):
|
|
def _should_buy(self, data: pd.Series, context: StrategyContext) -> bool:
|
|
return True
|
|
|
|
def _calculate_buy_confidence(self, data: pd.Series, context: StrategyContext) -> float:
|
|
return 0.8
|
|
|
|
@register_strategy(name="seller")
|
|
class TestSellStrategy(SellStrategy):
|
|
def _should_sell(self, data: pd.Series, context: StrategyContext, position: Any) -> bool:
|
|
return True
|
|
|
|
def _calculate_sell_confidence(self, data: pd.Series, context: StrategyContext, position: Any) -> float:
|
|
return 0.7
|
|
|
|
@register_strategy(name="selector")
|
|
class TestSelectStrategy(SelectStrategy):
|
|
def calculate_score(self, symbol: str, data: pd.DataFrame) -> float:
|
|
return float(len(symbol))
|
|
|
|
factory = StrategyFactory()
|
|
|
|
# Create each type
|
|
buy_strategy = factory.create_buy_strategy("buyer")
|
|
sell_strategy = factory.create_sell_strategy("seller")
|
|
select_strategy = factory.create_select_strategy("selector")
|
|
|
|
assert isinstance(buy_strategy, BuyStrategy)
|
|
assert isinstance(sell_strategy, SellStrategy)
|
|
assert isinstance(select_strategy, SelectStrategy)
|
|
|
|
# Check registry
|
|
stats = get_registry_stats()
|
|
assert stats["total_strategies"] == 3
|