"""Unit tests for data source interface and implementations.""" from datetime import datetime, timedelta from unittest.mock import MagicMock, patch import pandas as pd import pytest from openclaw.data import ( DataNotAvailableError, DataSource, DataSourceError, Interval, OHLCVData, RealtimeQuote, YahooFinanceDataSource, ) class TestDataSourceInterface: """Tests for the abstract DataSource interface.""" def test_data_source_error_inheritance(self) -> None: """Test DataSourceError is an Exception.""" assert issubclass(DataSourceError, Exception) def test_data_not_available_error_inheritance(self) -> None: """Test DataNotAvailableError inherits from DataSourceError.""" assert issubclass(DataNotAvailableError, DataSourceError) def test_interval_enum_values(self) -> None: """Test Interval enum has expected values.""" assert Interval.MINUTE_1.value == "1m" assert Interval.MINUTE_5.value == "5m" assert Interval.MINUTE_15.value == "15m" assert Interval.MINUTE_30.value == "30m" assert Interval.HOUR_1.value == "1h" assert Interval.HOUR_4.value == "4h" assert Interval.DAY_1.value == "1d" assert Interval.WEEK_1.value == "1wk" assert Interval.MONTH_1.value == "1mo" def test_ohlcv_data_creation(self) -> None: """Test OHLCVData dataclass creation.""" timestamp = datetime.now() data = OHLCVData( timestamp=timestamp, open=100.0, high=105.0, low=99.0, close=103.0, volume=1000000.0, ) assert data.timestamp == timestamp assert data.open == 100.0 assert data.high == 105.0 assert data.low == 99.0 assert data.close == 103.0 assert data.volume == 1000000.0 def test_ohlcv_data_is_immutable(self) -> None: """Test OHLCVData is frozen (immutable).""" timestamp = datetime.now() data = OHLCVData( timestamp=timestamp, open=100.0, high=105.0, low=99.0, close=103.0, volume=1000000.0, ) with pytest.raises(AttributeError): data.close = 110.0 # type: ignore[misc] def test_realtime_quote_creation(self) -> None: """Test RealtimeQuote dataclass creation.""" timestamp = datetime.now() quote = RealtimeQuote( symbol="AAPL", price=150.0, bid=149.95, ask=150.05, bid_size=100, ask_size=200, volume=50000000.0, timestamp=timestamp, ) assert quote.symbol == "AAPL" assert quote.price == 150.0 assert quote.bid == 149.95 assert quote.ask == 150.05 assert quote.bid_size == 100 assert quote.ask_size == 200 assert quote.volume == 50000000.0 assert quote.timestamp == timestamp def test_realtime_quote_is_immutable(self) -> None: """Test RealtimeQuote is frozen (immutable).""" timestamp = datetime.now() quote = RealtimeQuote( symbol="AAPL", price=150.0, bid=149.95, ask=150.05, bid_size=100, ask_size=200, volume=50000000.0, timestamp=timestamp, ) with pytest.raises(AttributeError): quote.price = 160.0 # type: ignore[misc] class TestYahooFinanceDataSource: """Tests for YahooFinanceDataSource implementation.""" @pytest.fixture def data_source(self) -> YahooFinanceDataSource: """Create a YahooFinanceDataSource instance for testing.""" return YahooFinanceDataSource(cache_ttl=60) def test_initialization(self) -> None: """Test YahooFinanceDataSource initialization.""" source = YahooFinanceDataSource(cache_ttl=120) assert source.name == "yahoo_finance" assert source._available is True # Initial availability state def test_default_cache_ttl(self) -> None: """Test default cache TTL is 60 seconds.""" source = YahooFinanceDataSource() assert source._cache_ttl == 60 def test_custom_cache_ttl(self) -> None: """Test custom cache TTL can be set.""" source = YahooFinanceDataSource(cache_ttl=300) assert source._cache_ttl == 300 def test_get_cache_key(self, data_source: YahooFinanceDataSource) -> None: """Test cache key generation.""" start = datetime(2024, 1, 1) end = datetime(2024, 1, 31) key = data_source._get_cache_key("AAPL", Interval.DAY_1, start, end) assert "AAPL" in key assert "1d" in key assert "2024-01-01" in key assert "2024-01-31" in key def test_get_cache_key_with_none(self, data_source: YahooFinanceDataSource) -> None: """Test cache key generation with None dates.""" key = data_source._get_cache_key("MSFT", Interval.HOUR_1, None, None) assert "MSFT" in key assert "1h" in key assert "None" in key def test_is_cache_valid(self, data_source: YahooFinanceDataSource) -> None: """Test cache validity check.""" now = datetime.now() # Recent cache entry should be valid assert data_source._is_cache_valid(now) is True # Old cache entry should be invalid old_time = now - timedelta(seconds=120) assert data_source._is_cache_valid(old_time) is False def test_get_yfinance_interval(self, data_source: YahooFinanceDataSource) -> None: """Test interval mapping to yfinance format.""" assert data_source._get_yfinance_interval(Interval.MINUTE_1) == "1m" assert data_source._get_yfinance_interval(Interval.MINUTE_5) == "5m" assert data_source._get_yfinance_interval(Interval.DAY_1) == "1d" assert data_source._get_yfinance_interval(Interval.WEEK_1) == "1wk" assert data_source._get_yfinance_interval(Interval.MONTH_1) == "1mo" def test_get_yfinance_interval_unsupported( self, data_source: YahooFinanceDataSource ) -> None: """Test unsupported interval raises error.""" # Create a fake interval not in the map fake_interval = MagicMock() fake_interval.value = "fake" with pytest.raises(DataSourceError, match="Unsupported interval"): data_source._get_yfinance_interval(fake_interval) # type: ignore[arg-type] def test_get_period_for_interval(self, data_source: YahooFinanceDataSource) -> None: """Test period selection for different intervals.""" assert data_source._get_period_for_interval(Interval.MINUTE_1) == "5d" assert data_source._get_period_for_interval(Interval.MINUTE_5) == "1mo" assert data_source._get_period_for_interval(Interval.HOUR_1) == "3mo" assert data_source._get_period_for_interval(Interval.DAY_1) == "1y" assert data_source._get_period_for_interval(Interval.WEEK_1) == "5y" assert data_source._get_period_for_interval(Interval.MONTH_1) == "max" # Default for unknown interval assert data_source._get_period_for_interval(Interval.HOUR_4) == "6mo" def test_clear_cache(self, data_source: YahooFinanceDataSource) -> None: """Test cache clearing.""" # Add some mock data to cache data_source._cache["key1"] = (pd.DataFrame(), datetime.now()) data_source._cache["key2"] = (pd.DataFrame(), datetime.now()) assert len(data_source._cache) == 2 data_source.clear_cache() assert len(data_source._cache) == 0 def test_get_cache_stats(self, data_source: YahooFinanceDataSource) -> None: """Test cache statistics.""" data_source._cache["key1"] = (pd.DataFrame(), datetime.now()) stats = data_source.get_cache_stats() assert stats["size"] == 1 assert stats["ttl_seconds"] == 60 assert "key1" in stats["keys"] def test_set_availability(self, data_source: YahooFinanceDataSource) -> None: """Test availability status can be set.""" data_source.set_availability(False) assert data_source._available is False data_source.set_availability(True) assert data_source._available is True def test_clear_expired_cache(self, data_source: YahooFinanceDataSource) -> None: """Test expired cache entries are cleared.""" now = datetime.now() expired_time = now - timedelta(seconds=120) data_source._cache["fresh"] = (pd.DataFrame(), now) data_source._cache["expired"] = (pd.DataFrame(), expired_time) data_source._clear_expired_cache() assert "fresh" in data_source._cache assert "expired" not in data_source._cache class TestYahooFinanceDataSourceAsync: """Async tests for YahooFinanceDataSource.""" @pytest.fixture def data_source(self) -> YahooFinanceDataSource: """Create a YahooFinanceDataSource instance for testing.""" return YahooFinanceDataSource(cache_ttl=60) @pytest.mark.asyncio async def test_fetch_ohlcv_returns_dataframe( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv returns a DataFrame.""" # Create mock DataFrame mock_df = pd.DataFrame({ "Date": [datetime(2024, 1, 1)], "Open": [100.0], "High": [105.0], "Low": [99.0], "Close": [103.0], "Volume": [1000000.0], }) mock_df.set_index("Date", inplace=True) with patch.object( data_source, "_fetch_yfinance_data", return_value=mock_df ): df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1) assert isinstance(df, pd.DataFrame) assert "timestamp" in df.columns assert "open" in df.columns assert "high" in df.columns assert "low" in df.columns assert "close" in df.columns assert "volume" in df.columns @pytest.mark.asyncio async def test_fetch_ohlcv_empty_raises_error( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv raises error when data is empty.""" empty_df = pd.DataFrame() with patch.object( data_source, "_fetch_yfinance_data", return_value=empty_df ): with pytest.raises(DataNotAvailableError, match="No data available"): await data_source.fetch_ohlcv("INVALID", Interval.DAY_1) @pytest.mark.asyncio async def test_fetch_ohlcv_caches_result( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv caches results.""" mock_df = pd.DataFrame({ "Date": [datetime(2024, 1, 1)], "Open": [100.0], "High": [105.0], "Low": [99.0], "Close": [103.0], "Volume": [1000000.0], }) mock_df.set_index("Date", inplace=True) with patch.object( data_source, "_fetch_yfinance_data", return_value=mock_df ) as mock_fetch: # First call should fetch await data_source.fetch_ohlcv("AAPL", Interval.DAY_1) assert mock_fetch.call_count == 1 # Second call should use cache await data_source.fetch_ohlcv("AAPL", Interval.DAY_1) assert mock_fetch.call_count == 1 # No additional fetch @pytest.mark.asyncio async def test_fetch_ohlcv_limit(self, data_source: YahooFinanceDataSource) -> None: """Test fetch_ohlcv respects limit parameter.""" dates = [datetime(2024, 1, i) for i in range(1, 11)] mock_df = pd.DataFrame({ "Date": dates, "Open": [100.0] * 10, "High": [105.0] * 10, "Low": [99.0] * 10, "Close": [103.0] * 10, "Volume": [1000000.0] * 10, }) mock_df.set_index("Date", inplace=True) with patch.object( data_source, "_fetch_yfinance_data", return_value=mock_df ): df = await data_source.fetch_ohlcv("AAPL", Interval.DAY_1, limit=5) assert len(df) == 5 @pytest.mark.asyncio async def test_fetch_realtime_returns_quote( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_realtime returns RealtimeQuote.""" mock_info = { "currentPrice": 150.0, "bid": 149.95, "ask": 150.05, "bidSize": 100, "askSize": 200, "volume": 50000000.0, } with patch.object( data_source, "_fetch_ticker_info", return_value=mock_info ): quote = await data_source.fetch_realtime("AAPL") assert isinstance(quote, RealtimeQuote) assert quote.symbol == "AAPL" assert quote.price == 150.0 assert quote.bid == 149.95 assert quote.ask == 150.05 assert quote.bid_size == 100 assert quote.ask_size == 200 assert quote.volume == 50000000.0 @pytest.mark.asyncio async def test_fetch_realtime_empty_raises_error( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_realtime raises error when data is empty.""" with patch.object(data_source, "_fetch_ticker_info", return_value={}): with pytest.raises( DataNotAvailableError, match="No real-time data available" ): await data_source.fetch_realtime("INVALID") @pytest.mark.asyncio async def test_fetch_ohlcv_with_datetime_index( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv handles Datetime index (intraday data).""" mock_df = pd.DataFrame({ "Datetime": [datetime(2024, 1, 1, 9, 30)], "Open": [100.0], "High": [105.0], "Low": [99.0], "Close": [103.0], "Volume": [1000000.0], }) mock_df.set_index("Datetime", inplace=True) with patch.object( data_source, "_fetch_yfinance_data", return_value=mock_df ): df = await data_source.fetch_ohlcv("AAPL", Interval.MINUTE_5) assert "timestamp" in df.columns assert "open" in df.columns class TestYahooFinanceErrorHandling: """Tests for error handling in YahooFinanceDataSource.""" @pytest.fixture def data_source(self) -> YahooFinanceDataSource: """Create a YahooFinanceDataSource instance for testing.""" return YahooFinanceDataSource(cache_ttl=60) @pytest.mark.asyncio async def test_fetch_ohlcv_error_wrapped( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv wraps exceptions in DataSourceError.""" with patch.object( data_source, "_fetch_yfinance_data", side_effect=Exception("Network error"), ): with pytest.raises(DataSourceError, match="Failed to fetch data"): await data_source.fetch_ohlcv("AAPL", Interval.DAY_1) @pytest.mark.asyncio async def test_fetch_realtime_error_wrapped( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_realtime wraps exceptions in DataSourceError.""" with patch.object( data_source, "_fetch_ticker_info", side_effect=Exception("Network error"), ): with pytest.raises( DataSourceError, match="Failed to fetch real-time data" ): await data_source.fetch_realtime("AAPL") @pytest.mark.asyncio async def test_fetch_ohlcv_missing_columns_raises_error( self, data_source: YahooFinanceDataSource ) -> None: """Test fetch_ohlcv raises error when required columns are missing.""" # DataFrame with missing columns mock_df = pd.DataFrame({ "Date": [datetime(2024, 1, 1)], "Open": [100.0], # Missing High, Low, Close, Volume }) mock_df.set_index("Date", inplace=True) with patch.object( data_source, "_fetch_yfinance_data", return_value=mock_df ): with pytest.raises(DataSourceError, match="Missing required column"): await data_source.fetch_ohlcv("AAPL", Interval.DAY_1) class TestDataSourceExports: """Test that all expected exports are available.""" def test_all_exports_available(self) -> None: """Test all expected exports from openclaw.data.""" from openclaw.data import __all__ as exports expected = [ "DataSource", "DataSourceError", "DataNotAvailableError", "Interval", "OHLCVData", "RealtimeQuote", "YahooFinanceDataSource", ] for item in expected: assert item in exports, f"{item} not in __all__"