508 lines
16 KiB
Python
508 lines
16 KiB
Python
"""Unit tests for TraderAgent.
|
|
|
|
This module tests the TraderAgent class including market analysis,
|
|
signal generation, and trade execution.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from openclaw.agents.base import ActivityType
|
|
from openclaw.agents.trader import (
|
|
MarketAnalysis,
|
|
SignalType,
|
|
TradeResult,
|
|
TradeSignal,
|
|
TraderAgent,
|
|
)
|
|
from openclaw.core.economy import SurvivalStatus
|
|
|
|
|
|
class TestTraderAgentInitialization:
|
|
"""Test TraderAgent initialization."""
|
|
|
|
def test_default_initialization(self):
|
|
"""Test agent with default parameters."""
|
|
agent = TraderAgent(agent_id="trader-1", initial_capital=10000.0)
|
|
|
|
assert agent.agent_id == "trader-1"
|
|
assert agent.balance == 10000.0
|
|
assert agent.skill_level == 0.5
|
|
assert agent.max_position_pct == 0.2
|
|
assert agent._trade_history == []
|
|
assert agent._paper_trade_history == []
|
|
|
|
def test_custom_initialization(self):
|
|
"""Test agent with custom parameters."""
|
|
agent = TraderAgent(
|
|
agent_id="trader-2",
|
|
initial_capital=5000.0,
|
|
skill_level=0.8,
|
|
max_position_pct=0.3,
|
|
)
|
|
|
|
assert agent.agent_id == "trader-2"
|
|
assert agent.balance == 5000.0
|
|
assert agent.skill_level == 0.8
|
|
assert agent.max_position_pct == 0.3
|
|
|
|
def test_inherits_from_base_agent(self):
|
|
"""Test that TraderAgent inherits from BaseAgent."""
|
|
from openclaw.agents.base import BaseAgent
|
|
|
|
agent = TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert isinstance(agent, BaseAgent)
|
|
|
|
|
|
class TestDecideActivity:
|
|
"""Test decide_activity method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_bankrupt_agent_only_rests(self, agent):
|
|
"""Test that bankrupt agent can only rest."""
|
|
agent.economic_tracker.balance = 0 # Bankrupt
|
|
|
|
result = asyncio.run(agent.decide_activity())
|
|
|
|
assert result == ActivityType.REST
|
|
|
|
def test_critical_status_prefers_learning(self, agent):
|
|
"""Test critical status leads to learning."""
|
|
agent.economic_tracker.balance = 3500.0 # Critical
|
|
agent.state.skill_level = 0.5
|
|
|
|
result = asyncio.run(agent.decide_activity())
|
|
|
|
assert result in [ActivityType.LEARN, ActivityType.PAPER_TRADE]
|
|
|
|
def test_thriving_status_prefers_trading(self, agent):
|
|
"""Test thriving status leads to trading."""
|
|
agent.economic_tracker.balance = 20000.0 # Thriving
|
|
|
|
# Run multiple times to account for randomness
|
|
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
|
|
|
|
# Most should be TRADE or ANALYZE
|
|
trade_like = [r for r in results if r in [ActivityType.TRADE, ActivityType.ANALYZE]]
|
|
assert len(trade_like) >= 10 # At least half
|
|
|
|
def test_struggling_status_more_paper_trading(self, agent):
|
|
"""Test struggling status prefers paper trading."""
|
|
agent.economic_tracker.balance = 8500.0 # Struggling
|
|
|
|
# Run multiple times
|
|
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
|
|
|
|
# Some should be paper trade
|
|
paper_trades = [r for r in results if r == ActivityType.PAPER_TRADE]
|
|
assert len(paper_trades) >= 5 # At least some paper trades
|
|
|
|
|
|
class TestAnalyzeMarket:
|
|
"""Test analyze_market method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_returns_market_analysis(self, agent):
|
|
"""Test that analyze_market returns MarketAnalysis."""
|
|
result = agent.analyze_market("AAPL")
|
|
|
|
assert isinstance(result, MarketAnalysis)
|
|
assert result.symbol == "AAPL"
|
|
assert result.trend in ["uptrend", "downtrend", "sideways"]
|
|
assert 0 <= result.volatility <= 1
|
|
assert result.volume_trend in ["increasing", "decreasing"]
|
|
assert result.support_level < result.resistance_level
|
|
|
|
def test_indicators_present(self, agent):
|
|
"""Test that technical indicators are present."""
|
|
result = agent.analyze_market("TSLA")
|
|
|
|
assert "rsi" in result.indicators
|
|
assert "macd" in result.indicators
|
|
assert "sma_20" in result.indicators
|
|
assert "current_price" in result.indicators
|
|
|
|
def test_high_skill_more_accurate(self):
|
|
"""Test that high skill produces more consistent analysis."""
|
|
high_skill_agent = TraderAgent(
|
|
agent_id="high", initial_capital=10000.0, skill_level=0.9
|
|
)
|
|
|
|
# Multiple analyses should have RSI in tighter range
|
|
rsis = []
|
|
for _ in range(10):
|
|
analysis = high_skill_agent.analyze_market("AAPL")
|
|
rsis.append(analysis.indicators["rsi"])
|
|
|
|
# RSIs should be within 20 points (high skill = more accurate)
|
|
assert max(rsis) - min(rsis) <= 30
|
|
|
|
|
|
class TestGenerateSignal:
|
|
"""Test generate_signal method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_oversold_generates_buy(self, agent):
|
|
"""Test oversold condition generates buy signal."""
|
|
analysis = MarketAnalysis(
|
|
symbol="AAPL",
|
|
trend="downtrend",
|
|
volatility=0.2,
|
|
volume_trend="increasing",
|
|
support_level=90.0,
|
|
resistance_level=110.0,
|
|
indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0},
|
|
)
|
|
|
|
signal = agent.generate_signal(analysis)
|
|
|
|
assert signal.signal == SignalType.BUY
|
|
assert signal.confidence > 0.5
|
|
assert "oversold" in signal.reason.lower() or "RSI" in signal.reason
|
|
|
|
def test_overbought_generates_sell(self, agent):
|
|
"""Test overbought condition generates sell signal."""
|
|
analysis = MarketAnalysis(
|
|
symbol="AAPL",
|
|
trend="uptrend",
|
|
volatility=0.2,
|
|
volume_trend="increasing",
|
|
support_level=90.0,
|
|
resistance_level=110.0,
|
|
indicators={"rsi": 70.0, "macd": -0.5, "current_price": 100.0},
|
|
)
|
|
|
|
signal = agent.generate_signal(analysis)
|
|
|
|
assert signal.signal == SignalType.SELL
|
|
assert signal.confidence > 0.5
|
|
assert "overbought" in signal.reason.lower() or "RSI" in signal.reason
|
|
|
|
def test_neutral_generates_hold(self, agent):
|
|
"""Test neutral condition generates hold signal."""
|
|
analysis = MarketAnalysis(
|
|
symbol="AAPL",
|
|
trend="sideways",
|
|
volatility=0.2,
|
|
volume_trend="flat",
|
|
support_level=90.0,
|
|
resistance_level=110.0,
|
|
indicators={"rsi": 50.0, "macd": 0.0, "current_price": 100.0},
|
|
)
|
|
|
|
signal = agent.generate_signal(analysis)
|
|
|
|
assert signal.signal == SignalType.HOLD
|
|
assert signal.suggested_position == 0.0
|
|
|
|
def test_suggested_position_based_on_confidence(self, agent):
|
|
"""Test that position size is based on confidence."""
|
|
analysis = MarketAnalysis(
|
|
symbol="AAPL",
|
|
trend="uptrend",
|
|
volatility=0.2,
|
|
volume_trend="increasing",
|
|
support_level=90.0,
|
|
resistance_level=110.0,
|
|
indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0},
|
|
)
|
|
|
|
signal = agent.generate_signal(analysis)
|
|
|
|
assert signal.suggested_position > 0
|
|
# Should be less than max_position_pct of balance
|
|
assert signal.suggested_position <= agent.balance * agent.max_position_pct
|
|
|
|
|
|
class TestExecuteTrade:
|
|
"""Test execute_trade method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_trade_success(self, agent):
|
|
"""Test successful trade execution."""
|
|
initial_balance = agent.balance
|
|
|
|
with patch("random.random", return_value=0.1): # Force win
|
|
result = agent.execute_trade("AAPL", SignalType.BUY, 1000.0)
|
|
|
|
assert isinstance(result, TradeResult)
|
|
assert result.symbol == "AAPL"
|
|
assert result.signal == SignalType.BUY
|
|
assert result.success is True
|
|
assert result.fee > 0
|
|
assert "trade history" in agent._trade_history or len(agent._trade_history) > 0
|
|
|
|
def test_trade_records_in_history(self, agent):
|
|
"""Test that trade is recorded in history."""
|
|
with patch("random.random", return_value=0.5):
|
|
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
assert len(agent._trade_history) == 1
|
|
assert agent._trade_history[0].symbol == "AAPL"
|
|
|
|
def test_trade_updates_stats(self, agent):
|
|
"""Test that trade updates agent statistics."""
|
|
initial_trades = agent.state.total_trades
|
|
|
|
with patch("random.random", return_value=0.1): # Force win
|
|
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
assert agent.state.total_trades == initial_trades + 1
|
|
|
|
def test_trade_deducts_costs(self, agent):
|
|
"""Test that trade deducts costs from balance."""
|
|
initial_balance = agent.balance
|
|
|
|
with patch("random.random", return_value=0.5):
|
|
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
# Balance should change due to fees/PnL
|
|
assert agent.balance != initial_balance
|
|
|
|
def test_insufficient_funds_fails(self, agent):
|
|
"""Test that trade fails when insufficient funds."""
|
|
result = agent.execute_trade("AAPL", SignalType.BUY, 50000.0)
|
|
|
|
assert result.success is False
|
|
assert "insufficient" in result.message.lower() or "Insufficient" in result.message
|
|
|
|
|
|
class TestPaperTrade:
|
|
"""Test paper_trade method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_paper_trade_returns_result(self, agent):
|
|
"""Test that paper trade returns TradeResult."""
|
|
result = agent.paper_trade("AAPL", SignalType.BUY, 1000.0)
|
|
|
|
assert isinstance(result, TradeResult)
|
|
assert result.symbol == "AAPL"
|
|
assert result.success is True
|
|
|
|
def test_paper_trade_records_in_separate_history(self, agent):
|
|
"""Test that paper trade is recorded separately."""
|
|
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
assert len(agent._paper_trade_history) == 1
|
|
assert len(agent._trade_history) == 0
|
|
|
|
def test_paper_trade_minimal_cost(self, agent):
|
|
"""Test that paper trade only deducts minimal cost."""
|
|
initial_balance = agent.balance
|
|
|
|
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
# Should only deduct small data cost, not full trade cost
|
|
balance_change = initial_balance - agent.balance
|
|
assert balance_change < 0.1 # Very small cost
|
|
|
|
|
|
class TestAnalyze:
|
|
"""Test analyze method (async)."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_analyze_returns_dict(self, agent):
|
|
"""Test that analyze returns a dictionary."""
|
|
result = asyncio.run(agent.analyze("AAPL"))
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["symbol"] == "AAPL"
|
|
assert "signal" in result
|
|
assert "confidence" in result
|
|
assert "reason" in result
|
|
assert "market_analysis" in result
|
|
assert "cost" in result
|
|
|
|
def test_analyze_deducts_cost(self, agent):
|
|
"""Test that analyze deducts decision cost."""
|
|
initial_balance = agent.balance
|
|
|
|
asyncio.run(agent.analyze("AAPL"))
|
|
|
|
assert agent.balance < initial_balance
|
|
|
|
def test_analyze_stores_last_analysis(self, agent):
|
|
"""Test that analyze stores the analysis."""
|
|
assert agent._last_analysis is None
|
|
|
|
asyncio.run(agent.analyze("TSLA"))
|
|
|
|
assert agent._last_analysis is not None
|
|
assert agent._last_analysis.symbol == "TSLA"
|
|
|
|
|
|
class TestTradeHistory:
|
|
"""Test trade history methods."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_get_trade_history_returns_copy(self, agent):
|
|
"""Test that get_trade_history returns a copy."""
|
|
with patch("random.random", return_value=0.5):
|
|
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
history = agent.get_trade_history()
|
|
history.append(None) # Modify the copy
|
|
|
|
# Original should be unchanged
|
|
assert len(agent._trade_history) == 1
|
|
|
|
def test_get_paper_trade_history_returns_copy(self, agent):
|
|
"""Test that get_paper_trade_history returns a copy."""
|
|
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
|
|
|
|
history = agent.get_paper_trade_history()
|
|
history.append(None) # Modify the copy
|
|
|
|
# Original should be unchanged
|
|
assert len(agent._paper_trade_history) == 1
|
|
|
|
|
|
class TestPerformanceStats:
|
|
"""Test performance statistics."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return TraderAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_stats_structure(self, agent):
|
|
"""Test that stats contains expected keys."""
|
|
stats = agent.get_performance_stats()
|
|
|
|
assert "total_real_trades" in stats
|
|
assert "total_paper_trades" in stats
|
|
assert "real_pnl" in stats
|
|
assert "paper_pnl" in stats
|
|
assert "win_rate" in stats
|
|
assert "skill_level" in stats
|
|
assert "balance" in stats
|
|
assert "survival_status" in stats
|
|
|
|
def test_stats_with_trades(self, agent):
|
|
"""Test stats calculation with trades."""
|
|
with patch("random.random", return_value=0.1):
|
|
agent.execute_trade("AAPL", SignalType.BUY, 1000.0)
|
|
agent.execute_trade("TSLA", SignalType.SELL, 1000.0)
|
|
|
|
agent.paper_trade("NVDA", SignalType.BUY, 1000.0)
|
|
|
|
stats = agent.get_performance_stats()
|
|
|
|
assert stats["total_real_trades"] == 2
|
|
assert stats["total_paper_trades"] == 1
|
|
assert stats["win_rate"] == 1.0 # Both won
|
|
|
|
|
|
class TestSignalType:
|
|
"""Test SignalType enum."""
|
|
|
|
def test_signal_values(self):
|
|
"""Test signal type values."""
|
|
assert SignalType.BUY == "buy"
|
|
assert SignalType.SELL == "sell"
|
|
assert SignalType.HOLD == "hold"
|
|
|
|
|
|
class TestTradeSignal:
|
|
"""Test TradeSignal dataclass."""
|
|
|
|
def test_trade_signal_creation(self):
|
|
"""Test creating a TradeSignal."""
|
|
signal = TradeSignal(
|
|
symbol="AAPL",
|
|
signal=SignalType.BUY,
|
|
confidence=0.8,
|
|
reason="RSI oversold",
|
|
suggested_position=1000.0,
|
|
)
|
|
|
|
assert signal.symbol == "AAPL"
|
|
assert signal.signal == SignalType.BUY
|
|
assert signal.confidence == 0.8
|
|
assert signal.reason == "RSI oversold"
|
|
assert signal.suggested_position == 1000.0
|
|
|
|
|
|
class TestMarketAnalysis:
|
|
"""Test MarketAnalysis dataclass."""
|
|
|
|
def test_market_analysis_creation(self):
|
|
"""Test creating a MarketAnalysis."""
|
|
analysis = MarketAnalysis(
|
|
symbol="AAPL",
|
|
trend="uptrend",
|
|
volatility=0.2,
|
|
volume_trend="increasing",
|
|
support_level=90.0,
|
|
resistance_level=110.0,
|
|
indicators={"rsi": 50.0},
|
|
)
|
|
|
|
assert analysis.symbol == "AAPL"
|
|
assert analysis.trend == "uptrend"
|
|
|
|
|
|
class TestTradeResult:
|
|
"""Test TradeResult dataclass."""
|
|
|
|
def test_trade_result_creation(self):
|
|
"""Test creating a TradeResult."""
|
|
result = TradeResult(
|
|
symbol="AAPL",
|
|
signal=SignalType.BUY,
|
|
success=True,
|
|
pnl=100.0,
|
|
fee=10.0,
|
|
message="Success",
|
|
)
|
|
|
|
assert result.symbol == "AAPL"
|
|
assert result.success is True
|
|
assert result.pnl == 100.0
|
|
assert result.timestamp != "" # Auto-generated
|
|
|
|
def test_trade_result_custom_timestamp(self):
|
|
"""Test TradeResult with custom timestamp."""
|
|
result = TradeResult(
|
|
symbol="AAPL",
|
|
signal=SignalType.BUY,
|
|
success=True,
|
|
pnl=100.0,
|
|
fee=10.0,
|
|
message="Success",
|
|
timestamp="2024-01-01T00:00:00",
|
|
)
|
|
|
|
assert result.timestamp == "2024-01-01T00:00:00"
|