"""Unit tests for TradingEconomicTracker.""" import json import tempfile from pathlib import Path import pytest from openclaw.core.economy import ( BalanceHistoryEntry, EconomicTrackerState, SurvivalStatus, TradeCostResult, TradingEconomicTracker, ) class TestTradingEconomicTrackerInitialization: """Test tracker initialization.""" def test_default_initialization(self): """Test tracker with default parameters.""" tracker = TradingEconomicTracker(agent_id="test-agent") assert tracker.agent_id == "test-agent" assert tracker.initial_capital == 10000.0 assert tracker.balance == 10000.0 assert tracker.token_costs == 0.0 assert tracker.trade_costs == 0.0 assert tracker.realized_pnl == 0.0 def test_custom_initialization(self): """Test tracker with custom parameters.""" tracker = TradingEconomicTracker( agent_id="custom-agent", initial_capital=5000.0, token_cost_per_1m_input=3.0, token_cost_per_1m_output=12.0, trade_fee_rate=0.002, data_cost_per_call=0.02, ) assert tracker.agent_id == "custom-agent" assert tracker.initial_capital == 5000.0 assert tracker.balance == 5000.0 assert tracker.token_cost_per_1m_input == 3.0 assert tracker.token_cost_per_1m_output == 12.0 assert tracker.trade_fee_rate == 0.002 assert tracker.data_cost_per_call == 0.02 def test_thresholds_calculated_correctly(self): """Test that survival thresholds are calculated from initial capital.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) assert tracker.thresholds["thriving"] == 15000.0 # 1.5x assert tracker.thresholds["stable"] == 11000.0 # 1.1x assert tracker.thresholds["struggling"] == 8000.0 # 0.8x assert tracker.thresholds["bankrupt"] == 3000.0 # 0.3x def test_initial_balance_history(self): """Test that initial balance history is recorded.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) history = tracker.get_balance_history() assert len(history) == 1 assert history[0].balance == 10000.0 assert history[0].change == 0.0 assert history[0].reason == "Initial capital" class TestCalculateDecisionCost: """Test decision cost calculation.""" def test_token_cost_calculation(self): """Test LLM token cost calculation.""" tracker = TradingEconomicTracker(agent_id="test") initial_balance = tracker.balance # 1000 input tokens, 500 output tokens, 0 data calls cost = tracker.calculate_decision_cost( tokens_input=1000, tokens_output=500, market_data_calls=0 ) # Expected: (1000/1e6 * 2.5) + (500/1e6 * 10.0) = 0.0025 + 0.005 = 0.0075 expected_cost = round(1000 / 1e6 * 2.5 + 500 / 1e6 * 10.0, 4) assert cost == expected_cost assert tracker.token_costs == expected_cost assert tracker.balance == round(initial_balance - expected_cost, 4) def test_market_data_cost(self): """Test market data API call cost.""" tracker = TradingEconomicTracker(agent_id="test", data_cost_per_call=0.01) cost = tracker.calculate_decision_cost( tokens_input=0, tokens_output=0, market_data_calls=5 ) # Expected: 5 * 0.01 = 0.05 assert cost == 0.05 def test_combined_costs(self): """Test combined token and data costs.""" tracker = TradingEconomicTracker(agent_id="test") cost = tracker.calculate_decision_cost( tokens_input=1000000, # 1M tokens tokens_output=500000, # 500K tokens market_data_calls=10, ) # Expected: (1.0 * 2.5) + (0.5 * 10.0) + (10 * 0.01) = 2.5 + 5.0 + 0.1 = 7.6 expected_cost = round(2.5 + 5.0 + 0.1, 4) assert cost == expected_cost def test_precision_to_four_decimals(self): """Test that costs are calculated with 4 decimal precision.""" tracker = TradingEconomicTracker(agent_id="test") cost = tracker.calculate_decision_cost( tokens_input=333333, tokens_output=333333, market_data_calls=3 ) # Should be rounded to 4 decimal places assert len(str(cost).split(".")[-1]) <= 4 def test_balance_history_updated(self): """Test that balance history is updated after decision cost.""" tracker = TradingEconomicTracker(agent_id="test") tracker.calculate_decision_cost( tokens_input=1000, tokens_output=500, market_data_calls=2 ) history = tracker.get_balance_history() assert len(history) == 2 assert "Decision cost" in history[1].reason class TestCalculateTradeCost: """Test trade cost calculation.""" def test_winning_trade(self): """Test cost calculation for winning trade.""" tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001) initial_balance = tracker.balance result = tracker.calculate_trade_cost( trade_value=10000.0, is_win=True, win_amount=500.0, loss_amount=0.0 ) # Expected fee: 10000 * 0.001 = 10.0 # Expected PnL: 500 - 10 = 490.0 expected_fee = 10.0 expected_pnl = 490.0 assert isinstance(result, TradeCostResult) assert result.fee == expected_fee assert result.pnl == expected_pnl assert result.balance == round(initial_balance + expected_pnl, 4) assert result.status == tracker.get_survival_status() assert tracker.trade_costs == expected_fee assert tracker.realized_pnl == 500.0 def test_losing_trade(self): """Test cost calculation for losing trade.""" tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001) initial_balance = tracker.balance result = tracker.calculate_trade_cost( trade_value=10000.0, is_win=False, win_amount=0.0, loss_amount=200.0 ) # Expected fee: 10000 * 0.001 = 10.0 # Expected PnL: -200 - 10 = -210.0 expected_fee = 10.0 expected_pnl = -210.0 assert result.fee == expected_fee assert result.pnl == expected_pnl assert result.balance == round(initial_balance + expected_pnl, 4) assert tracker.realized_pnl == -200.0 # Loss amount recorded (negative) def test_trade_fee_accumulation(self): """Test that trade fees accumulate correctly.""" tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001) tracker.calculate_trade_cost( trade_value=10000.0, is_win=True, win_amount=100.0, loss_amount=0.0 ) tracker.calculate_trade_cost( trade_value=5000.0, is_win=False, win_amount=0.0, loss_amount=50.0 ) # Total fees: 10 + 5 = 15 assert tracker.trade_costs == 15.0 class TestGetSurvivalStatus: """Test survival status determination.""" def test_thriving_status(self): """Test THRIVING status at 150%+.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) tracker.balance = 15000.0 assert tracker.get_survival_status() == SurvivalStatus.THRIVING tracker.balance = 20000.0 assert tracker.get_survival_status() == SurvivalStatus.THRIVING def test_stable_status(self): """Test STABLE status at 110%-149%.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) tracker.balance = 11000.0 assert tracker.get_survival_status() == SurvivalStatus.STABLE tracker.balance = 14000.0 assert tracker.get_survival_status() == SurvivalStatus.STABLE def test_struggling_status(self): """Test STRUGGLING status at 80%-109%.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) tracker.balance = 8000.0 assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING tracker.balance = 10000.0 assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING def test_critical_status(self): """Test CRITICAL status at 30%-79%.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) tracker.balance = 3000.0 assert tracker.get_survival_status() == SurvivalStatus.CRITICAL tracker.balance = 5000.0 assert tracker.get_survival_status() == SurvivalStatus.CRITICAL def test_bankrupt_status(self): """Test BANKRUPT status below 30%.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) tracker.balance = 2999.99 assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT tracker.balance = 0.0 assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT def test_boundary_conditions(self): """Test exact boundary values.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) # Test exact threshold values tracker.balance = 15000.0 # thriving threshold assert tracker.get_survival_status() == SurvivalStatus.THRIVING tracker.balance = 11000.0 # stable threshold assert tracker.get_survival_status() == SurvivalStatus.STABLE tracker.balance = 8000.0 # struggling threshold assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING tracker.balance = 3000.0 # bankrupt threshold assert tracker.get_survival_status() == SurvivalStatus.CRITICAL class TestBalanceHistory: """Test balance history tracking.""" def test_history_length(self): """Test that history grows with each transaction.""" tracker = TradingEconomicTracker(agent_id="test") assert len(tracker.get_balance_history()) == 1 # Initial tracker.calculate_decision_cost(tokens_input=1000, tokens_output=500) assert len(tracker.get_balance_history()) == 2 tracker.calculate_trade_cost( trade_value=1000.0, is_win=True, win_amount=50.0, loss_amount=0.0 ) assert len(tracker.get_balance_history()) == 3 def test_history_immutable(self): """Test that returned history doesn't affect internal state.""" tracker = TradingEconomicTracker(agent_id="test") history = tracker.get_balance_history() history.append( BalanceHistoryEntry( timestamp="test", balance=9999.0, change=0.0, reason="test" ) ) assert len(tracker.get_balance_history()) == 1 # Unchanged class TestPersistence: """Test save/load functionality.""" def test_save_to_file(self): """Test saving tracker state to JSONL file.""" tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0) tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000) with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: filepath = f.name try: tracker.save_to_file(filepath) assert Path(filepath).exists() # Read and verify content with open(filepath) as f: line = f.readline().strip() data = json.loads(line) assert data["agent_id"] == "test-agent" assert data["balance"] == tracker.balance assert data["token_costs"] == tracker.token_costs finally: Path(filepath).unlink() def test_load_from_file(self): """Test loading tracker state from JSONL file.""" tracker = TradingEconomicTracker( agent_id="test-agent", initial_capital=10000.0, token_cost_per_1m_input=2.5, token_cost_per_1m_output=10.0, ) tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000) tracker.calculate_trade_cost( trade_value=5000.0, is_win=True, win_amount=200.0, loss_amount=0.0 ) with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: filepath = f.name try: tracker.save_to_file(filepath) # Load and verify loaded = TradingEconomicTracker.load_from_file(filepath) assert loaded.agent_id == tracker.agent_id assert loaded.initial_capital == tracker.initial_capital assert loaded.balance == tracker.balance assert loaded.token_costs == tracker.token_costs assert loaded.trade_costs == tracker.trade_costs assert loaded.realized_pnl == tracker.realized_pnl assert len(loaded.get_balance_history()) == len(tracker.get_balance_history()) finally: Path(filepath).unlink() def test_load_latest_state(self): """Test that load_from_file returns the latest state.""" tracker1 = TradingEconomicTracker(agent_id="test", initial_capital=10000.0) with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: filepath = f.name try: tracker1.save_to_file(filepath) # Modify tracker and save again tracker1.calculate_decision_cost(tokens_input=1000000, tokens_output=0) tracker1.save_to_file(filepath) loaded = TradingEconomicTracker.load_from_file(filepath) assert loaded.token_costs == tracker1.token_costs finally: Path(filepath).unlink() def test_load_nonexistent_file(self): """Test loading from non-existent file raises error.""" with pytest.raises(FileNotFoundError): TradingEconomicTracker.load_from_file("/nonexistent/path/file.jsonl") def test_load_empty_file(self): """Test loading from empty file raises error.""" with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f: filepath = f.name try: with pytest.raises(ValueError, match="empty"): TradingEconomicTracker.load_from_file(filepath) finally: Path(filepath).unlink() class TestProperties: """Test computed properties.""" def test_total_costs(self): """Test total_costs property.""" tracker = TradingEconomicTracker(agent_id="test") assert tracker.total_costs == 0.0 tracker.token_costs = 10.0 tracker.trade_costs = 5.0 assert tracker.total_costs == 15.0 def test_net_profit(self): """Test net_profit property.""" tracker = TradingEconomicTracker(agent_id="test") assert tracker.net_profit == 0.0 tracker.realized_pnl = 100.0 tracker.token_costs = 10.0 tracker.trade_costs = 5.0 assert tracker.net_profit == 85.0 class TestEdgeCases: """Test edge cases and error conditions.""" def test_balance_never_negative(self): """Test that balance never goes below zero.""" tracker = TradingEconomicTracker(agent_id="test", initial_capital=100.0) # Large losing trade tracker.calculate_trade_cost( trade_value=100000.0, is_win=False, win_amount=0.0, loss_amount=50000.0 ) assert tracker.balance == 0.0 def test_zero_value_trade(self): """Test trade with zero value.""" tracker = TradingEconomicTracker(agent_id="test") result = tracker.calculate_trade_cost( trade_value=0.0, is_win=True, win_amount=0.0, loss_amount=0.0 ) assert result.fee == 0.0 assert result.pnl == 0.0 def test_repr(self): """Test string representation.""" tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0) repr_str = repr(tracker) assert "test-agent" in repr_str assert "$10000.00" in repr_str or "10000.0" in repr_str # At exactly initial capital, status is struggling (>=80% threshold) assert "struggling" in repr_str class TestPydanticModels: """Test Pydantic model validation.""" def test_balance_history_entry_validation(self): """Test BalanceHistoryEntry validation.""" entry = BalanceHistoryEntry( timestamp="2024-01-01T00:00:00", balance=100.0, change=-10.0, reason="Test", ) assert entry.balance == 100.0 assert entry.change == -10.0 def test_trade_cost_result_validation(self): """Test TradeCostResult validation.""" result = TradeCostResult( fee=10.0, pnl=100.0, balance=1000.0, status=SurvivalStatus.STABLE ) assert result.fee == 10.0 assert result.status == SurvivalStatus.STABLE def test_survival_status_enum(self): """Test SurvivalStatus enum values.""" assert SurvivalStatus.THRIVING == "🚀 thriving" assert SurvivalStatus.STABLE == "💪 stable" assert SurvivalStatus.STRUGGLING == "⚠️ struggling" assert SurvivalStatus.CRITICAL == "🔴 critical" assert SurvivalStatus.BANKRUPT == "💀 bankrupt"