stock/tests/test_backtest_basic.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

478 lines
15 KiB
Python

"""Basic tests for backtest engine.
This module contains tests for BacktestEngine initialization,
TradeRecord creation, and BacktestResult creation.
"""
from datetime import datetime, timedelta
import pytest
from openclaw.backtest.engine import (
BacktestEngine,
BacktestEvent,
BacktestResult,
CommissionModel,
EventType,
FixedCommissionModel,
FixedSlippageModel,
PercentageCommissionModel,
PercentageSlippageModel,
Position,
TradeRecord,
VolatilitySlippageModel,
)
class TestBacktestEngine:
"""Tests for BacktestEngine class."""
def test_engine_initialization(self):
"""Test BacktestEngine initialization with default parameters."""
start_date = datetime(2024, 1, 1)
end_date = datetime(2024, 12, 31)
initial_capital = 100000.0
engine = BacktestEngine(
initial_capital=initial_capital,
start_date=start_date,
end_date=end_date,
)
assert engine.initial_capital == initial_capital
assert engine.current_equity == initial_capital
assert engine.start_date == start_date
assert engine.end_date == end_date
assert engine.positions == {}
assert engine.trades == []
assert engine.equity_curve == [initial_capital]
def test_engine_initialization_with_custom_models(self):
"""Test BacktestEngine initialization with custom slippage and commission models."""
start_date = datetime(2024, 1, 1)
end_date = datetime(2024, 12, 31)
slippage_model = FixedSlippageModel(fixed_amount=0.02)
commission_model = FixedCommissionModel(fixed_amount=10.0)
engine = BacktestEngine(
initial_capital=50000.0,
start_date=start_date,
end_date=end_date,
slippage_model=slippage_model,
commission_model=commission_model,
)
assert isinstance(engine.slippage_model, FixedSlippageModel)
assert isinstance(engine.commission_model, FixedCommissionModel)
assert engine.slippage_model.fixed_amount == 0.02
assert engine.commission_model.fixed_amount == 10.0
def test_engine_reset(self):
"""Test BacktestEngine reset functionality."""
start_date = datetime(2024, 1, 1)
end_date = datetime(2024, 12, 31)
engine = BacktestEngine(
initial_capital=100000.0,
start_date=start_date,
end_date=end_date,
)
# Modify state
engine.current_equity = 50000.0
engine.equity_curve.append(105000.0)
# Reset
engine.reset()
assert engine.current_equity == engine.initial_capital
assert engine.positions == {}
assert engine.trades == []
assert engine.equity_curve == [engine.initial_capital]
def test_get_results_without_data(self):
"""Test getting results without running backtest."""
engine = BacktestEngine(
initial_capital=100000.0,
start_date=datetime(2024, 1, 1),
end_date=datetime(2024, 12, 31),
)
# Engine has initial equity_curve, so it can generate results
# Results will just show no change
result = engine.get_results()
assert result.total_return == 0.0
assert result.total_trades == 0
class TestTradeRecord:
"""Tests for TradeRecord dataclass."""
def test_trade_record_creation(self):
"""Test creating a TradeRecord with all fields."""
entry_time = datetime(2024, 1, 1, 10, 0)
exit_time = datetime(2024, 1, 2, 10, 0)
trade = TradeRecord(
symbol="AAPL",
entry_time=entry_time,
exit_time=exit_time,
entry_price=150.0,
exit_price=155.0,
quantity=100.0,
side="long",
pnl=500.0,
commission=5.0,
slippage=2.0,
)
assert trade.symbol == "AAPL"
assert trade.entry_time == entry_time
assert trade.exit_time == exit_time
assert trade.entry_price == 150.0
assert trade.exit_price == 155.0
assert trade.quantity == 100.0
assert trade.side == "long"
assert trade.pnl == 500.0
assert trade.commission == 5.0
assert trade.slippage == 2.0
def test_trade_record_total_cost(self):
"""Test TradeRecord total_cost property."""
trade = TradeRecord(
symbol="AAPL",
entry_time=datetime(2024, 1, 1),
exit_time=datetime(2024, 1, 2),
entry_price=150.0,
exit_price=155.0,
quantity=100.0,
side="long",
pnl=500.0,
commission=5.0,
slippage=2.0,
)
assert trade.total_cost == 7.0 # commission + slippage
def test_trade_record_net_pnl(self):
"""Test TradeRecord net_pnl property."""
trade = TradeRecord(
symbol="AAPL",
entry_time=datetime(2024, 1, 1),
exit_time=datetime(2024, 1, 2),
entry_price=150.0,
exit_price=155.0,
quantity=100.0,
side="long",
pnl=500.0,
commission=5.0,
slippage=2.0,
)
assert trade.net_pnl == 493.0 # pnl - total_cost
def test_trade_record_short_position(self):
"""Test TradeRecord with short position."""
trade = TradeRecord(
symbol="TSLA",
entry_time=datetime(2024, 1, 1),
exit_time=datetime(2024, 1, 2),
entry_price=200.0,
exit_price=190.0,
quantity=50.0,
side="short",
pnl=500.0,
commission=5.0,
slippage=1.0,
)
assert trade.side == "short"
assert trade.net_pnl == 494.0
class TestBacktestResult:
"""Tests for BacktestResult class."""
def test_backtest_result_creation(self):
"""Test creating a BacktestResult with all required fields."""
start_date = datetime(2024, 1, 1)
end_date = datetime(2024, 12, 31)
result = BacktestResult(
start_date=start_date,
end_date=end_date,
initial_capital=100000.0,
final_equity=110000.0,
total_return=10.0,
total_trades=50,
winning_trades=30,
losing_trades=20,
win_rate=60.0,
avg_win=500.0,
avg_loss=-200.0,
profit_factor=2.5,
sharpe_ratio=1.2,
max_drawdown=5.0,
max_drawdown_duration=10,
volatility=15.0,
calmar_ratio=2.0,
equity_curve=[100000.0, 101000.0, 110000.0],
)
assert result.start_date == start_date
assert result.end_date == end_date
assert result.initial_capital == 100000.0
assert result.final_equity == 110000.0
assert result.total_return == 10.0
assert result.total_trades == 50
assert result.winning_trades == 30
assert result.losing_trades == 20
assert result.win_rate == 60.0
def test_backtest_result_to_dict(self):
"""Test BacktestResult to_dict method."""
result = BacktestResult(
start_date=datetime(2024, 1, 1),
end_date=datetime(2024, 12, 31),
initial_capital=100000.0,
final_equity=110000.0,
total_return=10.0,
total_trades=50,
winning_trades=30,
losing_trades=20,
win_rate=60.0,
avg_win=500.0,
avg_loss=-200.0,
profit_factor=2.5,
sharpe_ratio=1.2,
max_drawdown=5.0,
max_drawdown_duration=10,
volatility=15.0,
calmar_ratio=2.0,
)
result_dict = result.to_dict()
assert isinstance(result_dict, dict)
assert result_dict["initial_capital"] == 100000.0
assert result_dict["total_return"] == "10.00%"
assert result_dict["win_rate"] == "60.00%"
assert "start_date" in result_dict
assert "end_date" in result_dict
def test_backtest_result_with_trades(self):
"""Test BacktestResult with trade records."""
trade1 = TradeRecord(
symbol="AAPL",
entry_time=datetime(2024, 1, 1),
exit_time=datetime(2024, 1, 2),
entry_price=150.0,
exit_price=155.0,
quantity=100.0,
side="long",
pnl=500.0,
commission=5.0,
slippage=2.0,
)
result = BacktestResult(
start_date=datetime(2024, 1, 1),
end_date=datetime(2024, 12, 31),
initial_capital=100000.0,
final_equity=110000.0,
total_return=10.0,
total_trades=1,
winning_trades=1,
losing_trades=0,
win_rate=100.0,
avg_win=493.0,
avg_loss=0.0,
profit_factor=float("inf"),
sharpe_ratio=1.5,
max_drawdown=0.0,
max_drawdown_duration=0,
volatility=10.0,
calmar_ratio=float("inf"),
trades=[trade1],
)
assert len(result.trades) == 1
assert result.trades[0].symbol == "AAPL"
class TestPosition:
"""Tests for Position dataclass."""
def test_position_creation(self):
"""Test creating a Position."""
position = Position(
symbol="AAPL",
quantity=100.0,
entry_price=150.0,
entry_time=datetime(2024, 1, 1, 10, 0),
side="long",
)
assert position.symbol == "AAPL"
assert position.quantity == 100.0
assert position.entry_price == 150.0
assert position.side == "long"
def test_position_market_value(self):
"""Test Position market_value property."""
position = Position(
symbol="AAPL",
quantity=100.0,
entry_price=150.0,
entry_time=datetime(2024, 1, 1),
side="long",
)
# market_value is a property that needs current_price
# Looking at the source, it's defined with @property but takes current_price
# This is a method-style property - need to access differently
result = position.market_value(160.0)
assert result == 16000.0
def test_position_unrealized_pnl_long(self):
"""Test Position unrealized_pnl for long position."""
position = Position(
symbol="AAPL",
quantity=100.0,
entry_price=150.0,
entry_time=datetime(2024, 1, 1),
side="long",
)
result1 = position.unrealized_pnl(160.0)
result2 = position.unrealized_pnl(140.0)
assert result1 == 1000.0
assert result2 == -1000.0
def test_position_unrealized_pnl_short(self):
"""Test Position unrealized_pnl for short position."""
position = Position(
symbol="AAPL",
quantity=100.0,
entry_price=150.0,
entry_time=datetime(2024, 1, 1),
side="short",
)
result1 = position.unrealized_pnl(140.0)
result2 = position.unrealized_pnl(160.0)
assert result1 == 1000.0
assert result2 == -1000.0
class TestBacktestEvent:
"""Tests for BacktestEvent dataclass."""
def test_event_creation(self):
"""Test creating a BacktestEvent."""
timestamp = datetime(2024, 1, 1, 10, 0)
data = {"price": 150.0, "volume": 1000}
event = BacktestEvent(
event_type=EventType.BAR_OPEN,
timestamp=timestamp,
data=data,
)
assert event.event_type == EventType.BAR_OPEN
assert event.timestamp == timestamp
assert event.data == data
def test_event_types(self):
"""Test all event types exist."""
assert EventType.BAR_OPEN.name == "BAR_OPEN"
assert EventType.BAR_CLOSE.name == "BAR_CLOSE"
assert EventType.SIGNAL.name == "SIGNAL"
assert EventType.ORDER.name == "ORDER"
assert EventType.TRADE.name == "TRADE"
assert EventType.END_OF_DATA.name == "END_OF_DATA"
class TestSlippageModels:
"""Tests for slippage models."""
def test_fixed_slippage_model(self):
"""Test FixedSlippageModel calculation."""
model = FixedSlippageModel(fixed_amount=0.01)
slippage = model.calculate_slippage(
price=100.0,
quantity=100.0,
side="buy",
volatility=0.02,
volume=100000.0,
)
assert slippage == 1.0 # 0.01 * 100
def test_percentage_slippage_model(self):
"""Test PercentageSlippageModel calculation."""
model = PercentageSlippageModel(percentage=0.001) # 0.1%
slippage = model.calculate_slippage(
price=100.0,
quantity=100.0,
side="buy",
volatility=0.02,
volume=100000.0,
)
expected = 100.0 * 100.0 * 0.001 # 10.0
assert slippage == expected
def test_volatility_slippage_model(self):
"""Test VolatilitySlippageModel calculation."""
model = VolatilitySlippageModel(
base_percentage=0.0005,
volatility_multiplier=1.0,
)
slippage = model.calculate_slippage(
price=100.0,
quantity=100.0,
side="buy",
volatility=0.02,
volume=100000.0,
)
# trade_value * base * (1 + vol * multiplier)
expected = 10000.0 * 0.0005 * (1 + 0.02 * 1.0)
assert abs(slippage - expected) < 0.0001 # Floating point comparison
class TestCommissionModels:
"""Tests for commission models."""
def test_fixed_commission_model(self):
"""Test FixedCommissionModel calculation."""
model = FixedCommissionModel(fixed_amount=5.0)
commission = model.calculate_commission(price=100.0, quantity=100.0)
assert commission == 5.0
def test_percentage_commission_model(self):
"""Test PercentageCommissionModel calculation."""
model = PercentageCommissionModel(
percentage=0.001, # 0.1%
min_commission=1.0,
)
commission = model.calculate_commission(price=100.0, quantity=10.0)
expected = max(100.0 * 10.0 * 0.001, 1.0) # max(1.0, 1.0)
assert commission == expected
def test_percentage_commission_with_max(self):
"""Test PercentageCommissionModel with maximum."""
model = PercentageCommissionModel(
percentage=0.01, # 1%
min_commission=1.0,
max_commission=50.0,
)
commission = model.calculate_commission(price=1000.0, quantity=100.0)
# trade_value = 100000, 1% = 1000, but max is 50
assert commission == 50.0