601 lines
20 KiB
Python
601 lines
20 KiB
Python
"""Unit tests for BaseAgent abstract base class.
|
|
|
|
This module tests the BaseAgent class including initialization,
|
|
economic tracking integration, event hooks, and state management.
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import Any, Dict
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from openclaw.agents.base import (
|
|
ActivityType,
|
|
AgentState,
|
|
BaseAgent,
|
|
EventCallback,
|
|
)
|
|
from openclaw.core.economy import SurvivalStatus
|
|
|
|
|
|
class ConcreteBaseAgent(BaseAgent):
|
|
"""Concrete agent implementation for testing BaseAgent (not a pytest test class)."""
|
|
|
|
async def decide_activity(self) -> ActivityType:
|
|
"""Return default activity."""
|
|
return ActivityType.ANALYZE
|
|
|
|
async def analyze(self, symbol: str) -> Dict[str, Any]:
|
|
"""Return default analysis."""
|
|
return {"symbol": symbol, "signal": "hold"}
|
|
|
|
|
|
class TestBaseAgentInitialization:
|
|
"""Test agent initialization."""
|
|
|
|
def test_default_initialization(self):
|
|
"""Test agent with default parameters."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
assert agent.agent_id == "test-agent"
|
|
assert agent.balance == 10000.0
|
|
assert agent.skill_level == 0.5 # Default
|
|
assert agent.state.agent_id == "test-agent"
|
|
assert agent.state.skill_level == 0.5
|
|
assert agent.state.win_rate == 0.5
|
|
assert agent.state.total_trades == 0
|
|
assert agent.state.winning_trades == 0
|
|
assert agent.state.unlocked_factors == []
|
|
assert agent.state.current_activity is None
|
|
assert agent.state.is_bankrupt is False
|
|
|
|
def test_custom_initialization(self):
|
|
"""Test agent with custom skill level."""
|
|
agent = ConcreteBaseAgent(
|
|
agent_id="custom-agent",
|
|
initial_capital=5000.0,
|
|
skill_level=0.8,
|
|
)
|
|
|
|
assert agent.agent_id == "custom-agent"
|
|
assert agent.balance == 5000.0
|
|
assert agent.skill_level == 0.8
|
|
assert agent.state.skill_level == 0.8
|
|
|
|
def test_economic_tracker_integration(self):
|
|
"""Test that economic tracker is properly initialized."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
assert agent.economic_tracker.agent_id == "test-agent"
|
|
assert agent.economic_tracker.initial_capital == 10000.0
|
|
assert agent.economic_tracker.balance == 10000.0
|
|
|
|
def test_event_hooks_initialized(self):
|
|
"""Test that event hooks are initialized."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
assert "on_trade" in agent._event_hooks
|
|
assert "on_learn" in agent._event_hooks
|
|
assert "on_bankrupt" in agent._event_hooks
|
|
assert "on_level_up" in agent._event_hooks
|
|
assert "on_factor_unlock" in agent._event_hooks
|
|
|
|
# All should start empty
|
|
assert agent._event_hooks["on_trade"] == []
|
|
assert agent._event_hooks["on_learn"] == []
|
|
|
|
def test_logger_initialized(self):
|
|
"""Test that logger is properly bound."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
assert agent.logger is not None
|
|
|
|
|
|
class TestBaseAgentProperties:
|
|
"""Test agent properties."""
|
|
|
|
def test_balance_property(self):
|
|
"""Test balance property reflects economic tracker."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.balance == 10000.0
|
|
|
|
# Modify through tracker
|
|
agent.economic_tracker.calculate_decision_cost(
|
|
tokens_input=1000, tokens_output=500
|
|
)
|
|
assert agent.balance < 10000.0
|
|
|
|
def test_survival_status_property(self):
|
|
"""Test survival_status property."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
# At 100%, status is struggling (>=80% threshold)
|
|
assert agent.survival_status == SurvivalStatus.STRUGGLING
|
|
|
|
# Boost balance
|
|
agent.economic_tracker.balance = 16000.0
|
|
assert agent.survival_status == SurvivalStatus.THRIVING
|
|
|
|
def test_skill_level_property(self):
|
|
"""Test skill_level property."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.skill_level == 0.5
|
|
|
|
agent.state.skill_level = 0.9
|
|
assert agent.skill_level == 0.9
|
|
|
|
def test_win_rate_property(self):
|
|
"""Test win_rate property."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.win_rate == 0.5
|
|
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
assert agent.win_rate == 1.0
|
|
|
|
|
|
class TestCanAfford:
|
|
"""Test can_afford method."""
|
|
|
|
def test_can_afford_with_safety_buffer(self):
|
|
"""Test affordability check with default 20% safety buffer."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
# With 20% buffer, can afford amount up to balance/1.2
|
|
# balance/1.2 = 10000/1.2 = 8333.33
|
|
assert agent.can_afford(8000.0) is True
|
|
assert agent.can_afford(8333.0) is True
|
|
assert agent.can_afford(8500.0) is False
|
|
|
|
def test_can_afford_custom_buffer(self):
|
|
"""Test affordability with custom safety buffer."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
# 50% buffer
|
|
assert agent.can_afford(6000.0, safety_buffer=1.5) is True
|
|
assert agent.can_afford(7000.0, safety_buffer=1.5) is False
|
|
|
|
# No buffer
|
|
assert agent.can_afford(10000.0, safety_buffer=1.0) is True
|
|
assert agent.can_afford(10001.0, safety_buffer=1.0) is False
|
|
|
|
def test_cannot_afford_when_bankrupt(self):
|
|
"""Test that bankrupt agent cannot afford anything."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
agent.economic_tracker.balance = 0.0
|
|
|
|
assert agent.can_afford(1.0) is False
|
|
|
|
|
|
class TestCheckSurvival:
|
|
"""Test check_survival method."""
|
|
|
|
def test_survival_when_stable(self):
|
|
"""Test survival check when agent is not bankrupt."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.check_survival() is True
|
|
assert agent.state.is_bankrupt is False
|
|
|
|
def test_bankruptcy_detection(self):
|
|
"""Test bankruptcy detection."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
agent.economic_tracker.balance = 1000.0 # Below 30% threshold
|
|
|
|
assert agent.check_survival() is False
|
|
assert agent.state.is_bankrupt is True
|
|
|
|
def test_bankrupt_event_triggered_once(self):
|
|
"""Test that on_bankrupt event is triggered only once."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
# Track event calls
|
|
calls = []
|
|
def on_bankrupt(agent_ref, **kwargs):
|
|
calls.append(1)
|
|
|
|
agent.register_hook("on_bankrupt", on_bankrupt)
|
|
|
|
# Set bankrupt
|
|
agent.economic_tracker.balance = 1000.0
|
|
|
|
# First check
|
|
agent.check_survival()
|
|
assert len(calls) == 1
|
|
|
|
# Second check - should not trigger again
|
|
agent.check_survival()
|
|
assert len(calls) == 1
|
|
|
|
|
|
class TestRecordTrade:
|
|
"""Test record_trade method."""
|
|
|
|
def test_record_winning_trade(self):
|
|
"""Test recording a winning trade."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.record_trade(is_win=True, pnl=150.0)
|
|
|
|
assert agent.state.total_trades == 1
|
|
assert agent.state.winning_trades == 1
|
|
assert agent.state.win_rate == 1.0
|
|
|
|
def test_record_losing_trade(self):
|
|
"""Test recording a losing trade."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.record_trade(is_win=False, pnl=-100.0)
|
|
|
|
assert agent.state.total_trades == 1
|
|
assert agent.state.winning_trades == 0
|
|
assert agent.state.win_rate == 0.0
|
|
|
|
def test_win_rate_calculation(self):
|
|
"""Test win rate calculation across multiple trades."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
agent.record_trade(is_win=False, pnl=-50.0)
|
|
agent.record_trade(is_win=True, pnl=75.0)
|
|
agent.record_trade(is_win=True, pnl=120.0)
|
|
|
|
# 3 wins out of 4 = 75%
|
|
assert agent.state.total_trades == 4
|
|
assert agent.state.winning_trades == 3
|
|
assert agent.state.win_rate == 0.75
|
|
|
|
def test_trade_event_triggered(self):
|
|
"""Test that on_trade event is triggered."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
event_data = {}
|
|
def on_trade(agent_ref, **kwargs):
|
|
event_data.update(kwargs)
|
|
|
|
agent.register_hook("on_trade", on_trade)
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
|
|
assert event_data.get("is_win") is True
|
|
assert event_data.get("pnl") == 100.0
|
|
|
|
|
|
class TestImproveSkill:
|
|
"""Test improve_skill method."""
|
|
|
|
def test_skill_improvement(self):
|
|
"""Test skill level improvement."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5)
|
|
|
|
agent.improve_skill(0.2)
|
|
|
|
assert agent.skill_level == 0.7
|
|
|
|
def test_skill_capped_at_one(self):
|
|
"""Test that skill cannot exceed 1.0."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.9)
|
|
|
|
agent.improve_skill(0.2)
|
|
|
|
assert agent.skill_level == 1.0
|
|
|
|
def test_no_improvement_when_already_max(self):
|
|
"""Test no improvement when already at max level."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=1.0)
|
|
|
|
event_triggered = False
|
|
def on_level_up(**kwargs):
|
|
nonlocal event_triggered
|
|
event_triggered = True
|
|
|
|
agent.register_hook("on_level_up", on_level_up)
|
|
agent.improve_skill(0.1)
|
|
|
|
# Event should not trigger when no actual improvement
|
|
assert event_triggered is False
|
|
|
|
def test_level_up_event_triggered(self):
|
|
"""Test that on_level_up event is triggered."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5)
|
|
|
|
event_data = {}
|
|
def on_level_up(agent_ref, **kwargs):
|
|
event_data.update(kwargs)
|
|
|
|
agent.register_hook("on_level_up", on_level_up)
|
|
agent.improve_skill(0.1)
|
|
|
|
assert event_data.get("old_level") == 0.5
|
|
|
|
|
|
class TestUnlockFactor:
|
|
"""Test unlock_factor method."""
|
|
|
|
def test_unlock_new_factor(self):
|
|
"""Test unlocking a new factor."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
result = agent.unlock_factor("momentum", cost=500.0)
|
|
|
|
assert result is True
|
|
assert "momentum" in agent.state.unlocked_factors
|
|
assert agent.balance == 9500.0
|
|
|
|
def test_unlock_already_unlocked(self):
|
|
"""Test unlocking an already unlocked factor returns True."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.unlock_factor("momentum", cost=500.0)
|
|
result = agent.unlock_factor("momentum", cost=500.0)
|
|
|
|
assert result is True
|
|
# Should not deduct cost again
|
|
assert agent.balance == 9500.0
|
|
|
|
def test_cannot_afford_factor(self):
|
|
"""Test unlocking when cannot afford."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=1000.0)
|
|
|
|
result = agent.unlock_factor("expensive", cost=5000.0)
|
|
|
|
assert result is False
|
|
assert "expensive" not in agent.state.unlocked_factors
|
|
|
|
def test_factor_unlock_event_triggered(self):
|
|
"""Test that on_factor_unlock event is triggered."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
event_data = {}
|
|
def on_factor_unlock(agent_ref, **kwargs):
|
|
event_data.update(kwargs)
|
|
|
|
agent.register_hook("on_factor_unlock", on_factor_unlock)
|
|
agent.unlock_factor("momentum", cost=500.0)
|
|
|
|
assert event_data.get("factor_name") == "momentum"
|
|
assert event_data.get("cost") == 500.0
|
|
|
|
|
|
class TestIsFactorUnlocked:
|
|
"""Test is_factor_unlocked method."""
|
|
|
|
def test_factor_unlocked(self):
|
|
"""Test checking unlocked factor."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.unlock_factor("momentum", cost=100.0)
|
|
|
|
assert agent.is_factor_unlocked("momentum") is True
|
|
|
|
def test_factor_not_unlocked(self):
|
|
"""Test checking factor not unlocked."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.is_factor_unlocked("momentum") is False
|
|
|
|
|
|
class TestEventHooks:
|
|
"""Test event hook system."""
|
|
|
|
def test_register_hook(self):
|
|
"""Test registering event hooks."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
callback = MagicMock()
|
|
agent.register_hook("on_trade", callback)
|
|
|
|
assert callback in agent._event_hooks["on_trade"]
|
|
|
|
def test_unregister_hook(self):
|
|
"""Test unregistering event hooks."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
callback = MagicMock()
|
|
agent.register_hook("on_trade", callback)
|
|
agent.unregister_hook("on_trade", callback)
|
|
|
|
assert callback not in agent._event_hooks["on_trade"]
|
|
|
|
def test_register_unknown_event_raises(self):
|
|
"""Test that registering unknown event raises ValueError."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
with pytest.raises(ValueError, match="Unknown event"):
|
|
agent.register_hook("on_unknown_event", MagicMock())
|
|
|
|
def test_event_callback_receives_agent(self):
|
|
"""Test that callback receives agent reference."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
received_agent = None
|
|
def callback(agent_ref, **kwargs):
|
|
nonlocal received_agent
|
|
received_agent = agent_ref
|
|
|
|
agent.register_hook("on_trade", callback)
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
|
|
assert received_agent is agent
|
|
|
|
def test_multiple_hooks_same_event(self):
|
|
"""Test multiple hooks for the same event."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
calls = []
|
|
def callback1(agent_ref, **kwargs):
|
|
calls.append("callback1")
|
|
def callback2(agent_ref, **kwargs):
|
|
calls.append("callback2")
|
|
|
|
agent.register_hook("on_trade", callback1)
|
|
agent.register_hook("on_trade", callback2)
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
|
|
assert "callback1" in calls
|
|
assert "callback2" in calls
|
|
|
|
def test_hook_error_handling(self):
|
|
"""Test that hook errors don't stop other hooks."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
calls = []
|
|
def error_callback(agent_ref, **kwargs):
|
|
raise ValueError("Test error")
|
|
def good_callback(agent_ref, **kwargs):
|
|
calls.append("good")
|
|
|
|
agent.register_hook("on_trade", error_callback)
|
|
agent.register_hook("on_trade", good_callback)
|
|
|
|
# Should not raise
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
|
|
assert "good" in calls
|
|
|
|
|
|
class TestAbstractMethods:
|
|
"""Test abstract method requirements."""
|
|
|
|
def test_cannot_instantiate_base(self):
|
|
"""Test that BaseAgent cannot be instantiated directly."""
|
|
|
|
class IncompleteAgent(BaseAgent):
|
|
pass
|
|
|
|
with pytest.raises(TypeError):
|
|
IncompleteAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_decide_activity_must_be_implemented(self):
|
|
"""Test that decide_activity must be implemented."""
|
|
|
|
class NoDecideAgent(BaseAgent):
|
|
async def analyze(self, symbol: str) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
with pytest.raises(TypeError):
|
|
NoDecideAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_analyze_must_be_implemented(self):
|
|
"""Test that analyze must be implemented."""
|
|
|
|
class NoAnalyzeAgent(BaseAgent):
|
|
async def decide_activity(self) -> ActivityType:
|
|
return ActivityType.REST
|
|
|
|
with pytest.raises(TypeError):
|
|
NoAnalyzeAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_decide_activity_returns_activity(self):
|
|
"""Test that decide_activity returns an ActivityType."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
result = asyncio.run(agent.decide_activity())
|
|
|
|
assert isinstance(result, ActivityType)
|
|
|
|
def test_analyze_returns_dict(self):
|
|
"""Test that analyze returns a dictionary."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
result = asyncio.run(agent.analyze("AAPL"))
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["symbol"] == "AAPL"
|
|
|
|
|
|
class TestGetStatusDict:
|
|
"""Test get_status_dict method."""
|
|
|
|
def test_status_dict_contains_required_fields(self):
|
|
"""Test that status dict has all required fields."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
status = agent.get_status_dict()
|
|
|
|
assert status["agent_id"] == "test-agent"
|
|
assert status["balance"] == 10000.0
|
|
assert "status" in status
|
|
assert status["skill_level"] == 0.5
|
|
assert status["win_rate"] == 0.5
|
|
assert status["total_trades"] == 0
|
|
assert status["unlocked_factors"] == 0
|
|
assert status["is_bankrupt"] is False
|
|
|
|
def test_status_dict_reflects_state(self):
|
|
"""Test that status dict reflects current state."""
|
|
agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0)
|
|
|
|
agent.record_trade(is_win=True, pnl=100.0)
|
|
agent.unlock_factor("test_factor", cost=100.0)
|
|
|
|
status = agent.get_status_dict()
|
|
|
|
assert status["total_trades"] == 1
|
|
assert status["win_rate"] == 1.0
|
|
assert status["unlocked_factors"] == 1
|
|
|
|
|
|
class TestRepr:
|
|
"""Test __repr__ method."""
|
|
|
|
def test_repr_contains_key_info(self):
|
|
"""Test that repr contains key information."""
|
|
agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
repr_str = repr(agent)
|
|
|
|
assert "ConcreteBaseAgent" in repr_str or "Agent" in repr_str
|
|
assert "test-agent" in repr_str
|
|
assert "$10,000.00" in repr_str or "$10000" in repr_str or "10000" in repr_str
|
|
assert "50.0%" in repr_str or "50%" in repr_str or "0.5" in repr_str
|
|
|
|
|
|
class TestAgentState:
|
|
"""Test AgentState dataclass."""
|
|
|
|
def test_default_state(self):
|
|
"""Test default AgentState values."""
|
|
state = AgentState(agent_id="test")
|
|
|
|
assert state.agent_id == "test"
|
|
assert state.skill_level == 0.5
|
|
assert state.win_rate == 0.5
|
|
assert state.total_trades == 0
|
|
assert state.winning_trades == 0
|
|
assert state.unlocked_factors == []
|
|
assert state.current_activity is None
|
|
assert state.is_bankrupt is False
|
|
|
|
def test_state_with_custom_values(self):
|
|
"""Test AgentState with custom values."""
|
|
state = AgentState(
|
|
agent_id="test",
|
|
skill_level=0.9,
|
|
win_rate=0.75,
|
|
total_trades=100,
|
|
)
|
|
|
|
assert state.skill_level == 0.9
|
|
assert state.win_rate == 0.75
|
|
assert state.total_trades == 100
|
|
|
|
|
|
class TestActivityType:
|
|
"""Test ActivityType enum."""
|
|
|
|
def test_activity_types(self):
|
|
"""Test all activity type values."""
|
|
assert ActivityType.TRADE == "trade"
|
|
assert ActivityType.LEARN == "learn"
|
|
assert ActivityType.ANALYZE == "analyze"
|
|
assert ActivityType.REST == "rest"
|
|
assert ActivityType.PAPER_TRADE == "paper_trade"
|
|
|
|
def test_activity_type_comparison(self):
|
|
"""Test activity type comparison."""
|
|
assert ActivityType.TRADE == "trade"
|
|
assert ActivityType.TRADE != "learn"
|