483 lines
17 KiB
Python
483 lines
17 KiB
Python
"""Unit tests for data source interface and implementations."""
|
|
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from openclaw.data import (
|
|
DataNotAvailableError,
|
|
DataSource,
|
|
DataSourceError,
|
|
Interval,
|
|
OHLCVData,
|
|
RealtimeQuote,
|
|
YahooFinanceDataSource,
|
|
)
|
|
|
|
|
|
class TestDataSourceInterface:
|
|
"""Tests for the abstract DataSource interface."""
|
|
|
|
def test_data_source_error_inheritance(self) -> None:
|
|
"""Test DataSourceError is an Exception."""
|
|
assert issubclass(DataSourceError, Exception)
|
|
|
|
def test_data_not_available_error_inheritance(self) -> None:
|
|
"""Test DataNotAvailableError inherits from DataSourceError."""
|
|
assert issubclass(DataNotAvailableError, DataSourceError)
|
|
|
|
def test_interval_enum_values(self) -> None:
|
|
"""Test Interval enum has expected values."""
|
|
assert Interval.MINUTE_1.value == "1m"
|
|
assert Interval.MINUTE_5.value == "5m"
|
|
assert Interval.MINUTE_15.value == "15m"
|
|
assert Interval.MINUTE_30.value == "30m"
|
|
assert Interval.HOUR_1.value == "1h"
|
|
assert Interval.HOUR_4.value == "4h"
|
|
assert Interval.DAY_1.value == "1d"
|
|
assert Interval.WEEK_1.value == "1wk"
|
|
assert Interval.MONTH_1.value == "1mo"
|
|
|
|
def test_ohlcv_data_creation(self) -> None:
|
|
"""Test OHLCVData dataclass creation."""
|
|
timestamp = datetime.now()
|
|
data = OHLCVData(
|
|
timestamp=timestamp,
|
|
open=100.0,
|
|
high=105.0,
|
|
low=99.0,
|
|
close=103.0,
|
|
volume=1000000.0,
|
|
)
|
|
|
|
assert data.timestamp == timestamp
|
|
assert data.open == 100.0
|
|
assert data.high == 105.0
|
|
assert data.low == 99.0
|
|
assert data.close == 103.0
|
|
assert data.volume == 1000000.0
|
|
|
|
def test_ohlcv_data_is_immutable(self) -> None:
|
|
"""Test OHLCVData is frozen (immutable)."""
|
|
timestamp = datetime.now()
|
|
data = OHLCVData(
|
|
timestamp=timestamp,
|
|
open=100.0,
|
|
high=105.0,
|
|
low=99.0,
|
|
close=103.0,
|
|
volume=1000000.0,
|
|
)
|
|
|
|
with pytest.raises(AttributeError):
|
|
data.close = 110.0 # type: ignore[misc]
|
|
|
|
def test_realtime_quote_creation(self) -> None:
|
|
"""Test RealtimeQuote dataclass creation."""
|
|
timestamp = datetime.now()
|
|
quote = RealtimeQuote(
|
|
symbol="AAPL",
|
|
price=150.0,
|
|
bid=149.95,
|
|
ask=150.05,
|
|
bid_size=100,
|
|
ask_size=200,
|
|
volume=50000000.0,
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
assert quote.symbol == "AAPL"
|
|
assert quote.price == 150.0
|
|
assert quote.bid == 149.95
|
|
assert quote.ask == 150.05
|
|
assert quote.bid_size == 100
|
|
assert quote.ask_size == 200
|
|
assert quote.volume == 50000000.0
|
|
assert quote.timestamp == timestamp
|
|
|
|
def test_realtime_quote_is_immutable(self) -> None:
|
|
"""Test RealtimeQuote is frozen (immutable)."""
|
|
timestamp = datetime.now()
|
|
quote = RealtimeQuote(
|
|
symbol="AAPL",
|
|
price=150.0,
|
|
bid=149.95,
|
|
ask=150.05,
|
|
bid_size=100,
|
|
ask_size=200,
|
|
volume=50000000.0,
|
|
timestamp=timestamp,
|
|
)
|
|
|
|
with pytest.raises(AttributeError):
|
|
quote.price = 160.0 # type: ignore[misc]
|
|
|
|
|
|
class TestYahooFinanceDataSource:
|
|
"""Tests for YahooFinanceDataSource implementation."""
|
|
|
|
@pytest.fixture
|
|
def data_source(self) -> YahooFinanceDataSource:
|
|
"""Create a YahooFinanceDataSource instance for testing."""
|
|
return YahooFinanceDataSource(cache_ttl=60)
|
|
|
|
def test_initialization(self) -> None:
|
|
"""Test YahooFinanceDataSource initialization."""
|
|
source = YahooFinanceDataSource(cache_ttl=120)
|
|
|
|
assert source.name == "yahoo_finance"
|
|
assert source._available is True # Initial availability state
|
|
|
|
def test_default_cache_ttl(self) -> None:
|
|
"""Test default cache TTL is 60 seconds."""
|
|
source = YahooFinanceDataSource()
|
|
assert source._cache_ttl == 60
|
|
|
|
def test_custom_cache_ttl(self) -> None:
|
|
"""Test custom cache TTL can be set."""
|
|
source = YahooFinanceDataSource(cache_ttl=300)
|
|
assert source._cache_ttl == 300
|
|
|
|
def test_get_cache_key(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test cache key generation."""
|
|
start = datetime(2024, 1, 1)
|
|
end = datetime(2024, 1, 31)
|
|
|
|
key = data_source._get_cache_key("AAPL", Interval.DAY_1, start, end)
|
|
|
|
assert "AAPL" in key
|
|
assert "1d" in key
|
|
assert "2024-01-01" in key
|
|
assert "2024-01-31" in key
|
|
|
|
def test_get_cache_key_with_none(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test cache key generation with None dates."""
|
|
key = data_source._get_cache_key("MSFT", Interval.HOUR_1, None, None)
|
|
|
|
assert "MSFT" in key
|
|
assert "1h" in key
|
|
assert "None" in key
|
|
|
|
def test_is_cache_valid(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test cache validity check."""
|
|
now = datetime.now()
|
|
|
|
# Recent cache entry should be valid
|
|
assert data_source._is_cache_valid(now) is True
|
|
|
|
# Old cache entry should be invalid
|
|
old_time = now - timedelta(seconds=120)
|
|
assert data_source._is_cache_valid(old_time) is False
|
|
|
|
def test_get_yfinance_interval(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test interval mapping to yfinance format."""
|
|
assert data_source._get_yfinance_interval(Interval.MINUTE_1) == "1m"
|
|
assert data_source._get_yfinance_interval(Interval.MINUTE_5) == "5m"
|
|
assert data_source._get_yfinance_interval(Interval.DAY_1) == "1d"
|
|
assert data_source._get_yfinance_interval(Interval.WEEK_1) == "1wk"
|
|
assert data_source._get_yfinance_interval(Interval.MONTH_1) == "1mo"
|
|
|
|
def test_get_yfinance_interval_unsupported(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test unsupported interval raises error."""
|
|
# Create a fake interval not in the map
|
|
fake_interval = MagicMock()
|
|
fake_interval.value = "fake"
|
|
|
|
with pytest.raises(DataSourceError, match="Unsupported interval"):
|
|
data_source._get_yfinance_interval(fake_interval) # type: ignore[arg-type]
|
|
|
|
def test_get_period_for_interval(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test period selection for different intervals."""
|
|
assert data_source._get_period_for_interval(Interval.MINUTE_1) == "5d"
|
|
assert data_source._get_period_for_interval(Interval.MINUTE_5) == "1mo"
|
|
assert data_source._get_period_for_interval(Interval.HOUR_1) == "3mo"
|
|
assert data_source._get_period_for_interval(Interval.DAY_1) == "1y"
|
|
assert data_source._get_period_for_interval(Interval.WEEK_1) == "5y"
|
|
assert data_source._get_period_for_interval(Interval.MONTH_1) == "max"
|
|
# Default for unknown interval
|
|
assert data_source._get_period_for_interval(Interval.HOUR_4) == "6mo"
|
|
|
|
def test_clear_cache(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test cache clearing."""
|
|
# Add some mock data to cache
|
|
data_source._cache["key1"] = (pd.DataFrame(), datetime.now())
|
|
data_source._cache["key2"] = (pd.DataFrame(), datetime.now())
|
|
|
|
assert len(data_source._cache) == 2
|
|
|
|
data_source.clear_cache()
|
|
|
|
assert len(data_source._cache) == 0
|
|
|
|
def test_get_cache_stats(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test cache statistics."""
|
|
data_source._cache["key1"] = (pd.DataFrame(), datetime.now())
|
|
|
|
stats = data_source.get_cache_stats()
|
|
|
|
assert stats["size"] == 1
|
|
assert stats["ttl_seconds"] == 60
|
|
assert "key1" in stats["keys"]
|
|
|
|
def test_set_availability(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test availability status can be set."""
|
|
data_source.set_availability(False)
|
|
assert data_source._available is False
|
|
|
|
data_source.set_availability(True)
|
|
assert data_source._available is True
|
|
|
|
def test_clear_expired_cache(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test expired cache entries are cleared."""
|
|
now = datetime.now()
|
|
expired_time = now - timedelta(seconds=120)
|
|
|
|
data_source._cache["fresh"] = (pd.DataFrame(), now)
|
|
data_source._cache["expired"] = (pd.DataFrame(), expired_time)
|
|
|
|
data_source._clear_expired_cache()
|
|
|
|
assert "fresh" in data_source._cache
|
|
assert "expired" not in data_source._cache
|
|
|
|
|
|
class TestYahooFinanceDataSourceAsync:
|
|
"""Async tests for YahooFinanceDataSource."""
|
|
|
|
@pytest.fixture
|
|
def data_source(self) -> YahooFinanceDataSource:
|
|
"""Create a YahooFinanceDataSource instance for testing."""
|
|
return YahooFinanceDataSource(cache_ttl=60)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_returns_dataframe(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv returns a DataFrame."""
|
|
# Create mock DataFrame
|
|
mock_df = pd.DataFrame({
|
|
"Date": [datetime(2024, 1, 1)],
|
|
"Open": [100.0],
|
|
"High": [105.0],
|
|
"Low": [99.0],
|
|
"Close": [103.0],
|
|
"Volume": [1000000.0],
|
|
})
|
|
mock_df.set_index("Date", inplace=True)
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=mock_df
|
|
):
|
|
df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
|
|
|
|
assert isinstance(df, pd.DataFrame)
|
|
assert "timestamp" in df.columns
|
|
assert "open" in df.columns
|
|
assert "high" in df.columns
|
|
assert "low" in df.columns
|
|
assert "close" in df.columns
|
|
assert "volume" in df.columns
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_empty_raises_error(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv raises error when data is empty."""
|
|
empty_df = pd.DataFrame()
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=empty_df
|
|
):
|
|
with pytest.raises(DataNotAvailableError, match="No data available"):
|
|
await data_source.fetch_ohlcv("INVALID", Interval.DAY_1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_caches_result(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv caches results."""
|
|
mock_df = pd.DataFrame({
|
|
"Date": [datetime(2024, 1, 1)],
|
|
"Open": [100.0],
|
|
"High": [105.0],
|
|
"Low": [99.0],
|
|
"Close": [103.0],
|
|
"Volume": [1000000.0],
|
|
})
|
|
mock_df.set_index("Date", inplace=True)
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=mock_df
|
|
) as mock_fetch:
|
|
# First call should fetch
|
|
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
|
|
assert mock_fetch.call_count == 1
|
|
|
|
# Second call should use cache
|
|
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
|
|
assert mock_fetch.call_count == 1 # No additional fetch
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_limit(self, data_source: YahooFinanceDataSource) -> None:
|
|
"""Test fetch_ohlcv respects limit parameter."""
|
|
dates = [datetime(2024, 1, i) for i in range(1, 11)]
|
|
mock_df = pd.DataFrame({
|
|
"Date": dates,
|
|
"Open": [100.0] * 10,
|
|
"High": [105.0] * 10,
|
|
"Low": [99.0] * 10,
|
|
"Close": [103.0] * 10,
|
|
"Volume": [1000000.0] * 10,
|
|
})
|
|
mock_df.set_index("Date", inplace=True)
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=mock_df
|
|
):
|
|
df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1, limit=5)
|
|
|
|
assert len(df) == 5
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_realtime_returns_quote(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_realtime returns RealtimeQuote."""
|
|
mock_info = {
|
|
"currentPrice": 150.0,
|
|
"bid": 149.95,
|
|
"ask": 150.05,
|
|
"bidSize": 100,
|
|
"askSize": 200,
|
|
"volume": 50000000.0,
|
|
}
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_ticker_info", return_value=mock_info
|
|
):
|
|
quote = await data_source.fetch_realtime("AAPL")
|
|
|
|
assert isinstance(quote, RealtimeQuote)
|
|
assert quote.symbol == "AAPL"
|
|
assert quote.price == 150.0
|
|
assert quote.bid == 149.95
|
|
assert quote.ask == 150.05
|
|
assert quote.bid_size == 100
|
|
assert quote.ask_size == 200
|
|
assert quote.volume == 50000000.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_realtime_empty_raises_error(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_realtime raises error when data is empty."""
|
|
with patch.object(data_source, "_fetch_ticker_info", return_value={}):
|
|
with pytest.raises(
|
|
DataNotAvailableError, match="No real-time data available"
|
|
):
|
|
await data_source.fetch_realtime("INVALID")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_with_datetime_index(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv handles Datetime index (intraday data)."""
|
|
mock_df = pd.DataFrame({
|
|
"Datetime": [datetime(2024, 1, 1, 9, 30)],
|
|
"Open": [100.0],
|
|
"High": [105.0],
|
|
"Low": [99.0],
|
|
"Close": [103.0],
|
|
"Volume": [1000000.0],
|
|
})
|
|
mock_df.set_index("Datetime", inplace=True)
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=mock_df
|
|
):
|
|
df = await data_source.fetch_ohlcv("AAPL", Interval.MINUTE_5)
|
|
|
|
assert "timestamp" in df.columns
|
|
assert "open" in df.columns
|
|
|
|
|
|
class TestYahooFinanceErrorHandling:
|
|
"""Tests for error handling in YahooFinanceDataSource."""
|
|
|
|
@pytest.fixture
|
|
def data_source(self) -> YahooFinanceDataSource:
|
|
"""Create a YahooFinanceDataSource instance for testing."""
|
|
return YahooFinanceDataSource(cache_ttl=60)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_error_wrapped(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv wraps exceptions in DataSourceError."""
|
|
with patch.object(
|
|
data_source,
|
|
"_fetch_yfinance_data",
|
|
side_effect=Exception("Network error"),
|
|
):
|
|
with pytest.raises(DataSourceError, match="Failed to fetch data"):
|
|
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_realtime_error_wrapped(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_realtime wraps exceptions in DataSourceError."""
|
|
with patch.object(
|
|
data_source,
|
|
"_fetch_ticker_info",
|
|
side_effect=Exception("Network error"),
|
|
):
|
|
with pytest.raises(
|
|
DataSourceError, match="Failed to fetch real-time data"
|
|
):
|
|
await data_source.fetch_realtime("AAPL")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fetch_ohlcv_missing_columns_raises_error(
|
|
self, data_source: YahooFinanceDataSource
|
|
) -> None:
|
|
"""Test fetch_ohlcv raises error when required columns are missing."""
|
|
# DataFrame with missing columns
|
|
mock_df = pd.DataFrame({
|
|
"Date": [datetime(2024, 1, 1)],
|
|
"Open": [100.0],
|
|
# Missing High, Low, Close, Volume
|
|
})
|
|
mock_df.set_index("Date", inplace=True)
|
|
|
|
with patch.object(
|
|
data_source, "_fetch_yfinance_data", return_value=mock_df
|
|
):
|
|
with pytest.raises(DataSourceError, match="Missing required column"):
|
|
await data_source.fetch_ohlcv("AAPL", Interval.DAY_1)
|
|
|
|
|
|
class TestDataSourceExports:
|
|
"""Test that all expected exports are available."""
|
|
|
|
def test_all_exports_available(self) -> None:
|
|
"""Test all expected exports from openclaw.data."""
|
|
from openclaw.data import __all__ as exports
|
|
|
|
expected = [
|
|
"DataSource",
|
|
"DataSourceError",
|
|
"DataNotAvailableError",
|
|
"Interval",
|
|
"OHLCVData",
|
|
"RealtimeQuote",
|
|
"YahooFinanceDataSource",
|
|
]
|
|
|
|
for item in expected:
|
|
assert item in exports, f"{item} not in __all__"
|