"""Unit tests for TraderAgent. This module tests the TraderAgent class including market analysis, signal generation, and trade execution. """ import asyncio from unittest.mock import patch import pytest from openclaw.agents.base import ActivityType from openclaw.agents.trader import ( MarketAnalysis, SignalType, TradeResult, TradeSignal, TraderAgent, ) from openclaw.core.economy import SurvivalStatus class TestTraderAgentInitialization: """Test TraderAgent initialization.""" def test_default_initialization(self): """Test agent with default parameters.""" agent = TraderAgent(agent_id="trader-1", initial_capital=10000.0) assert agent.agent_id == "trader-1" assert agent.balance == 10000.0 assert agent.skill_level == 0.5 assert agent.max_position_pct == 0.2 assert agent._trade_history == [] assert agent._paper_trade_history == [] def test_custom_initialization(self): """Test agent with custom parameters.""" agent = TraderAgent( agent_id="trader-2", initial_capital=5000.0, skill_level=0.8, max_position_pct=0.3, ) assert agent.agent_id == "trader-2" assert agent.balance == 5000.0 assert agent.skill_level == 0.8 assert agent.max_position_pct == 0.3 def test_inherits_from_base_agent(self): """Test that TraderAgent inherits from BaseAgent.""" from openclaw.agents.base import BaseAgent agent = TraderAgent(agent_id="test", initial_capital=10000.0) assert isinstance(agent, BaseAgent) class TestDecideActivity: """Test decide_activity method.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_bankrupt_agent_only_rests(self, agent): """Test that bankrupt agent can only rest.""" agent.economic_tracker.balance = 0 # Bankrupt result = asyncio.run(agent.decide_activity()) assert result == ActivityType.REST def test_critical_status_prefers_learning(self, agent): """Test critical status leads to learning.""" agent.economic_tracker.balance = 3500.0 # Critical agent.state.skill_level = 0.5 result = asyncio.run(agent.decide_activity()) assert result in [ActivityType.LEARN, ActivityType.PAPER_TRADE] def test_thriving_status_prefers_trading(self, agent): """Test thriving status leads to trading.""" agent.economic_tracker.balance = 20000.0 # Thriving # Run multiple times to account for randomness results = [asyncio.run(agent.decide_activity()) for _ in range(20)] # Most should be TRADE or ANALYZE trade_like = [r for r in results if r in [ActivityType.TRADE, ActivityType.ANALYZE]] assert len(trade_like) >= 10 # At least half def test_struggling_status_more_paper_trading(self, agent): """Test struggling status prefers paper trading.""" agent.economic_tracker.balance = 8500.0 # Struggling # Run multiple times results = [asyncio.run(agent.decide_activity()) for _ in range(20)] # Some should be paper trade paper_trades = [r for r in results if r == ActivityType.PAPER_TRADE] assert len(paper_trades) >= 5 # At least some paper trades class TestAnalyzeMarket: """Test analyze_market method.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_returns_market_analysis(self, agent): """Test that analyze_market returns MarketAnalysis.""" result = agent.analyze_market("AAPL") assert isinstance(result, MarketAnalysis) assert result.symbol == "AAPL" assert result.trend in ["uptrend", "downtrend", "sideways"] assert 0 <= result.volatility <= 1 assert result.volume_trend in ["increasing", "decreasing"] assert result.support_level < result.resistance_level def test_indicators_present(self, agent): """Test that technical indicators are present.""" result = agent.analyze_market("TSLA") assert "rsi" in result.indicators assert "macd" in result.indicators assert "sma_20" in result.indicators assert "current_price" in result.indicators def test_high_skill_more_accurate(self): """Test that high skill produces more consistent analysis.""" high_skill_agent = TraderAgent( agent_id="high", initial_capital=10000.0, skill_level=0.9 ) # Multiple analyses should have RSI in tighter range rsis = [] for _ in range(10): analysis = high_skill_agent.analyze_market("AAPL") rsis.append(analysis.indicators["rsi"]) # RSIs should be within 20 points (high skill = more accurate) assert max(rsis) - min(rsis) <= 30 class TestGenerateSignal: """Test generate_signal method.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_oversold_generates_buy(self, agent): """Test oversold condition generates buy signal.""" analysis = MarketAnalysis( symbol="AAPL", trend="downtrend", volatility=0.2, volume_trend="increasing", support_level=90.0, resistance_level=110.0, indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0}, ) signal = agent.generate_signal(analysis) assert signal.signal == SignalType.BUY assert signal.confidence > 0.5 assert "oversold" in signal.reason.lower() or "RSI" in signal.reason def test_overbought_generates_sell(self, agent): """Test overbought condition generates sell signal.""" analysis = MarketAnalysis( symbol="AAPL", trend="uptrend", volatility=0.2, volume_trend="increasing", support_level=90.0, resistance_level=110.0, indicators={"rsi": 70.0, "macd": -0.5, "current_price": 100.0}, ) signal = agent.generate_signal(analysis) assert signal.signal == SignalType.SELL assert signal.confidence > 0.5 assert "overbought" in signal.reason.lower() or "RSI" in signal.reason def test_neutral_generates_hold(self, agent): """Test neutral condition generates hold signal.""" analysis = MarketAnalysis( symbol="AAPL", trend="sideways", volatility=0.2, volume_trend="flat", support_level=90.0, resistance_level=110.0, indicators={"rsi": 50.0, "macd": 0.0, "current_price": 100.0}, ) signal = agent.generate_signal(analysis) assert signal.signal == SignalType.HOLD assert signal.suggested_position == 0.0 def test_suggested_position_based_on_confidence(self, agent): """Test that position size is based on confidence.""" analysis = MarketAnalysis( symbol="AAPL", trend="uptrend", volatility=0.2, volume_trend="increasing", support_level=90.0, resistance_level=110.0, indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0}, ) signal = agent.generate_signal(analysis) assert signal.suggested_position > 0 # Should be less than max_position_pct of balance assert signal.suggested_position <= agent.balance * agent.max_position_pct class TestExecuteTrade: """Test execute_trade method.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_trade_success(self, agent): """Test successful trade execution.""" initial_balance = agent.balance with patch("random.random", return_value=0.1): # Force win result = agent.execute_trade("AAPL", SignalType.BUY, 1000.0) assert isinstance(result, TradeResult) assert result.symbol == "AAPL" assert result.signal == SignalType.BUY assert result.success is True assert result.fee > 0 assert "trade history" in agent._trade_history or len(agent._trade_history) > 0 def test_trade_records_in_history(self, agent): """Test that trade is recorded in history.""" with patch("random.random", return_value=0.5): agent.execute_trade("AAPL", SignalType.BUY, 500.0) assert len(agent._trade_history) == 1 assert agent._trade_history[0].symbol == "AAPL" def test_trade_updates_stats(self, agent): """Test that trade updates agent statistics.""" initial_trades = agent.state.total_trades with patch("random.random", return_value=0.1): # Force win agent.execute_trade("AAPL", SignalType.BUY, 500.0) assert agent.state.total_trades == initial_trades + 1 def test_trade_deducts_costs(self, agent): """Test that trade deducts costs from balance.""" initial_balance = agent.balance with patch("random.random", return_value=0.5): agent.execute_trade("AAPL", SignalType.BUY, 500.0) # Balance should change due to fees/PnL assert agent.balance != initial_balance def test_insufficient_funds_fails(self, agent): """Test that trade fails when insufficient funds.""" result = agent.execute_trade("AAPL", SignalType.BUY, 50000.0) assert result.success is False assert "insufficient" in result.message.lower() or "Insufficient" in result.message class TestPaperTrade: """Test paper_trade method.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_paper_trade_returns_result(self, agent): """Test that paper trade returns TradeResult.""" result = agent.paper_trade("AAPL", SignalType.BUY, 1000.0) assert isinstance(result, TradeResult) assert result.symbol == "AAPL" assert result.success is True def test_paper_trade_records_in_separate_history(self, agent): """Test that paper trade is recorded separately.""" agent.paper_trade("AAPL", SignalType.BUY, 500.0) assert len(agent._paper_trade_history) == 1 assert len(agent._trade_history) == 0 def test_paper_trade_minimal_cost(self, agent): """Test that paper trade only deducts minimal cost.""" initial_balance = agent.balance agent.paper_trade("AAPL", SignalType.BUY, 500.0) # Should only deduct small data cost, not full trade cost balance_change = initial_balance - agent.balance assert balance_change < 0.1 # Very small cost class TestAnalyze: """Test analyze method (async).""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_analyze_returns_dict(self, agent): """Test that analyze returns a dictionary.""" result = asyncio.run(agent.analyze("AAPL")) assert isinstance(result, dict) assert result["symbol"] == "AAPL" assert "signal" in result assert "confidence" in result assert "reason" in result assert "market_analysis" in result assert "cost" in result def test_analyze_deducts_cost(self, agent): """Test that analyze deducts decision cost.""" initial_balance = agent.balance asyncio.run(agent.analyze("AAPL")) assert agent.balance < initial_balance def test_analyze_stores_last_analysis(self, agent): """Test that analyze stores the analysis.""" assert agent._last_analysis is None asyncio.run(agent.analyze("TSLA")) assert agent._last_analysis is not None assert agent._last_analysis.symbol == "TSLA" class TestTradeHistory: """Test trade history methods.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_get_trade_history_returns_copy(self, agent): """Test that get_trade_history returns a copy.""" with patch("random.random", return_value=0.5): agent.execute_trade("AAPL", SignalType.BUY, 500.0) history = agent.get_trade_history() history.append(None) # Modify the copy # Original should be unchanged assert len(agent._trade_history) == 1 def test_get_paper_trade_history_returns_copy(self, agent): """Test that get_paper_trade_history returns a copy.""" agent.paper_trade("AAPL", SignalType.BUY, 500.0) history = agent.get_paper_trade_history() history.append(None) # Modify the copy # Original should be unchanged assert len(agent._paper_trade_history) == 1 class TestPerformanceStats: """Test performance statistics.""" @pytest.fixture def agent(self): """Create a test agent.""" return TraderAgent(agent_id="test", initial_capital=10000.0) def test_stats_structure(self, agent): """Test that stats contains expected keys.""" stats = agent.get_performance_stats() assert "total_real_trades" in stats assert "total_paper_trades" in stats assert "real_pnl" in stats assert "paper_pnl" in stats assert "win_rate" in stats assert "skill_level" in stats assert "balance" in stats assert "survival_status" in stats def test_stats_with_trades(self, agent): """Test stats calculation with trades.""" with patch("random.random", return_value=0.1): agent.execute_trade("AAPL", SignalType.BUY, 1000.0) agent.execute_trade("TSLA", SignalType.SELL, 1000.0) agent.paper_trade("NVDA", SignalType.BUY, 1000.0) stats = agent.get_performance_stats() assert stats["total_real_trades"] == 2 assert stats["total_paper_trades"] == 1 assert stats["win_rate"] == 1.0 # Both won class TestSignalType: """Test SignalType enum.""" def test_signal_values(self): """Test signal type values.""" assert SignalType.BUY == "buy" assert SignalType.SELL == "sell" assert SignalType.HOLD == "hold" class TestTradeSignal: """Test TradeSignal dataclass.""" def test_trade_signal_creation(self): """Test creating a TradeSignal.""" signal = TradeSignal( symbol="AAPL", signal=SignalType.BUY, confidence=0.8, reason="RSI oversold", suggested_position=1000.0, ) assert signal.symbol == "AAPL" assert signal.signal == SignalType.BUY assert signal.confidence == 0.8 assert signal.reason == "RSI oversold" assert signal.suggested_position == 1000.0 class TestMarketAnalysis: """Test MarketAnalysis dataclass.""" def test_market_analysis_creation(self): """Test creating a MarketAnalysis.""" analysis = MarketAnalysis( symbol="AAPL", trend="uptrend", volatility=0.2, volume_trend="increasing", support_level=90.0, resistance_level=110.0, indicators={"rsi": 50.0}, ) assert analysis.symbol == "AAPL" assert analysis.trend == "uptrend" class TestTradeResult: """Test TradeResult dataclass.""" def test_trade_result_creation(self): """Test creating a TradeResult.""" result = TradeResult( symbol="AAPL", signal=SignalType.BUY, success=True, pnl=100.0, fee=10.0, message="Success", ) assert result.symbol == "AAPL" assert result.success is True assert result.pnl == 100.0 assert result.timestamp != "" # Auto-generated def test_trade_result_custom_timestamp(self): """Test TradeResult with custom timestamp.""" result = TradeResult( symbol="AAPL", signal=SignalType.BUY, success=True, pnl=100.0, fee=10.0, message="Success", timestamp="2024-01-01T00:00:00", ) assert result.timestamp == "2024-01-01T00:00:00"