stock/tests/unit/test_data_source.py
ZhangPeng 9aecdd036c Initial commit: OpenClaw Trading - AI多智能体量化交易系统
- 添加项目核心代码和配置
- 添加前端界面 (Next.js)
- 添加单元测试
- 更新 .gitignore 排除缓存和依赖
2026-02-27 03:47:40 +08:00

483 lines
17 KiB
Python

"""Unit tests for data source interface and implementations."""
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
import pandas as pd
import pytest
from openclaw.data import (
DataNotAvailableError,
DataSource,
DataSourceError,
Interval,
OHLCVData,
RealtimeQuote,
YahooFinanceDataSource,
)
class TestDataSourceInterface:
"""Tests for the abstract DataSource interface."""
def test_data_source_error_inheritance(self) -> None:
"""Test DataSourceError is an Exception."""
assert issubclass(DataSourceError, Exception)
def test_data_not_available_error_inheritance(self) -> None:
"""Test DataNotAvailableError inherits from DataSourceError."""
assert issubclass(DataNotAvailableError, DataSourceError)
def test_interval_enum_values(self) -> None:
"""Test Interval enum has expected values."""
assert Interval.MINUTE_1.value == "1m"
assert Interval.MINUTE_5.value == "5m"
assert Interval.MINUTE_15.value == "15m"
assert Interval.MINUTE_30.value == "30m"
assert Interval.HOUR_1.value == "1h"
assert Interval.HOUR_4.value == "4h"
assert Interval.DAY_1.value == "1d"
assert Interval.WEEK_1.value == "1wk"
assert Interval.MONTH_1.value == "1mo"
def test_ohlcv_data_creation(self) -> None:
"""Test OHLCVData dataclass creation."""
timestamp = datetime.now()
data = OHLCVData(
timestamp=timestamp,
open=100.0,
high=105.0,
low=99.0,
close=103.0,
volume=1000000.0,
)
assert data.timestamp == timestamp
assert data.open == 100.0
assert data.high == 105.0
assert data.low == 99.0
assert data.close == 103.0
assert data.volume == 1000000.0
def test_ohlcv_data_is_immutable(self) -> None:
"""Test OHLCVData is frozen (immutable)."""
timestamp = datetime.now()
data = OHLCVData(
timestamp=timestamp,
open=100.0,
high=105.0,
low=99.0,
close=103.0,
volume=1000000.0,
)
with pytest.raises(AttributeError):
data.close = 110.0 # type: ignore[misc]
def test_realtime_quote_creation(self) -> None:
"""Test RealtimeQuote dataclass creation."""
timestamp = datetime.now()
quote = RealtimeQuote(
symbol="AAPL",
price=150.0,
bid=149.95,
ask=150.05,
bid_size=100,
ask_size=200,
volume=50000000.0,
timestamp=timestamp,
)
assert quote.symbol == "AAPL"
assert quote.price == 150.0
assert quote.bid == 149.95
assert quote.ask == 150.05
assert quote.bid_size == 100
assert quote.ask_size == 200
assert quote.volume == 50000000.0
assert quote.timestamp == timestamp
def test_realtime_quote_is_immutable(self) -> None:
"""Test RealtimeQuote is frozen (immutable)."""
timestamp = datetime.now()
quote = RealtimeQuote(
symbol="AAPL",
price=150.0,
bid=149.95,
ask=150.05,
bid_size=100,
ask_size=200,
volume=50000000.0,
timestamp=timestamp,
)
with pytest.raises(AttributeError):
quote.price = 160.0 # type: ignore[misc]
class TestYahooFinanceDataSource:
"""Tests for YahooFinanceDataSource implementation."""
@pytest.fixture
def data_source(self) -> YahooFinanceDataSource:
"""Create a YahooFinanceDataSource instance for testing."""
return YahooFinanceDataSource(cache_ttl=60)
def test_initialization(self) -> None:
"""Test YahooFinanceDataSource initialization."""
source = YahooFinanceDataSource(cache_ttl=120)
assert source.name == "yahoo_finance"
assert source._available is True # Initial availability state
def test_default_cache_ttl(self) -> None:
"""Test default cache TTL is 60 seconds."""
source = YahooFinanceDataSource()
assert source._cache_ttl == 60
def test_custom_cache_ttl(self) -> None:
"""Test custom cache TTL can be set."""
source = YahooFinanceDataSource(cache_ttl=300)
assert source._cache_ttl == 300
def test_get_cache_key(self, data_source: YahooFinanceDataSource) -> None:
"""Test cache key generation."""
start = datetime(2024, 1, 1)
end = datetime(2024, 1, 31)
key = data_source._get_cache_key("AAPL", Interval.DAY_1, start, end)
assert "AAPL" in key
assert "1d" in key
assert "2024-01-01" in key
assert "2024-01-31" in key
def test_get_cache_key_with_none(self, data_source: YahooFinanceDataSource) -> None:
"""Test cache key generation with None dates."""
key = data_source._get_cache_key("MSFT", Interval.HOUR_1, None, None)
assert "MSFT" in key
assert "1h" in key
assert "None" in key
def test_is_cache_valid(self, data_source: YahooFinanceDataSource) -> None:
"""Test cache validity check."""
now = datetime.now()
# Recent cache entry should be valid
assert data_source._is_cache_valid(now) is True
# Old cache entry should be invalid
old_time = now - timedelta(seconds=120)
assert data_source._is_cache_valid(old_time) is False
def test_get_yfinance_interval(self, data_source: YahooFinanceDataSource) -> None:
"""Test interval mapping to yfinance format."""
assert data_source._get_yfinance_interval(Interval.MINUTE_1) == "1m"
assert data_source._get_yfinance_interval(Interval.MINUTE_5) == "5m"
assert data_source._get_yfinance_interval(Interval.DAY_1) == "1d"
assert data_source._get_yfinance_interval(Interval.WEEK_1) == "1wk"
assert data_source._get_yfinance_interval(Interval.MONTH_1) == "1mo"
def test_get_yfinance_interval_unsupported(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test unsupported interval raises error."""
# Create a fake interval not in the map
fake_interval = MagicMock()
fake_interval.value = "fake"
with pytest.raises(DataSourceError, match="Unsupported interval"):
data_source._get_yfinance_interval(fake_interval) # type: ignore[arg-type]
def test_get_period_for_interval(self, data_source: YahooFinanceDataSource) -> None:
"""Test period selection for different intervals."""
assert data_source._get_period_for_interval(Interval.MINUTE_1) == "5d"
assert data_source._get_period_for_interval(Interval.MINUTE_5) == "1mo"
assert data_source._get_period_for_interval(Interval.HOUR_1) == "3mo"
assert data_source._get_period_for_interval(Interval.DAY_1) == "1y"
assert data_source._get_period_for_interval(Interval.WEEK_1) == "5y"
assert data_source._get_period_for_interval(Interval.MONTH_1) == "max"
# Default for unknown interval
assert data_source._get_period_for_interval(Interval.HOUR_4) == "6mo"
def test_clear_cache(self, data_source: YahooFinanceDataSource) -> None:
"""Test cache clearing."""
# Add some mock data to cache
data_source._cache["key1"] = (pd.DataFrame(), datetime.now())
data_source._cache["key2"] = (pd.DataFrame(), datetime.now())
assert len(data_source._cache) == 2
data_source.clear_cache()
assert len(data_source._cache) == 0
def test_get_cache_stats(self, data_source: YahooFinanceDataSource) -> None:
"""Test cache statistics."""
data_source._cache["key1"] = (pd.DataFrame(), datetime.now())
stats = data_source.get_cache_stats()
assert stats["size"] == 1
assert stats["ttl_seconds"] == 60
assert "key1" in stats["keys"]
def test_set_availability(self, data_source: YahooFinanceDataSource) -> None:
"""Test availability status can be set."""
data_source.set_availability(False)
assert data_source._available is False
data_source.set_availability(True)
assert data_source._available is True
def test_clear_expired_cache(self, data_source: YahooFinanceDataSource) -> None:
"""Test expired cache entries are cleared."""
now = datetime.now()
expired_time = now - timedelta(seconds=120)
data_source._cache["fresh"] = (pd.DataFrame(), now)
data_source._cache["expired"] = (pd.DataFrame(), expired_time)
data_source._clear_expired_cache()
assert "fresh" in data_source._cache
assert "expired" not in data_source._cache
class TestYahooFinanceDataSourceAsync:
"""Async tests for YahooFinanceDataSource."""
@pytest.fixture
def data_source(self) -> YahooFinanceDataSource:
"""Create a YahooFinanceDataSource instance for testing."""
return YahooFinanceDataSource(cache_ttl=60)
@pytest.mark.asyncio
async def test_fetch_ohlcv_returns_dataframe(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv returns a DataFrame."""
# Create mock DataFrame
mock_df = pd.DataFrame({
"Date": [datetime(2024, 1, 1)],
"Open": [100.0],
"High": [105.0],
"Low": [99.0],
"Close": [103.0],
"Volume": [1000000.0],
})
mock_df.set_index("Date", inplace=True)
with patch.object(
data_source, "_fetch_yfinance_data", return_value=mock_df
):
df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
assert isinstance(df, pd.DataFrame)
assert "timestamp" in df.columns
assert "open" in df.columns
assert "high" in df.columns
assert "low" in df.columns
assert "close" in df.columns
assert "volume" in df.columns
@pytest.mark.asyncio
async def test_fetch_ohlcv_empty_raises_error(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv raises error when data is empty."""
empty_df = pd.DataFrame()
with patch.object(
data_source, "_fetch_yfinance_data", return_value=empty_df
):
with pytest.raises(DataNotAvailableError, match="No data available"):
await data_source.fetch_ohlcv("INVALID", Interval.DAY_1)
@pytest.mark.asyncio
async def test_fetch_ohlcv_caches_result(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv caches results."""
mock_df = pd.DataFrame({
"Date": [datetime(2024, 1, 1)],
"Open": [100.0],
"High": [105.0],
"Low": [99.0],
"Close": [103.0],
"Volume": [1000000.0],
})
mock_df.set_index("Date", inplace=True)
with patch.object(
data_source, "_fetch_yfinance_data", return_value=mock_df
) as mock_fetch:
# First call should fetch
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
assert mock_fetch.call_count == 1
# Second call should use cache
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
assert mock_fetch.call_count == 1 # No additional fetch
@pytest.mark.asyncio
async def test_fetch_ohlcv_limit(self, data_source: YahooFinanceDataSource) -> None:
"""Test fetch_ohlcv respects limit parameter."""
dates = [datetime(2024, 1, i) for i in range(1, 11)]
mock_df = pd.DataFrame({
"Date": dates,
"Open": [100.0] * 10,
"High": [105.0] * 10,
"Low": [99.0] * 10,
"Close": [103.0] * 10,
"Volume": [1000000.0] * 10,
})
mock_df.set_index("Date", inplace=True)
with patch.object(
data_source, "_fetch_yfinance_data", return_value=mock_df
):
df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1, limit=5)
assert len(df) == 5
@pytest.mark.asyncio
async def test_fetch_realtime_returns_quote(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_realtime returns RealtimeQuote."""
mock_info = {
"currentPrice": 150.0,
"bid": 149.95,
"ask": 150.05,
"bidSize": 100,
"askSize": 200,
"volume": 50000000.0,
}
with patch.object(
data_source, "_fetch_ticker_info", return_value=mock_info
):
quote = await data_source.fetch_realtime("AAPL")
assert isinstance(quote, RealtimeQuote)
assert quote.symbol == "AAPL"
assert quote.price == 150.0
assert quote.bid == 149.95
assert quote.ask == 150.05
assert quote.bid_size == 100
assert quote.ask_size == 200
assert quote.volume == 50000000.0
@pytest.mark.asyncio
async def test_fetch_realtime_empty_raises_error(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_realtime raises error when data is empty."""
with patch.object(data_source, "_fetch_ticker_info", return_value={}):
with pytest.raises(
DataNotAvailableError, match="No real-time data available"
):
await data_source.fetch_realtime("INVALID")
@pytest.mark.asyncio
async def test_fetch_ohlcv_with_datetime_index(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv handles Datetime index (intraday data)."""
mock_df = pd.DataFrame({
"Datetime": [datetime(2024, 1, 1, 9, 30)],
"Open": [100.0],
"High": [105.0],
"Low": [99.0],
"Close": [103.0],
"Volume": [1000000.0],
})
mock_df.set_index("Datetime", inplace=True)
with patch.object(
data_source, "_fetch_yfinance_data", return_value=mock_df
):
df = await data_source.fetch_ohlcv("AAPL", Interval.MINUTE_5)
assert "timestamp" in df.columns
assert "open" in df.columns
class TestYahooFinanceErrorHandling:
"""Tests for error handling in YahooFinanceDataSource."""
@pytest.fixture
def data_source(self) -> YahooFinanceDataSource:
"""Create a YahooFinanceDataSource instance for testing."""
return YahooFinanceDataSource(cache_ttl=60)
@pytest.mark.asyncio
async def test_fetch_ohlcv_error_wrapped(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv wraps exceptions in DataSourceError."""
with patch.object(
data_source,
"_fetch_yfinance_data",
side_effect=Exception("Network error"),
):
with pytest.raises(DataSourceError, match="Failed to fetch data"):
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
@pytest.mark.asyncio
async def test_fetch_realtime_error_wrapped(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_realtime wraps exceptions in DataSourceError."""
with patch.object(
data_source,
"_fetch_ticker_info",
side_effect=Exception("Network error"),
):
with pytest.raises(
DataSourceError, match="Failed to fetch real-time data"
):
await data_source.fetch_realtime("AAPL")
@pytest.mark.asyncio
async def test_fetch_ohlcv_missing_columns_raises_error(
self, data_source: YahooFinanceDataSource
) -> None:
"""Test fetch_ohlcv raises error when required columns are missing."""
# DataFrame with missing columns
mock_df = pd.DataFrame({
"Date": [datetime(2024, 1, 1)],
"Open": [100.0],
# Missing High, Low, Close, Volume
})
mock_df.set_index("Date", inplace=True)
with patch.object(
data_source, "_fetch_yfinance_data", return_value=mock_df
):
with pytest.raises(DataSourceError, match="Missing required column"):
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
class TestDataSourceExports:
"""Test that all expected exports are available."""
def test_all_exports_available(self) -> None:
"""Test all expected exports from openclaw.data."""
from openclaw.data import __all__ as exports
expected = [
"DataSource",
"DataSourceError",
"DataNotAvailableError",
"Interval",
"OHLCVData",
"RealtimeQuote",
"YahooFinanceDataSource",
]
for item in expected:
assert item in exports, f"{item} not in __all__"