"""Unit tests for BaseAgent abstract base class. This module tests the BaseAgent class including initialization, economic tracking integration, event hooks, and state management. """ import asyncio from typing import Any, Dict from unittest.mock import MagicMock import pytest from openclaw.agents.base import ( ActivityType, AgentState, BaseAgent, EventCallback, ) from openclaw.core.economy import SurvivalStatus class ConcreteBaseAgent(BaseAgent): """Concrete agent implementation for testing BaseAgent (not a pytest test class).""" async def decide_activity(self) -> ActivityType: """Return default activity.""" return ActivityType.ANALYZE async def analyze(self, symbol: str) -> Dict[str, Any]: """Return default analysis.""" return {"symbol": symbol, "signal": "hold"} class TestBaseAgentInitialization: """Test agent initialization.""" def test_default_initialization(self): """Test agent with default parameters.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) assert agent.agent_id == "test-agent" assert agent.balance == 10000.0 assert agent.skill_level == 0.5 # Default assert agent.state.agent_id == "test-agent" assert agent.state.skill_level == 0.5 assert agent.state.win_rate == 0.5 assert agent.state.total_trades == 0 assert agent.state.winning_trades == 0 assert agent.state.unlocked_factors == [] assert agent.state.current_activity is None assert agent.state.is_bankrupt is False def test_custom_initialization(self): """Test agent with custom skill level.""" agent = ConcreteBaseAgent( agent_id="custom-agent", initial_capital=5000.0, skill_level=0.8, ) assert agent.agent_id == "custom-agent" assert agent.balance == 5000.0 assert agent.skill_level == 0.8 assert agent.state.skill_level == 0.8 def test_economic_tracker_integration(self): """Test that economic tracker is properly initialized.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) assert agent.economic_tracker.agent_id == "test-agent" assert agent.economic_tracker.initial_capital == 10000.0 assert agent.economic_tracker.balance == 10000.0 def test_event_hooks_initialized(self): """Test that event hooks are initialized.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) assert "on_trade" in agent._event_hooks assert "on_learn" in agent._event_hooks assert "on_bankrupt" in agent._event_hooks assert "on_level_up" in agent._event_hooks assert "on_factor_unlock" in agent._event_hooks # All should start empty assert agent._event_hooks["on_trade"] == [] assert agent._event_hooks["on_learn"] == [] def test_logger_initialized(self): """Test that logger is properly bound.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) assert agent.logger is not None class TestBaseAgentProperties: """Test agent properties.""" def test_balance_property(self): """Test balance property reflects economic tracker.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) assert agent.balance == 10000.0 # Modify through tracker agent.economic_tracker.calculate_decision_cost( tokens_input=1000, tokens_output=500 ) assert agent.balance < 10000.0 def test_survival_status_property(self): """Test survival_status property.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) # At 100%, status is struggling (>=80% threshold) assert agent.survival_status == SurvivalStatus.STRUGGLING # Boost balance agent.economic_tracker.balance = 16000.0 assert agent.survival_status == SurvivalStatus.THRIVING def test_skill_level_property(self): """Test skill_level property.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) assert agent.skill_level == 0.5 agent.state.skill_level = 0.9 assert agent.skill_level == 0.9 def test_win_rate_property(self): """Test win_rate property.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) assert agent.win_rate == 0.5 agent.record_trade(is_win=True, pnl=100.0) assert agent.win_rate == 1.0 class TestCanAfford: """Test can_afford method.""" def test_can_afford_with_safety_buffer(self): """Test affordability check with default 20% safety buffer.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) # With 20% buffer, can afford amount up to balance/1.2 # balance/1.2 = 10000/1.2 = 8333.33 assert agent.can_afford(8000.0) is True assert agent.can_afford(8333.0) is True assert agent.can_afford(8500.0) is False def test_can_afford_custom_buffer(self): """Test affordability with custom safety buffer.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) # 50% buffer assert agent.can_afford(6000.0, safety_buffer=1.5) is True assert agent.can_afford(7000.0, safety_buffer=1.5) is False # No buffer assert agent.can_afford(10000.0, safety_buffer=1.0) is True assert agent.can_afford(10001.0, safety_buffer=1.0) is False def test_cannot_afford_when_bankrupt(self): """Test that bankrupt agent cannot afford anything.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.economic_tracker.balance = 0.0 assert agent.can_afford(1.0) is False class TestCheckSurvival: """Test check_survival method.""" def test_survival_when_stable(self): """Test survival check when agent is not bankrupt.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) assert agent.check_survival() is True assert agent.state.is_bankrupt is False def test_bankruptcy_detection(self): """Test bankruptcy detection.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.economic_tracker.balance = 1000.0 # Below 30% threshold assert agent.check_survival() is False assert agent.state.is_bankrupt is True def test_bankrupt_event_triggered_once(self): """Test that on_bankrupt event is triggered only once.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) # Track event calls calls = [] def on_bankrupt(agent_ref, **kwargs): calls.append(1) agent.register_hook("on_bankrupt", on_bankrupt) # Set bankrupt agent.economic_tracker.balance = 1000.0 # First check agent.check_survival() assert len(calls) == 1 # Second check - should not trigger again agent.check_survival() assert len(calls) == 1 class TestRecordTrade: """Test record_trade method.""" def test_record_winning_trade(self): """Test recording a winning trade.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.record_trade(is_win=True, pnl=150.0) assert agent.state.total_trades == 1 assert agent.state.winning_trades == 1 assert agent.state.win_rate == 1.0 def test_record_losing_trade(self): """Test recording a losing trade.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.record_trade(is_win=False, pnl=-100.0) assert agent.state.total_trades == 1 assert agent.state.winning_trades == 0 assert agent.state.win_rate == 0.0 def test_win_rate_calculation(self): """Test win rate calculation across multiple trades.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.record_trade(is_win=True, pnl=100.0) agent.record_trade(is_win=False, pnl=-50.0) agent.record_trade(is_win=True, pnl=75.0) agent.record_trade(is_win=True, pnl=120.0) # 3 wins out of 4 = 75% assert agent.state.total_trades == 4 assert agent.state.winning_trades == 3 assert agent.state.win_rate == 0.75 def test_trade_event_triggered(self): """Test that on_trade event is triggered.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) event_data = {} def on_trade(agent_ref, **kwargs): event_data.update(kwargs) agent.register_hook("on_trade", on_trade) agent.record_trade(is_win=True, pnl=100.0) assert event_data.get("is_win") is True assert event_data.get("pnl") == 100.0 class TestImproveSkill: """Test improve_skill method.""" def test_skill_improvement(self): """Test skill level improvement.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5) agent.improve_skill(0.2) assert agent.skill_level == 0.7 def test_skill_capped_at_one(self): """Test that skill cannot exceed 1.0.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.9) agent.improve_skill(0.2) assert agent.skill_level == 1.0 def test_no_improvement_when_already_max(self): """Test no improvement when already at max level.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=1.0) event_triggered = False def on_level_up(**kwargs): nonlocal event_triggered event_triggered = True agent.register_hook("on_level_up", on_level_up) agent.improve_skill(0.1) # Event should not trigger when no actual improvement assert event_triggered is False def test_level_up_event_triggered(self): """Test that on_level_up event is triggered.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0, skill_level=0.5) event_data = {} def on_level_up(agent_ref, **kwargs): event_data.update(kwargs) agent.register_hook("on_level_up", on_level_up) agent.improve_skill(0.1) assert event_data.get("old_level") == 0.5 class TestUnlockFactor: """Test unlock_factor method.""" def test_unlock_new_factor(self): """Test unlocking a new factor.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) result = agent.unlock_factor("momentum", cost=500.0) assert result is True assert "momentum" in agent.state.unlocked_factors assert agent.balance == 9500.0 def test_unlock_already_unlocked(self): """Test unlocking an already unlocked factor returns True.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.unlock_factor("momentum", cost=500.0) result = agent.unlock_factor("momentum", cost=500.0) assert result is True # Should not deduct cost again assert agent.balance == 9500.0 def test_cannot_afford_factor(self): """Test unlocking when cannot afford.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=1000.0) result = agent.unlock_factor("expensive", cost=5000.0) assert result is False assert "expensive" not in agent.state.unlocked_factors def test_factor_unlock_event_triggered(self): """Test that on_factor_unlock event is triggered.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) event_data = {} def on_factor_unlock(agent_ref, **kwargs): event_data.update(kwargs) agent.register_hook("on_factor_unlock", on_factor_unlock) agent.unlock_factor("momentum", cost=500.0) assert event_data.get("factor_name") == "momentum" assert event_data.get("cost") == 500.0 class TestIsFactorUnlocked: """Test is_factor_unlocked method.""" def test_factor_unlocked(self): """Test checking unlocked factor.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.unlock_factor("momentum", cost=100.0) assert agent.is_factor_unlocked("momentum") is True def test_factor_not_unlocked(self): """Test checking factor not unlocked.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) assert agent.is_factor_unlocked("momentum") is False class TestEventHooks: """Test event hook system.""" def test_register_hook(self): """Test registering event hooks.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) callback = MagicMock() agent.register_hook("on_trade", callback) assert callback in agent._event_hooks["on_trade"] def test_unregister_hook(self): """Test unregistering event hooks.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) callback = MagicMock() agent.register_hook("on_trade", callback) agent.unregister_hook("on_trade", callback) assert callback not in agent._event_hooks["on_trade"] def test_register_unknown_event_raises(self): """Test that registering unknown event raises ValueError.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) with pytest.raises(ValueError, match="Unknown event"): agent.register_hook("on_unknown_event", MagicMock()) def test_event_callback_receives_agent(self): """Test that callback receives agent reference.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) received_agent = None def callback(agent_ref, **kwargs): nonlocal received_agent received_agent = agent_ref agent.register_hook("on_trade", callback) agent.record_trade(is_win=True, pnl=100.0) assert received_agent is agent def test_multiple_hooks_same_event(self): """Test multiple hooks for the same event.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) calls = [] def callback1(agent_ref, **kwargs): calls.append("callback1") def callback2(agent_ref, **kwargs): calls.append("callback2") agent.register_hook("on_trade", callback1) agent.register_hook("on_trade", callback2) agent.record_trade(is_win=True, pnl=100.0) assert "callback1" in calls assert "callback2" in calls def test_hook_error_handling(self): """Test that hook errors don't stop other hooks.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) calls = [] def error_callback(agent_ref, **kwargs): raise ValueError("Test error") def good_callback(agent_ref, **kwargs): calls.append("good") agent.register_hook("on_trade", error_callback) agent.register_hook("on_trade", good_callback) # Should not raise agent.record_trade(is_win=True, pnl=100.0) assert "good" in calls class TestAbstractMethods: """Test abstract method requirements.""" def test_cannot_instantiate_base(self): """Test that BaseAgent cannot be instantiated directly.""" class IncompleteAgent(BaseAgent): pass with pytest.raises(TypeError): IncompleteAgent(agent_id="test", initial_capital=10000.0) def test_decide_activity_must_be_implemented(self): """Test that decide_activity must be implemented.""" class NoDecideAgent(BaseAgent): async def analyze(self, symbol: str) -> Dict[str, Any]: return {} with pytest.raises(TypeError): NoDecideAgent(agent_id="test", initial_capital=10000.0) def test_analyze_must_be_implemented(self): """Test that analyze must be implemented.""" class NoAnalyzeAgent(BaseAgent): async def decide_activity(self) -> ActivityType: return ActivityType.REST with pytest.raises(TypeError): NoAnalyzeAgent(agent_id="test", initial_capital=10000.0) def test_decide_activity_returns_activity(self): """Test that decide_activity returns an ActivityType.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) result = asyncio.run(agent.decide_activity()) assert isinstance(result, ActivityType) def test_analyze_returns_dict(self): """Test that analyze returns a dictionary.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) result = asyncio.run(agent.analyze("AAPL")) assert isinstance(result, dict) assert result["symbol"] == "AAPL" class TestGetStatusDict: """Test get_status_dict method.""" def test_status_dict_contains_required_fields(self): """Test that status dict has all required fields.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) status = agent.get_status_dict() assert status["agent_id"] == "test-agent" assert status["balance"] == 10000.0 assert "status" in status assert status["skill_level"] == 0.5 assert status["win_rate"] == 0.5 assert status["total_trades"] == 0 assert status["unlocked_factors"] == 0 assert status["is_bankrupt"] is False def test_status_dict_reflects_state(self): """Test that status dict reflects current state.""" agent = ConcreteBaseAgent(agent_id="test", initial_capital=10000.0) agent.record_trade(is_win=True, pnl=100.0) agent.unlock_factor("test_factor", cost=100.0) status = agent.get_status_dict() assert status["total_trades"] == 1 assert status["win_rate"] == 1.0 assert status["unlocked_factors"] == 1 class TestRepr: """Test __repr__ method.""" def test_repr_contains_key_info(self): """Test that repr contains key information.""" agent = ConcreteBaseAgent(agent_id="test-agent", initial_capital=10000.0) repr_str = repr(agent) assert "ConcreteBaseAgent" in repr_str or "Agent" in repr_str assert "test-agent" in repr_str assert "$10,000.00" in repr_str or "$10000" in repr_str or "10000" in repr_str assert "50.0%" in repr_str or "50%" in repr_str or "0.5" in repr_str class TestAgentState: """Test AgentState dataclass.""" def test_default_state(self): """Test default AgentState values.""" state = AgentState(agent_id="test") assert state.agent_id == "test" assert state.skill_level == 0.5 assert state.win_rate == 0.5 assert state.total_trades == 0 assert state.winning_trades == 0 assert state.unlocked_factors == [] assert state.current_activity is None assert state.is_bankrupt is False def test_state_with_custom_values(self): """Test AgentState with custom values.""" state = AgentState( agent_id="test", skill_level=0.9, win_rate=0.75, total_trades=100, ) assert state.skill_level == 0.9 assert state.win_rate == 0.75 assert state.total_trades == 100 class TestActivityType: """Test ActivityType enum.""" def test_activity_types(self): """Test all activity type values.""" assert ActivityType.TRADE == "trade" assert ActivityType.LEARN == "learn" assert ActivityType.ANALYZE == "analyze" assert ActivityType.REST == "rest" assert ActivityType.PAPER_TRADE == "paper_trade" def test_activity_type_comparison(self): """Test activity type comparison.""" assert ActivityType.TRADE == "trade" assert ActivityType.TRADE != "learn"