493 lines
17 KiB
Python
493 lines
17 KiB
Python
"""Unit tests for TradingEconomicTracker."""
|
|
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from openclaw.core.economy import (
|
|
BalanceHistoryEntry,
|
|
EconomicTrackerState,
|
|
SurvivalStatus,
|
|
TradeCostResult,
|
|
TradingEconomicTracker,
|
|
)
|
|
|
|
|
|
class TestTradingEconomicTrackerInitialization:
|
|
"""Test tracker initialization."""
|
|
|
|
def test_default_initialization(self):
|
|
"""Test tracker with default parameters."""
|
|
tracker = TradingEconomicTracker(agent_id="test-agent")
|
|
|
|
assert tracker.agent_id == "test-agent"
|
|
assert tracker.initial_capital == 10000.0
|
|
assert tracker.balance == 10000.0
|
|
assert tracker.token_costs == 0.0
|
|
assert tracker.trade_costs == 0.0
|
|
assert tracker.realized_pnl == 0.0
|
|
|
|
def test_custom_initialization(self):
|
|
"""Test tracker with custom parameters."""
|
|
tracker = TradingEconomicTracker(
|
|
agent_id="custom-agent",
|
|
initial_capital=5000.0,
|
|
token_cost_per_1m_input=3.0,
|
|
token_cost_per_1m_output=12.0,
|
|
trade_fee_rate=0.002,
|
|
data_cost_per_call=0.02,
|
|
)
|
|
|
|
assert tracker.agent_id == "custom-agent"
|
|
assert tracker.initial_capital == 5000.0
|
|
assert tracker.balance == 5000.0
|
|
assert tracker.token_cost_per_1m_input == 3.0
|
|
assert tracker.token_cost_per_1m_output == 12.0
|
|
assert tracker.trade_fee_rate == 0.002
|
|
assert tracker.data_cost_per_call == 0.02
|
|
|
|
def test_thresholds_calculated_correctly(self):
|
|
"""Test that survival thresholds are calculated from initial capital."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert tracker.thresholds["thriving"] == 15000.0 # 1.5x
|
|
assert tracker.thresholds["stable"] == 11000.0 # 1.1x
|
|
assert tracker.thresholds["struggling"] == 8000.0 # 0.8x
|
|
assert tracker.thresholds["bankrupt"] == 3000.0 # 0.3x
|
|
|
|
def test_initial_balance_history(self):
|
|
"""Test that initial balance history is recorded."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
|
|
history = tracker.get_balance_history()
|
|
assert len(history) == 1
|
|
assert history[0].balance == 10000.0
|
|
assert history[0].change == 0.0
|
|
assert history[0].reason == "Initial capital"
|
|
|
|
|
|
class TestCalculateDecisionCost:
|
|
"""Test decision cost calculation."""
|
|
|
|
def test_token_cost_calculation(self):
|
|
"""Test LLM token cost calculation."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
initial_balance = tracker.balance
|
|
|
|
# 1000 input tokens, 500 output tokens, 0 data calls
|
|
cost = tracker.calculate_decision_cost(
|
|
tokens_input=1000, tokens_output=500, market_data_calls=0
|
|
)
|
|
|
|
# Expected: (1000/1e6 * 2.5) + (500/1e6 * 10.0) = 0.0025 + 0.005 = 0.0075
|
|
expected_cost = round(1000 / 1e6 * 2.5 + 500 / 1e6 * 10.0, 4)
|
|
assert cost == expected_cost
|
|
assert tracker.token_costs == expected_cost
|
|
assert tracker.balance == round(initial_balance - expected_cost, 4)
|
|
|
|
def test_market_data_cost(self):
|
|
"""Test market data API call cost."""
|
|
tracker = TradingEconomicTracker(agent_id="test", data_cost_per_call=0.01)
|
|
|
|
cost = tracker.calculate_decision_cost(
|
|
tokens_input=0, tokens_output=0, market_data_calls=5
|
|
)
|
|
|
|
# Expected: 5 * 0.01 = 0.05
|
|
assert cost == 0.05
|
|
|
|
def test_combined_costs(self):
|
|
"""Test combined token and data costs."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
cost = tracker.calculate_decision_cost(
|
|
tokens_input=1000000, # 1M tokens
|
|
tokens_output=500000, # 500K tokens
|
|
market_data_calls=10,
|
|
)
|
|
|
|
# Expected: (1.0 * 2.5) + (0.5 * 10.0) + (10 * 0.01) = 2.5 + 5.0 + 0.1 = 7.6
|
|
expected_cost = round(2.5 + 5.0 + 0.1, 4)
|
|
assert cost == expected_cost
|
|
|
|
def test_precision_to_four_decimals(self):
|
|
"""Test that costs are calculated with 4 decimal precision."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
cost = tracker.calculate_decision_cost(
|
|
tokens_input=333333, tokens_output=333333, market_data_calls=3
|
|
)
|
|
|
|
# Should be rounded to 4 decimal places
|
|
assert len(str(cost).split(".")[-1]) <= 4
|
|
|
|
def test_balance_history_updated(self):
|
|
"""Test that balance history is updated after decision cost."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
tracker.calculate_decision_cost(
|
|
tokens_input=1000, tokens_output=500, market_data_calls=2
|
|
)
|
|
|
|
history = tracker.get_balance_history()
|
|
assert len(history) == 2
|
|
assert "Decision cost" in history[1].reason
|
|
|
|
|
|
class TestCalculateTradeCost:
|
|
"""Test trade cost calculation."""
|
|
|
|
def test_winning_trade(self):
|
|
"""Test cost calculation for winning trade."""
|
|
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
|
|
initial_balance = tracker.balance
|
|
|
|
result = tracker.calculate_trade_cost(
|
|
trade_value=10000.0, is_win=True, win_amount=500.0, loss_amount=0.0
|
|
)
|
|
|
|
# Expected fee: 10000 * 0.001 = 10.0
|
|
# Expected PnL: 500 - 10 = 490.0
|
|
expected_fee = 10.0
|
|
expected_pnl = 490.0
|
|
|
|
assert isinstance(result, TradeCostResult)
|
|
assert result.fee == expected_fee
|
|
assert result.pnl == expected_pnl
|
|
assert result.balance == round(initial_balance + expected_pnl, 4)
|
|
assert result.status == tracker.get_survival_status()
|
|
|
|
assert tracker.trade_costs == expected_fee
|
|
assert tracker.realized_pnl == 500.0
|
|
|
|
def test_losing_trade(self):
|
|
"""Test cost calculation for losing trade."""
|
|
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
|
|
|
|
initial_balance = tracker.balance
|
|
|
|
result = tracker.calculate_trade_cost(
|
|
trade_value=10000.0, is_win=False, win_amount=0.0, loss_amount=200.0
|
|
)
|
|
|
|
# Expected fee: 10000 * 0.001 = 10.0
|
|
# Expected PnL: -200 - 10 = -210.0
|
|
expected_fee = 10.0
|
|
expected_pnl = -210.0
|
|
|
|
assert result.fee == expected_fee
|
|
assert result.pnl == expected_pnl
|
|
assert result.balance == round(initial_balance + expected_pnl, 4)
|
|
|
|
assert tracker.realized_pnl == -200.0 # Loss amount recorded (negative)
|
|
|
|
def test_trade_fee_accumulation(self):
|
|
"""Test that trade fees accumulate correctly."""
|
|
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
|
|
|
|
tracker.calculate_trade_cost(
|
|
trade_value=10000.0, is_win=True, win_amount=100.0, loss_amount=0.0
|
|
)
|
|
tracker.calculate_trade_cost(
|
|
trade_value=5000.0, is_win=False, win_amount=0.0, loss_amount=50.0
|
|
)
|
|
|
|
# Total fees: 10 + 5 = 15
|
|
assert tracker.trade_costs == 15.0
|
|
|
|
|
|
class TestGetSurvivalStatus:
|
|
"""Test survival status determination."""
|
|
|
|
def test_thriving_status(self):
|
|
"""Test THRIVING status at 150%+."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
tracker.balance = 15000.0
|
|
|
|
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
|
|
|
|
tracker.balance = 20000.0
|
|
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
|
|
|
|
def test_stable_status(self):
|
|
"""Test STABLE status at 110%-149%."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
tracker.balance = 11000.0
|
|
|
|
assert tracker.get_survival_status() == SurvivalStatus.STABLE
|
|
|
|
tracker.balance = 14000.0
|
|
assert tracker.get_survival_status() == SurvivalStatus.STABLE
|
|
|
|
def test_struggling_status(self):
|
|
"""Test STRUGGLING status at 80%-109%."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
tracker.balance = 8000.0
|
|
|
|
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
|
|
|
|
tracker.balance = 10000.0
|
|
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
|
|
|
|
def test_critical_status(self):
|
|
"""Test CRITICAL status at 30%-79%."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
tracker.balance = 3000.0
|
|
|
|
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
|
|
|
|
tracker.balance = 5000.0
|
|
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
|
|
|
|
def test_bankrupt_status(self):
|
|
"""Test BANKRUPT status below 30%."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
tracker.balance = 2999.99
|
|
|
|
assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT
|
|
|
|
tracker.balance = 0.0
|
|
assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT
|
|
|
|
def test_boundary_conditions(self):
|
|
"""Test exact boundary values."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
|
|
# Test exact threshold values
|
|
tracker.balance = 15000.0 # thriving threshold
|
|
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
|
|
|
|
tracker.balance = 11000.0 # stable threshold
|
|
assert tracker.get_survival_status() == SurvivalStatus.STABLE
|
|
|
|
tracker.balance = 8000.0 # struggling threshold
|
|
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
|
|
|
|
tracker.balance = 3000.0 # bankrupt threshold
|
|
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
|
|
|
|
|
|
class TestBalanceHistory:
|
|
"""Test balance history tracking."""
|
|
|
|
def test_history_length(self):
|
|
"""Test that history grows with each transaction."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
assert len(tracker.get_balance_history()) == 1 # Initial
|
|
|
|
tracker.calculate_decision_cost(tokens_input=1000, tokens_output=500)
|
|
assert len(tracker.get_balance_history()) == 2
|
|
|
|
tracker.calculate_trade_cost(
|
|
trade_value=1000.0, is_win=True, win_amount=50.0, loss_amount=0.0
|
|
)
|
|
assert len(tracker.get_balance_history()) == 3
|
|
|
|
def test_history_immutable(self):
|
|
"""Test that returned history doesn't affect internal state."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
history = tracker.get_balance_history()
|
|
history.append(
|
|
BalanceHistoryEntry(
|
|
timestamp="test", balance=9999.0, change=0.0, reason="test"
|
|
)
|
|
)
|
|
|
|
assert len(tracker.get_balance_history()) == 1 # Unchanged
|
|
|
|
|
|
class TestPersistence:
|
|
"""Test save/load functionality."""
|
|
|
|
def test_save_to_file(self):
|
|
"""Test saving tracker state to JSONL file."""
|
|
tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0)
|
|
tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000)
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
|
|
filepath = f.name
|
|
|
|
try:
|
|
tracker.save_to_file(filepath)
|
|
|
|
assert Path(filepath).exists()
|
|
|
|
# Read and verify content
|
|
with open(filepath) as f:
|
|
line = f.readline().strip()
|
|
data = json.loads(line)
|
|
assert data["agent_id"] == "test-agent"
|
|
assert data["balance"] == tracker.balance
|
|
assert data["token_costs"] == tracker.token_costs
|
|
finally:
|
|
Path(filepath).unlink()
|
|
|
|
def test_load_from_file(self):
|
|
"""Test loading tracker state from JSONL file."""
|
|
tracker = TradingEconomicTracker(
|
|
agent_id="test-agent",
|
|
initial_capital=10000.0,
|
|
token_cost_per_1m_input=2.5,
|
|
token_cost_per_1m_output=10.0,
|
|
)
|
|
tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000)
|
|
tracker.calculate_trade_cost(
|
|
trade_value=5000.0, is_win=True, win_amount=200.0, loss_amount=0.0
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
|
|
filepath = f.name
|
|
|
|
try:
|
|
tracker.save_to_file(filepath)
|
|
|
|
# Load and verify
|
|
loaded = TradingEconomicTracker.load_from_file(filepath)
|
|
|
|
assert loaded.agent_id == tracker.agent_id
|
|
assert loaded.initial_capital == tracker.initial_capital
|
|
assert loaded.balance == tracker.balance
|
|
assert loaded.token_costs == tracker.token_costs
|
|
assert loaded.trade_costs == tracker.trade_costs
|
|
assert loaded.realized_pnl == tracker.realized_pnl
|
|
assert len(loaded.get_balance_history()) == len(tracker.get_balance_history())
|
|
finally:
|
|
Path(filepath).unlink()
|
|
|
|
def test_load_latest_state(self):
|
|
"""Test that load_from_file returns the latest state."""
|
|
tracker1 = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
|
|
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
|
|
filepath = f.name
|
|
|
|
try:
|
|
tracker1.save_to_file(filepath)
|
|
|
|
# Modify tracker and save again
|
|
tracker1.calculate_decision_cost(tokens_input=1000000, tokens_output=0)
|
|
tracker1.save_to_file(filepath)
|
|
|
|
loaded = TradingEconomicTracker.load_from_file(filepath)
|
|
|
|
assert loaded.token_costs == tracker1.token_costs
|
|
finally:
|
|
Path(filepath).unlink()
|
|
|
|
def test_load_nonexistent_file(self):
|
|
"""Test loading from non-existent file raises error."""
|
|
with pytest.raises(FileNotFoundError):
|
|
TradingEconomicTracker.load_from_file("/nonexistent/path/file.jsonl")
|
|
|
|
def test_load_empty_file(self):
|
|
"""Test loading from empty file raises error."""
|
|
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
|
|
filepath = f.name
|
|
|
|
try:
|
|
with pytest.raises(ValueError, match="empty"):
|
|
TradingEconomicTracker.load_from_file(filepath)
|
|
finally:
|
|
Path(filepath).unlink()
|
|
|
|
|
|
class TestProperties:
|
|
"""Test computed properties."""
|
|
|
|
def test_total_costs(self):
|
|
"""Test total_costs property."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
assert tracker.total_costs == 0.0
|
|
|
|
tracker.token_costs = 10.0
|
|
tracker.trade_costs = 5.0
|
|
|
|
assert tracker.total_costs == 15.0
|
|
|
|
def test_net_profit(self):
|
|
"""Test net_profit property."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
assert tracker.net_profit == 0.0
|
|
|
|
tracker.realized_pnl = 100.0
|
|
tracker.token_costs = 10.0
|
|
tracker.trade_costs = 5.0
|
|
|
|
assert tracker.net_profit == 85.0
|
|
|
|
|
|
class TestEdgeCases:
|
|
"""Test edge cases and error conditions."""
|
|
|
|
def test_balance_never_negative(self):
|
|
"""Test that balance never goes below zero."""
|
|
tracker = TradingEconomicTracker(agent_id="test", initial_capital=100.0)
|
|
|
|
# Large losing trade
|
|
tracker.calculate_trade_cost(
|
|
trade_value=100000.0, is_win=False, win_amount=0.0, loss_amount=50000.0
|
|
)
|
|
|
|
assert tracker.balance == 0.0
|
|
|
|
def test_zero_value_trade(self):
|
|
"""Test trade with zero value."""
|
|
tracker = TradingEconomicTracker(agent_id="test")
|
|
|
|
result = tracker.calculate_trade_cost(
|
|
trade_value=0.0, is_win=True, win_amount=0.0, loss_amount=0.0
|
|
)
|
|
|
|
assert result.fee == 0.0
|
|
assert result.pnl == 0.0
|
|
|
|
def test_repr(self):
|
|
"""Test string representation."""
|
|
tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0)
|
|
|
|
repr_str = repr(tracker)
|
|
|
|
assert "test-agent" in repr_str
|
|
assert "$10000.00" in repr_str or "10000.0" in repr_str
|
|
# At exactly initial capital, status is struggling (>=80% threshold)
|
|
assert "struggling" in repr_str
|
|
|
|
|
|
class TestPydanticModels:
|
|
"""Test Pydantic model validation."""
|
|
|
|
def test_balance_history_entry_validation(self):
|
|
"""Test BalanceHistoryEntry validation."""
|
|
entry = BalanceHistoryEntry(
|
|
timestamp="2024-01-01T00:00:00",
|
|
balance=100.0,
|
|
change=-10.0,
|
|
reason="Test",
|
|
)
|
|
|
|
assert entry.balance == 100.0
|
|
assert entry.change == -10.0
|
|
|
|
def test_trade_cost_result_validation(self):
|
|
"""Test TradeCostResult validation."""
|
|
result = TradeCostResult(
|
|
fee=10.0, pnl=100.0, balance=1000.0, status=SurvivalStatus.STABLE
|
|
)
|
|
|
|
assert result.fee == 10.0
|
|
assert result.status == SurvivalStatus.STABLE
|
|
|
|
def test_survival_status_enum(self):
|
|
"""Test SurvivalStatus enum values."""
|
|
assert SurvivalStatus.THRIVING == "🚀 thriving"
|
|
assert SurvivalStatus.STABLE == "💪 stable"
|
|
assert SurvivalStatus.STRUGGLING == "⚠️ struggling"
|
|
assert SurvivalStatus.CRITICAL == "🔴 critical"
|
|
assert SurvivalStatus.BANKRUPT == "💀 bankrupt"
|