stock/tests/unit/test_economy.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

493 lines
17 KiB
Python

"""Unit tests for TradingEconomicTracker."""
import json
import tempfile
from pathlib import Path
import pytest
from openclaw.core.economy import (
BalanceHistoryEntry,
EconomicTrackerState,
SurvivalStatus,
TradeCostResult,
TradingEconomicTracker,
)
class TestTradingEconomicTrackerInitialization:
"""Test tracker initialization."""
def test_default_initialization(self):
"""Test tracker with default parameters."""
tracker = TradingEconomicTracker(agent_id="test-agent")
assert tracker.agent_id == "test-agent"
assert tracker.initial_capital == 10000.0
assert tracker.balance == 10000.0
assert tracker.token_costs == 0.0
assert tracker.trade_costs == 0.0
assert tracker.realized_pnl == 0.0
def test_custom_initialization(self):
"""Test tracker with custom parameters."""
tracker = TradingEconomicTracker(
agent_id="custom-agent",
initial_capital=5000.0,
token_cost_per_1m_input=3.0,
token_cost_per_1m_output=12.0,
trade_fee_rate=0.002,
data_cost_per_call=0.02,
)
assert tracker.agent_id == "custom-agent"
assert tracker.initial_capital == 5000.0
assert tracker.balance == 5000.0
assert tracker.token_cost_per_1m_input == 3.0
assert tracker.token_cost_per_1m_output == 12.0
assert tracker.trade_fee_rate == 0.002
assert tracker.data_cost_per_call == 0.02
def test_thresholds_calculated_correctly(self):
"""Test that survival thresholds are calculated from initial capital."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
assert tracker.thresholds["thriving"] == 15000.0 # 1.5x
assert tracker.thresholds["stable"] == 11000.0 # 1.1x
assert tracker.thresholds["struggling"] == 8000.0 # 0.8x
assert tracker.thresholds["bankrupt"] == 3000.0 # 0.3x
def test_initial_balance_history(self):
"""Test that initial balance history is recorded."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
history = tracker.get_balance_history()
assert len(history) == 1
assert history[0].balance == 10000.0
assert history[0].change == 0.0
assert history[0].reason == "Initial capital"
class TestCalculateDecisionCost:
"""Test decision cost calculation."""
def test_token_cost_calculation(self):
"""Test LLM token cost calculation."""
tracker = TradingEconomicTracker(agent_id="test")
initial_balance = tracker.balance
# 1000 input tokens, 500 output tokens, 0 data calls
cost = tracker.calculate_decision_cost(
tokens_input=1000, tokens_output=500, market_data_calls=0
)
# Expected: (1000/1e6 * 2.5) + (500/1e6 * 10.0) = 0.0025 + 0.005 = 0.0075
expected_cost = round(1000 / 1e6 * 2.5 + 500 / 1e6 * 10.0, 4)
assert cost == expected_cost
assert tracker.token_costs == expected_cost
assert tracker.balance == round(initial_balance - expected_cost, 4)
def test_market_data_cost(self):
"""Test market data API call cost."""
tracker = TradingEconomicTracker(agent_id="test", data_cost_per_call=0.01)
cost = tracker.calculate_decision_cost(
tokens_input=0, tokens_output=0, market_data_calls=5
)
# Expected: 5 * 0.01 = 0.05
assert cost == 0.05
def test_combined_costs(self):
"""Test combined token and data costs."""
tracker = TradingEconomicTracker(agent_id="test")
cost = tracker.calculate_decision_cost(
tokens_input=1000000, # 1M tokens
tokens_output=500000, # 500K tokens
market_data_calls=10,
)
# Expected: (1.0 * 2.5) + (0.5 * 10.0) + (10 * 0.01) = 2.5 + 5.0 + 0.1 = 7.6
expected_cost = round(2.5 + 5.0 + 0.1, 4)
assert cost == expected_cost
def test_precision_to_four_decimals(self):
"""Test that costs are calculated with 4 decimal precision."""
tracker = TradingEconomicTracker(agent_id="test")
cost = tracker.calculate_decision_cost(
tokens_input=333333, tokens_output=333333, market_data_calls=3
)
# Should be rounded to 4 decimal places
assert len(str(cost).split(".")[-1]) <= 4
def test_balance_history_updated(self):
"""Test that balance history is updated after decision cost."""
tracker = TradingEconomicTracker(agent_id="test")
tracker.calculate_decision_cost(
tokens_input=1000, tokens_output=500, market_data_calls=2
)
history = tracker.get_balance_history()
assert len(history) == 2
assert "Decision cost" in history[1].reason
class TestCalculateTradeCost:
"""Test trade cost calculation."""
def test_winning_trade(self):
"""Test cost calculation for winning trade."""
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
initial_balance = tracker.balance
result = tracker.calculate_trade_cost(
trade_value=10000.0, is_win=True, win_amount=500.0, loss_amount=0.0
)
# Expected fee: 10000 * 0.001 = 10.0
# Expected PnL: 500 - 10 = 490.0
expected_fee = 10.0
expected_pnl = 490.0
assert isinstance(result, TradeCostResult)
assert result.fee == expected_fee
assert result.pnl == expected_pnl
assert result.balance == round(initial_balance + expected_pnl, 4)
assert result.status == tracker.get_survival_status()
assert tracker.trade_costs == expected_fee
assert tracker.realized_pnl == 500.0
def test_losing_trade(self):
"""Test cost calculation for losing trade."""
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
initial_balance = tracker.balance
result = tracker.calculate_trade_cost(
trade_value=10000.0, is_win=False, win_amount=0.0, loss_amount=200.0
)
# Expected fee: 10000 * 0.001 = 10.0
# Expected PnL: -200 - 10 = -210.0
expected_fee = 10.0
expected_pnl = -210.0
assert result.fee == expected_fee
assert result.pnl == expected_pnl
assert result.balance == round(initial_balance + expected_pnl, 4)
assert tracker.realized_pnl == -200.0 # Loss amount recorded (negative)
def test_trade_fee_accumulation(self):
"""Test that trade fees accumulate correctly."""
tracker = TradingEconomicTracker(agent_id="test", trade_fee_rate=0.001)
tracker.calculate_trade_cost(
trade_value=10000.0, is_win=True, win_amount=100.0, loss_amount=0.0
)
tracker.calculate_trade_cost(
trade_value=5000.0, is_win=False, win_amount=0.0, loss_amount=50.0
)
# Total fees: 10 + 5 = 15
assert tracker.trade_costs == 15.0
class TestGetSurvivalStatus:
"""Test survival status determination."""
def test_thriving_status(self):
"""Test THRIVING status at 150%+."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
tracker.balance = 15000.0
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
tracker.balance = 20000.0
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
def test_stable_status(self):
"""Test STABLE status at 110%-149%."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
tracker.balance = 11000.0
assert tracker.get_survival_status() == SurvivalStatus.STABLE
tracker.balance = 14000.0
assert tracker.get_survival_status() == SurvivalStatus.STABLE
def test_struggling_status(self):
"""Test STRUGGLING status at 80%-109%."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
tracker.balance = 8000.0
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
tracker.balance = 10000.0
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
def test_critical_status(self):
"""Test CRITICAL status at 30%-79%."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
tracker.balance = 3000.0
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
tracker.balance = 5000.0
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
def test_bankrupt_status(self):
"""Test BANKRUPT status below 30%."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
tracker.balance = 2999.99
assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT
tracker.balance = 0.0
assert tracker.get_survival_status() == SurvivalStatus.BANKRUPT
def test_boundary_conditions(self):
"""Test exact boundary values."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
# Test exact threshold values
tracker.balance = 15000.0 # thriving threshold
assert tracker.get_survival_status() == SurvivalStatus.THRIVING
tracker.balance = 11000.0 # stable threshold
assert tracker.get_survival_status() == SurvivalStatus.STABLE
tracker.balance = 8000.0 # struggling threshold
assert tracker.get_survival_status() == SurvivalStatus.STRUGGLING
tracker.balance = 3000.0 # bankrupt threshold
assert tracker.get_survival_status() == SurvivalStatus.CRITICAL
class TestBalanceHistory:
"""Test balance history tracking."""
def test_history_length(self):
"""Test that history grows with each transaction."""
tracker = TradingEconomicTracker(agent_id="test")
assert len(tracker.get_balance_history()) == 1 # Initial
tracker.calculate_decision_cost(tokens_input=1000, tokens_output=500)
assert len(tracker.get_balance_history()) == 2
tracker.calculate_trade_cost(
trade_value=1000.0, is_win=True, win_amount=50.0, loss_amount=0.0
)
assert len(tracker.get_balance_history()) == 3
def test_history_immutable(self):
"""Test that returned history doesn't affect internal state."""
tracker = TradingEconomicTracker(agent_id="test")
history = tracker.get_balance_history()
history.append(
BalanceHistoryEntry(
timestamp="test", balance=9999.0, change=0.0, reason="test"
)
)
assert len(tracker.get_balance_history()) == 1 # Unchanged
class TestPersistence:
"""Test save/load functionality."""
def test_save_to_file(self):
"""Test saving tracker state to JSONL file."""
tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0)
tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000)
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
filepath = f.name
try:
tracker.save_to_file(filepath)
assert Path(filepath).exists()
# Read and verify content
with open(filepath) as f:
line = f.readline().strip()
data = json.loads(line)
assert data["agent_id"] == "test-agent"
assert data["balance"] == tracker.balance
assert data["token_costs"] == tracker.token_costs
finally:
Path(filepath).unlink()
def test_load_from_file(self):
"""Test loading tracker state from JSONL file."""
tracker = TradingEconomicTracker(
agent_id="test-agent",
initial_capital=10000.0,
token_cost_per_1m_input=2.5,
token_cost_per_1m_output=10.0,
)
tracker.calculate_decision_cost(tokens_input=1000000, tokens_output=500000)
tracker.calculate_trade_cost(
trade_value=5000.0, is_win=True, win_amount=200.0, loss_amount=0.0
)
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
filepath = f.name
try:
tracker.save_to_file(filepath)
# Load and verify
loaded = TradingEconomicTracker.load_from_file(filepath)
assert loaded.agent_id == tracker.agent_id
assert loaded.initial_capital == tracker.initial_capital
assert loaded.balance == tracker.balance
assert loaded.token_costs == tracker.token_costs
assert loaded.trade_costs == tracker.trade_costs
assert loaded.realized_pnl == tracker.realized_pnl
assert len(loaded.get_balance_history()) == len(tracker.get_balance_history())
finally:
Path(filepath).unlink()
def test_load_latest_state(self):
"""Test that load_from_file returns the latest state."""
tracker1 = TradingEconomicTracker(agent_id="test", initial_capital=10000.0)
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
filepath = f.name
try:
tracker1.save_to_file(filepath)
# Modify tracker and save again
tracker1.calculate_decision_cost(tokens_input=1000000, tokens_output=0)
tracker1.save_to_file(filepath)
loaded = TradingEconomicTracker.load_from_file(filepath)
assert loaded.token_costs == tracker1.token_costs
finally:
Path(filepath).unlink()
def test_load_nonexistent_file(self):
"""Test loading from non-existent file raises error."""
with pytest.raises(FileNotFoundError):
TradingEconomicTracker.load_from_file("/nonexistent/path/file.jsonl")
def test_load_empty_file(self):
"""Test loading from empty file raises error."""
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".jsonl") as f:
filepath = f.name
try:
with pytest.raises(ValueError, match="empty"):
TradingEconomicTracker.load_from_file(filepath)
finally:
Path(filepath).unlink()
class TestProperties:
"""Test computed properties."""
def test_total_costs(self):
"""Test total_costs property."""
tracker = TradingEconomicTracker(agent_id="test")
assert tracker.total_costs == 0.0
tracker.token_costs = 10.0
tracker.trade_costs = 5.0
assert tracker.total_costs == 15.0
def test_net_profit(self):
"""Test net_profit property."""
tracker = TradingEconomicTracker(agent_id="test")
assert tracker.net_profit == 0.0
tracker.realized_pnl = 100.0
tracker.token_costs = 10.0
tracker.trade_costs = 5.0
assert tracker.net_profit == 85.0
class TestEdgeCases:
"""Test edge cases and error conditions."""
def test_balance_never_negative(self):
"""Test that balance never goes below zero."""
tracker = TradingEconomicTracker(agent_id="test", initial_capital=100.0)
# Large losing trade
tracker.calculate_trade_cost(
trade_value=100000.0, is_win=False, win_amount=0.0, loss_amount=50000.0
)
assert tracker.balance == 0.0
def test_zero_value_trade(self):
"""Test trade with zero value."""
tracker = TradingEconomicTracker(agent_id="test")
result = tracker.calculate_trade_cost(
trade_value=0.0, is_win=True, win_amount=0.0, loss_amount=0.0
)
assert result.fee == 0.0
assert result.pnl == 0.0
def test_repr(self):
"""Test string representation."""
tracker = TradingEconomicTracker(agent_id="test-agent", initial_capital=10000.0)
repr_str = repr(tracker)
assert "test-agent" in repr_str
assert "$10000.00" in repr_str or "10000.0" in repr_str
# At exactly initial capital, status is struggling (>=80% threshold)
assert "struggling" in repr_str
class TestPydanticModels:
"""Test Pydantic model validation."""
def test_balance_history_entry_validation(self):
"""Test BalanceHistoryEntry validation."""
entry = BalanceHistoryEntry(
timestamp="2024-01-01T00:00:00",
balance=100.0,
change=-10.0,
reason="Test",
)
assert entry.balance == 100.0
assert entry.change == -10.0
def test_trade_cost_result_validation(self):
"""Test TradeCostResult validation."""
result = TradeCostResult(
fee=10.0, pnl=100.0, balance=1000.0, status=SurvivalStatus.STABLE
)
assert result.fee == 10.0
assert result.status == SurvivalStatus.STABLE
def test_survival_status_enum(self):
"""Test SurvivalStatus enum values."""
assert SurvivalStatus.THRIVING == "🚀 thriving"
assert SurvivalStatus.STABLE == "💪 stable"
assert SurvivalStatus.STRUGGLING == "⚠️ struggling"
assert SurvivalStatus.CRITICAL == "🔴 critical"
assert SurvivalStatus.BANKRUPT == "💀 bankrupt"