stock/tests/unit/test_base_agent.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

601 lines
20 KiB
Python

"""Unit tests for BaseAgent abstract base class.
This module tests the BaseAgent class including initialization,
economic tracking integration, event hooks, and state management.
"""
import asyncio
from typing import Any, Dict
from unittest.mock import MagicMock
import pytest
from openclaw.agents.base import (
ActivityType,
AgentState,
BaseAgent,
EventCallback,
)
from openclaw.core.economy import SurvivalStatus
class ConcreteBaseAgent(BaseAgent):
"""Concrete agent implementation for testing BaseAgent (not a pytest test class)."""
async def decide_activity(self) -> ActivityType:
"""Return default activity."""
return ActivityType.ANALYZE
async def analyze(self, symbol: str) -> Dict[str, Any]:
"""Return default analysis."""
return {"symbol": symbol, "signal": "hold"}
class TestBaseAgentInitialization:
"""Test agent initialization."""
def test_default_initialization(self):
"""Test agent with default parameters."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
assert agent.agent_id == "test-agent"
assert agent.balance == 10000.0
assert agent.skill_level == 0.5 # Default
assert agent.state.agent_id == "test-agent"
assert agent.state.skill_level == 0.5
assert agent.state.win_rate == 0.5
assert agent.state.total_trades == 0
assert agent.state.winning_trades == 0
assert agent.state.unlocked_factors == []
assert agent.state.current_activity is None
assert agent.state.is_bankrupt is False
def test_custom_initialization(self):
"""Test agent with custom skill level."""
agent = ConcreteBaseAgent(
agent_id="custom-agent",
initial_capital=5000.0,
skill_level=0.8,
)
assert agent.agent_id == "custom-agent"
assert agent.balance == 5000.0
assert agent.skill_level == 0.8
assert agent.state.skill_level == 0.8
def test_economic_tracker_integration(self):
"""Test that economic tracker is properly initialized."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
assert agent.economic_tracker.agent_id == "test-agent"
assert agent.economic_tracker.initial_capital == 10000.0
assert agent.economic_tracker.balance == 10000.0
def test_event_hooks_initialized(self):
"""Test that event hooks are initialized."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
assert "on_trade" in agent._event_hooks
assert "on_learn" in agent._event_hooks
assert "on_bankrupt" in agent._event_hooks
assert "on_level_up" in agent._event_hooks
assert "on_factor_unlock" in agent._event_hooks
# All should start empty
assert agent._event_hooks["on_trade"] == []
assert agent._event_hooks["on_learn"] == []
def test_logger_initialized(self):
"""Test that logger is properly bound."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
assert agent.logger is not None
class TestBaseAgentProperties:
"""Test agent properties."""
def test_balance_property(self):
"""Test balance property reflects economic tracker."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
assert agent.balance == 10000.0
# Modify through tracker
agent.economic_tracker.calculate_decision_cost(
tokens_input=1000, tokens_output=500
)
assert agent.balance < 10000.0
def test_survival_status_property(self):
"""Test survival_status property."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
# At 100%, status is struggling (>=80% threshold)
assert agent.survival_status == SurvivalStatus.STRUGGLING
# Boost balance
agent.economic_tracker.balance = 16000.0
assert agent.survival_status == SurvivalStatus.THRIVING
def test_skill_level_property(self):
"""Test skill_level property."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
assert agent.skill_level == 0.5
agent.state.skill_level = 0.9
assert agent.skill_level == 0.9
def test_win_rate_property(self):
"""Test win_rate property."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
assert agent.win_rate == 0.5
agent.record_trade(is_win=True, pnl=100.0)
assert agent.win_rate == 1.0
class TestCanAfford:
"""Test can_afford method."""
def test_can_afford_with_safety_buffer(self):
"""Test affordability check with default 20% safety buffer."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
# With 20% buffer, can afford amount up to balance/1.2
# balance/1.2 = 10000/1.2 = 8333.33
assert agent.can_afford(8000.0) is True
assert agent.can_afford(8333.0) is True
assert agent.can_afford(8500.0) is False
def test_can_afford_custom_buffer(self):
"""Test affordability with custom safety buffer."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
# 50% buffer
assert agent.can_afford(6000.0, safety_buffer=1.5) is True
assert agent.can_afford(7000.0, safety_buffer=1.5) is False
# No buffer
assert agent.can_afford(10000.0, safety_buffer=1.0) is True
assert agent.can_afford(10001.0, safety_buffer=1.0) is False
def test_cannot_afford_when_bankrupt(self):
"""Test that bankrupt agent cannot afford anything."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.economic_tracker.balance = 0.0
assert agent.can_afford(1.0) is False
class TestCheckSurvival:
"""Test check_survival method."""
def test_survival_when_stable(self):
"""Test survival check when agent is not bankrupt."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
assert agent.check_survival() is True
assert agent.state.is_bankrupt is False
def test_bankruptcy_detection(self):
"""Test bankruptcy detection."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.economic_tracker.balance = 1000.0 # Below 30% threshold
assert agent.check_survival() is False
assert agent.state.is_bankrupt is True
def test_bankrupt_event_triggered_once(self):
"""Test that on_bankrupt event is triggered only once."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
# Track event calls
calls = []
def on_bankrupt(agent_ref, **kwargs):
calls.append(1)
agent.register_hook("on_bankrupt", on_bankrupt)
# Set bankrupt
agent.economic_tracker.balance = 1000.0
# First check
agent.check_survival()
assert len(calls) == 1
# Second check - should not trigger again
agent.check_survival()
assert len(calls) == 1
class TestRecordTrade:
"""Test record_trade method."""
def test_record_winning_trade(self):
"""Test recording a winning trade."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.record_trade(is_win=True, pnl=150.0)
assert agent.state.total_trades == 1
assert agent.state.winning_trades == 1
assert agent.state.win_rate == 1.0
def test_record_losing_trade(self):
"""Test recording a losing trade."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.record_trade(is_win=False, pnl=-100.0)
assert agent.state.total_trades == 1
assert agent.state.winning_trades == 0
assert agent.state.win_rate == 0.0
def test_win_rate_calculation(self):
"""Test win rate calculation across multiple trades."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.record_trade(is_win=True, pnl=100.0)
agent.record_trade(is_win=False, pnl=-50.0)
agent.record_trade(is_win=True, pnl=75.0)
agent.record_trade(is_win=True, pnl=120.0)
# 3 wins out of 4 = 75%
assert agent.state.total_trades == 4
assert agent.state.winning_trades == 3
assert agent.state.win_rate == 0.75
def test_trade_event_triggered(self):
"""Test that on_trade event is triggered."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
event_data = {}
def on_trade(agent_ref, **kwargs):
event_data.update(kwargs)
agent.register_hook("on_trade", on_trade)
agent.record_trade(is_win=True, pnl=100.0)
assert event_data.get("is_win") is True
assert event_data.get("pnl") == 100.0
class TestImproveSkill:
"""Test improve_skill method."""
def test_skill_improvement(self):
"""Test skill level improvement."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5)
agent.improve_skill(0.2)
assert agent.skill_level == 0.7
def test_skill_capped_at_one(self):
"""Test that skill cannot exceed 1.0."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.9)
agent.improve_skill(0.2)
assert agent.skill_level == 1.0
def test_no_improvement_when_already_max(self):
"""Test no improvement when already at max level."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=1.0)
event_triggered = False
def on_level_up(**kwargs):
nonlocal event_triggered
event_triggered = True
agent.register_hook("on_level_up", on_level_up)
agent.improve_skill(0.1)
# Event should not trigger when no actual improvement
assert event_triggered is False
def test_level_up_event_triggered(self):
"""Test that on_level_up event is triggered."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5)
event_data = {}
def on_level_up(agent_ref, **kwargs):
event_data.update(kwargs)
agent.register_hook("on_level_up", on_level_up)
agent.improve_skill(0.1)
assert event_data.get("old_level") == 0.5
class TestUnlockFactor:
"""Test unlock_factor method."""
def test_unlock_new_factor(self):
"""Test unlocking a new factor."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
result = agent.unlock_factor("momentum", cost=500.0)
assert result is True
assert "momentum" in agent.state.unlocked_factors
assert agent.balance == 9500.0
def test_unlock_already_unlocked(self):
"""Test unlocking an already unlocked factor returns True."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.unlock_factor("momentum", cost=500.0)
result = agent.unlock_factor("momentum", cost=500.0)
assert result is True
# Should not deduct cost again
assert agent.balance == 9500.0
def test_cannot_afford_factor(self):
"""Test unlocking when cannot afford."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=1000.0)
result = agent.unlock_factor("expensive", cost=5000.0)
assert result is False
assert "expensive" not in agent.state.unlocked_factors
def test_factor_unlock_event_triggered(self):
"""Test that on_factor_unlock event is triggered."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
event_data = {}
def on_factor_unlock(agent_ref, **kwargs):
event_data.update(kwargs)
agent.register_hook("on_factor_unlock", on_factor_unlock)
agent.unlock_factor("momentum", cost=500.0)
assert event_data.get("factor_name") == "momentum"
assert event_data.get("cost") == 500.0
class TestIsFactorUnlocked:
"""Test is_factor_unlocked method."""
def test_factor_unlocked(self):
"""Test checking unlocked factor."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.unlock_factor("momentum", cost=100.0)
assert agent.is_factor_unlocked("momentum") is True
def test_factor_not_unlocked(self):
"""Test checking factor not unlocked."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
assert agent.is_factor_unlocked("momentum") is False
class TestEventHooks:
"""Test event hook system."""
def test_register_hook(self):
"""Test registering event hooks."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
callback = MagicMock()
agent.register_hook("on_trade", callback)
assert callback in agent._event_hooks["on_trade"]
def test_unregister_hook(self):
"""Test unregistering event hooks."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
callback = MagicMock()
agent.register_hook("on_trade", callback)
agent.unregister_hook("on_trade", callback)
assert callback not in agent._event_hooks["on_trade"]
def test_register_unknown_event_raises(self):
"""Test that registering unknown event raises ValueError."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
with pytest.raises(ValueError, match="Unknown event"):
agent.register_hook("on_unknown_event", MagicMock())
def test_event_callback_receives_agent(self):
"""Test that callback receives agent reference."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
received_agent = None
def callback(agent_ref, **kwargs):
nonlocal received_agent
received_agent = agent_ref
agent.register_hook("on_trade", callback)
agent.record_trade(is_win=True, pnl=100.0)
assert received_agent is agent
def test_multiple_hooks_same_event(self):
"""Test multiple hooks for the same event."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
calls = []
def callback1(agent_ref, **kwargs):
calls.append("callback1")
def callback2(agent_ref, **kwargs):
calls.append("callback2")
agent.register_hook("on_trade", callback1)
agent.register_hook("on_trade", callback2)
agent.record_trade(is_win=True, pnl=100.0)
assert "callback1" in calls
assert "callback2" in calls
def test_hook_error_handling(self):
"""Test that hook errors don't stop other hooks."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
calls = []
def error_callback(agent_ref, **kwargs):
raise ValueError("Test error")
def good_callback(agent_ref, **kwargs):
calls.append("good")
agent.register_hook("on_trade", error_callback)
agent.register_hook("on_trade", good_callback)
# Should not raise
agent.record_trade(is_win=True, pnl=100.0)
assert "good" in calls
class TestAbstractMethods:
"""Test abstract method requirements."""
def test_cannot_instantiate_base(self):
"""Test that BaseAgent cannot be instantiated directly."""
class IncompleteAgent(BaseAgent):
pass
with pytest.raises(TypeError):
IncompleteAgent(agent_id="test", initial_capital=10000.0)
def test_decide_activity_must_be_implemented(self):
"""Test that decide_activity must be implemented."""
class NoDecideAgent(BaseAgent):
async def analyze(self, symbol: str) -> Dict[str, Any]:
return {}
with pytest.raises(TypeError):
NoDecideAgent(agent_id="test", initial_capital=10000.0)
def test_analyze_must_be_implemented(self):
"""Test that analyze must be implemented."""
class NoAnalyzeAgent(BaseAgent):
async def decide_activity(self) -> ActivityType:
return ActivityType.REST
with pytest.raises(TypeError):
NoAnalyzeAgent(agent_id="test", initial_capital=10000.0)
def test_decide_activity_returns_activity(self):
"""Test that decide_activity returns an ActivityType."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
result = asyncio.run(agent.decide_activity())
assert isinstance(result, ActivityType)
def test_analyze_returns_dict(self):
"""Test that analyze returns a dictionary."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
result = asyncio.run(agent.analyze("AAPL"))
assert isinstance(result, dict)
assert result["symbol"] == "AAPL"
class TestGetStatusDict:
"""Test get_status_dict method."""
def test_status_dict_contains_required_fields(self):
"""Test that status dict has all required fields."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
status = agent.get_status_dict()
assert status["agent_id"] == "test-agent"
assert status["balance"] == 10000.0
assert "status" in status
assert status["skill_level"] == 0.5
assert status["win_rate"] == 0.5
assert status["total_trades"] == 0
assert status["unlocked_factors"] == 0
assert status["is_bankrupt"] is False
def test_status_dict_reflects_state(self):
"""Test that status dict reflects current state."""
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
agent.record_trade(is_win=True, pnl=100.0)
agent.unlock_factor("test_factor", cost=100.0)
status = agent.get_status_dict()
assert status["total_trades"] == 1
assert status["win_rate"] == 1.0
assert status["unlocked_factors"] == 1
class TestRepr:
"""Test __repr__ method."""
def test_repr_contains_key_info(self):
"""Test that repr contains key information."""
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
repr_str = repr(agent)
assert "ConcreteBaseAgent" in repr_str or "Agent" in repr_str
assert "test-agent" in repr_str
assert "$10,000.00" in repr_str or "$10000" in repr_str or "10000" in repr_str
assert "50.0%" in repr_str or "50%" in repr_str or "0.5" in repr_str
class TestAgentState:
"""Test AgentState dataclass."""
def test_default_state(self):
"""Test default AgentState values."""
state = AgentState(agent_id="test")
assert state.agent_id == "test"
assert state.skill_level == 0.5
assert state.win_rate == 0.5
assert state.total_trades == 0
assert state.winning_trades == 0
assert state.unlocked_factors == []
assert state.current_activity is None
assert state.is_bankrupt is False
def test_state_with_custom_values(self):
"""Test AgentState with custom values."""
state = AgentState(
agent_id="test",
skill_level=0.9,
win_rate=0.75,
total_trades=100,
)
assert state.skill_level == 0.9
assert state.win_rate == 0.75
assert state.total_trades == 100
class TestActivityType:
"""Test ActivityType enum."""
def test_activity_types(self):
"""Test all activity type values."""
assert ActivityType.TRADE == "trade"
assert ActivityType.LEARN == "learn"
assert ActivityType.ANALYZE == "analyze"
assert ActivityType.REST == "rest"
assert ActivityType.PAPER_TRADE == "paper_trade"
def test_activity_type_comparison(self):
"""Test activity type comparison."""
assert ActivityType.TRADE == "trade"
assert ActivityType.TRADE != "learn"