stock/tests/unit/test_trader_agent.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

508 lines
16 KiB
Python

"""Unit tests for TraderAgent.
This module tests the TraderAgent class including market analysis,
signal generation, and trade execution.
"""
import asyncio
from unittest.mock import patch
import pytest
from openclaw.agents.base import ActivityType
from openclaw.agents.trader import (
MarketAnalysis,
SignalType,
TradeResult,
TradeSignal,
TraderAgent,
)
from openclaw.core.economy import SurvivalStatus
class TestTraderAgentInitialization:
"""Test TraderAgent initialization."""
def test_default_initialization(self):
"""Test agent with default parameters."""
agent = TraderAgent(agent_id="trader-1", initial_capital=10000.0)
assert agent.agent_id == "trader-1"
assert agent.balance == 10000.0
assert agent.skill_level == 0.5
assert agent.max_position_pct == 0.2
assert agent._trade_history == []
assert agent._paper_trade_history == []
def test_custom_initialization(self):
"""Test agent with custom parameters."""
agent = TraderAgent(
agent_id="trader-2",
initial_capital=5000.0,
skill_level=0.8,
max_position_pct=0.3,
)
assert agent.agent_id == "trader-2"
assert agent.balance == 5000.0
assert agent.skill_level == 0.8
assert agent.max_position_pct == 0.3
def test_inherits_from_base_agent(self):
"""Test that TraderAgent inherits from BaseAgent."""
from openclaw.agents.base import BaseAgent
agent = TraderAgent(agent_id="test", initial_capital=10000.0)
assert isinstance(agent, BaseAgent)
class TestDecideActivity:
"""Test decide_activity method."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_bankrupt_agent_only_rests(self, agent):
"""Test that bankrupt agent can only rest."""
agent.economic_tracker.balance = 0 # Bankrupt
result = asyncio.run(agent.decide_activity())
assert result == ActivityType.REST
def test_critical_status_prefers_learning(self, agent):
"""Test critical status leads to learning."""
agent.economic_tracker.balance = 3500.0 # Critical
agent.state.skill_level = 0.5
result = asyncio.run(agent.decide_activity())
assert result in [ActivityType.LEARN, ActivityType.PAPER_TRADE]
def test_thriving_status_prefers_trading(self, agent):
"""Test thriving status leads to trading."""
agent.economic_tracker.balance = 20000.0 # Thriving
# Run multiple times to account for randomness
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
# Most should be TRADE or ANALYZE
trade_like = [r for r in results if r in [ActivityType.TRADE, ActivityType.ANALYZE]]
assert len(trade_like) >= 10 # At least half
def test_struggling_status_more_paper_trading(self, agent):
"""Test struggling status prefers paper trading."""
agent.economic_tracker.balance = 8500.0 # Struggling
# Run multiple times
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
# Some should be paper trade
paper_trades = [r for r in results if r == ActivityType.PAPER_TRADE]
assert len(paper_trades) >= 5 # At least some paper trades
class TestAnalyzeMarket:
"""Test analyze_market method."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_returns_market_analysis(self, agent):
"""Test that analyze_market returns MarketAnalysis."""
result = agent.analyze_market("AAPL")
assert isinstance(result, MarketAnalysis)
assert result.symbol == "AAPL"
assert result.trend in ["uptrend", "downtrend", "sideways"]
assert 0 <= result.volatility <= 1
assert result.volume_trend in ["increasing", "decreasing"]
assert result.support_level < result.resistance_level
def test_indicators_present(self, agent):
"""Test that technical indicators are present."""
result = agent.analyze_market("TSLA")
assert "rsi" in result.indicators
assert "macd" in result.indicators
assert "sma_20" in result.indicators
assert "current_price" in result.indicators
def test_high_skill_more_accurate(self):
"""Test that high skill produces more consistent analysis."""
high_skill_agent = TraderAgent(
agent_id="high", initial_capital=10000.0, skill_level=0.9
)
# Multiple analyses should have RSI in tighter range
rsis = []
for _ in range(10):
analysis = high_skill_agent.analyze_market("AAPL")
rsis.append(analysis.indicators["rsi"])
# RSIs should be within 20 points (high skill = more accurate)
assert max(rsis) - min(rsis) <= 30
class TestGenerateSignal:
"""Test generate_signal method."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_oversold_generates_buy(self, agent):
"""Test oversold condition generates buy signal."""
analysis = MarketAnalysis(
symbol="AAPL",
trend="downtrend",
volatility=0.2,
volume_trend="increasing",
support_level=90.0,
resistance_level=110.0,
indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0},
)
signal = agent.generate_signal(analysis)
assert signal.signal == SignalType.BUY
assert signal.confidence > 0.5
assert "oversold" in signal.reason.lower() or "RSI" in signal.reason
def test_overbought_generates_sell(self, agent):
"""Test overbought condition generates sell signal."""
analysis = MarketAnalysis(
symbol="AAPL",
trend="uptrend",
volatility=0.2,
volume_trend="increasing",
support_level=90.0,
resistance_level=110.0,
indicators={"rsi": 70.0, "macd": -0.5, "current_price": 100.0},
)
signal = agent.generate_signal(analysis)
assert signal.signal == SignalType.SELL
assert signal.confidence > 0.5
assert "overbought" in signal.reason.lower() or "RSI" in signal.reason
def test_neutral_generates_hold(self, agent):
"""Test neutral condition generates hold signal."""
analysis = MarketAnalysis(
symbol="AAPL",
trend="sideways",
volatility=0.2,
volume_trend="flat",
support_level=90.0,
resistance_level=110.0,
indicators={"rsi": 50.0, "macd": 0.0, "current_price": 100.0},
)
signal = agent.generate_signal(analysis)
assert signal.signal == SignalType.HOLD
assert signal.suggested_position == 0.0
def test_suggested_position_based_on_confidence(self, agent):
"""Test that position size is based on confidence."""
analysis = MarketAnalysis(
symbol="AAPL",
trend="uptrend",
volatility=0.2,
volume_trend="increasing",
support_level=90.0,
resistance_level=110.0,
indicators={"rsi": 30.0, "macd": 0.5, "current_price": 100.0},
)
signal = agent.generate_signal(analysis)
assert signal.suggested_position > 0
# Should be less than max_position_pct of balance
assert signal.suggested_position <= agent.balance * agent.max_position_pct
class TestExecuteTrade:
"""Test execute_trade method."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_trade_success(self, agent):
"""Test successful trade execution."""
initial_balance = agent.balance
with patch("random.random", return_value=0.1): # Force win
result = agent.execute_trade("AAPL", SignalType.BUY, 1000.0)
assert isinstance(result, TradeResult)
assert result.symbol == "AAPL"
assert result.signal == SignalType.BUY
assert result.success is True
assert result.fee > 0
assert "trade history" in agent._trade_history or len(agent._trade_history) > 0
def test_trade_records_in_history(self, agent):
"""Test that trade is recorded in history."""
with patch("random.random", return_value=0.5):
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
assert len(agent._trade_history) == 1
assert agent._trade_history[0].symbol == "AAPL"
def test_trade_updates_stats(self, agent):
"""Test that trade updates agent statistics."""
initial_trades = agent.state.total_trades
with patch("random.random", return_value=0.1): # Force win
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
assert agent.state.total_trades == initial_trades + 1
def test_trade_deducts_costs(self, agent):
"""Test that trade deducts costs from balance."""
initial_balance = agent.balance
with patch("random.random", return_value=0.5):
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
# Balance should change due to fees/PnL
assert agent.balance != initial_balance
def test_insufficient_funds_fails(self, agent):
"""Test that trade fails when insufficient funds."""
result = agent.execute_trade("AAPL", SignalType.BUY, 50000.0)
assert result.success is False
assert "insufficient" in result.message.lower() or "Insufficient" in result.message
class TestPaperTrade:
"""Test paper_trade method."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_paper_trade_returns_result(self, agent):
"""Test that paper trade returns TradeResult."""
result = agent.paper_trade("AAPL", SignalType.BUY, 1000.0)
assert isinstance(result, TradeResult)
assert result.symbol == "AAPL"
assert result.success is True
def test_paper_trade_records_in_separate_history(self, agent):
"""Test that paper trade is recorded separately."""
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
assert len(agent._paper_trade_history) == 1
assert len(agent._trade_history) == 0
def test_paper_trade_minimal_cost(self, agent):
"""Test that paper trade only deducts minimal cost."""
initial_balance = agent.balance
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
# Should only deduct small data cost, not full trade cost
balance_change = initial_balance - agent.balance
assert balance_change < 0.1 # Very small cost
class TestAnalyze:
"""Test analyze method (async)."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_analyze_returns_dict(self, agent):
"""Test that analyze returns a dictionary."""
result = asyncio.run(agent.analyze("AAPL"))
assert isinstance(result, dict)
assert result["symbol"] == "AAPL"
assert "signal" in result
assert "confidence" in result
assert "reason" in result
assert "market_analysis" in result
assert "cost" in result
def test_analyze_deducts_cost(self, agent):
"""Test that analyze deducts decision cost."""
initial_balance = agent.balance
asyncio.run(agent.analyze("AAPL"))
assert agent.balance < initial_balance
def test_analyze_stores_last_analysis(self, agent):
"""Test that analyze stores the analysis."""
assert agent._last_analysis is None
asyncio.run(agent.analyze("TSLA"))
assert agent._last_analysis is not None
assert agent._last_analysis.symbol == "TSLA"
class TestTradeHistory:
"""Test trade history methods."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_get_trade_history_returns_copy(self, agent):
"""Test that get_trade_history returns a copy."""
with patch("random.random", return_value=0.5):
agent.execute_trade("AAPL", SignalType.BUY, 500.0)
history = agent.get_trade_history()
history.append(None) # Modify the copy
# Original should be unchanged
assert len(agent._trade_history) == 1
def test_get_paper_trade_history_returns_copy(self, agent):
"""Test that get_paper_trade_history returns a copy."""
agent.paper_trade("AAPL", SignalType.BUY, 500.0)
history = agent.get_paper_trade_history()
history.append(None) # Modify the copy
# Original should be unchanged
assert len(agent._paper_trade_history) == 1
class TestPerformanceStats:
"""Test performance statistics."""
@pytest.fixture
def agent(self):
"""Create a test agent."""
return TraderAgent(agent_id="test", initial_capital=10000.0)
def test_stats_structure(self, agent):
"""Test that stats contains expected keys."""
stats = agent.get_performance_stats()
assert "total_real_trades" in stats
assert "total_paper_trades" in stats
assert "real_pnl" in stats
assert "paper_pnl" in stats
assert "win_rate" in stats
assert "skill_level" in stats
assert "balance" in stats
assert "survival_status" in stats
def test_stats_with_trades(self, agent):
"""Test stats calculation with trades."""
with patch("random.random", return_value=0.1):
agent.execute_trade("AAPL", SignalType.BUY, 1000.0)
agent.execute_trade("TSLA", SignalType.SELL, 1000.0)
agent.paper_trade("NVDA", SignalType.BUY, 1000.0)
stats = agent.get_performance_stats()
assert stats["total_real_trades"] == 2
assert stats["total_paper_trades"] == 1
assert stats["win_rate"] == 1.0 # Both won
class TestSignalType:
"""Test SignalType enum."""
def test_signal_values(self):
"""Test signal type values."""
assert SignalType.BUY == "buy"
assert SignalType.SELL == "sell"
assert SignalType.HOLD == "hold"
class TestTradeSignal:
"""Test TradeSignal dataclass."""
def test_trade_signal_creation(self):
"""Test creating a TradeSignal."""
signal = TradeSignal(
symbol="AAPL",
signal=SignalType.BUY,
confidence=0.8,
reason="RSI oversold",
suggested_position=1000.0,
)
assert signal.symbol == "AAPL"
assert signal.signal == SignalType.BUY
assert signal.confidence == 0.8
assert signal.reason == "RSI oversold"
assert signal.suggested_position == 1000.0
class TestMarketAnalysis:
"""Test MarketAnalysis dataclass."""
def test_market_analysis_creation(self):
"""Test creating a MarketAnalysis."""
analysis = MarketAnalysis(
symbol="AAPL",
trend="uptrend",
volatility=0.2,
volume_trend="increasing",
support_level=90.0,
resistance_level=110.0,
indicators={"rsi": 50.0},
)
assert analysis.symbol == "AAPL"
assert analysis.trend == "uptrend"
class TestTradeResult:
"""Test TradeResult dataclass."""
def test_trade_result_creation(self):
"""Test creating a TradeResult."""
result = TradeResult(
symbol="AAPL",
signal=SignalType.BUY,
success=True,
pnl=100.0,
fee=10.0,
message="Success",
)
assert result.symbol == "AAPL"
assert result.success is True
assert result.pnl == 100.0
assert result.timestamp != "" # Auto-generated
def test_trade_result_custom_timestamp(self):
"""Test TradeResult with custom timestamp."""
result = TradeResult(
symbol="AAPL",
signal=SignalType.BUY,
success=True,
pnl=100.0,
fee=10.0,
message="Success",
timestamp="2024-01-01T00:00:00",
)
assert result.timestamp == "2024-01-01T00:00:00"