286 lines
9.1 KiB
Python
286 lines
9.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
# pylint: disable=W0212
|
|
import asyncio
|
|
import time
|
|
import logging
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
import pytest
|
|
from backend.services.market import MarketService
|
|
from backend.data.polling_price_manager import PollingPriceManager
|
|
from backend.llm.models import RetryChatModel
|
|
|
|
|
|
class TestPollingPriceManager:
|
|
def test_init(self):
|
|
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
|
|
|
|
assert manager.api_key == "test_key"
|
|
assert manager.poll_interval == 30
|
|
assert manager.provider == "finnhub"
|
|
assert manager.running is False
|
|
|
|
def test_init_yfinance(self):
|
|
manager = PollingPriceManager(provider="yfinance", poll_interval=15)
|
|
|
|
assert manager.api_key is None
|
|
assert manager.poll_interval == 15
|
|
assert manager.provider == "yfinance"
|
|
assert manager.running is False
|
|
|
|
def test_subscribe(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.subscribe(["AAPL", "MSFT"])
|
|
|
|
assert "AAPL" in manager.subscribed_symbols
|
|
assert "MSFT" in manager.subscribed_symbols
|
|
|
|
def test_unsubscribe(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.subscribe(["AAPL", "MSFT"])
|
|
manager.unsubscribe(["AAPL"])
|
|
|
|
assert "AAPL" not in manager.subscribed_symbols
|
|
assert "MSFT" in manager.subscribed_symbols
|
|
|
|
def test_add_price_callback(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
callback = MagicMock()
|
|
manager.add_price_callback(callback)
|
|
|
|
assert callback in manager.price_callbacks
|
|
|
|
@patch.object(PollingPriceManager, "_fetch_prices")
|
|
def test_start_stop(self, _mock_fetch_prices):
|
|
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
|
|
manager.subscribe(["AAPL"])
|
|
|
|
manager.start()
|
|
assert manager.running is True
|
|
|
|
time.sleep(0.1)
|
|
|
|
manager.stop()
|
|
assert manager.running is False
|
|
|
|
def test_start_without_subscription(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.start()
|
|
|
|
assert manager.running is False
|
|
|
|
def test_get_latest_price(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.latest_prices["AAPL"] = 150.0
|
|
|
|
price = manager.get_latest_price("AAPL")
|
|
assert price == 150.0
|
|
|
|
def test_get_open_price(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.open_prices["AAPL"] = 148.0
|
|
|
|
price = manager.get_open_price("AAPL")
|
|
assert price == 148.0
|
|
|
|
def test_reset_open_prices(self):
|
|
manager = PollingPriceManager(api_key="test_key")
|
|
manager.open_prices["AAPL"] = 150.0
|
|
|
|
manager.reset_open_prices()
|
|
|
|
assert len(manager.open_prices) == 0
|
|
|
|
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
|
|
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
|
manager.subscribe(["AAPL"])
|
|
|
|
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
|
|
with caplog.at_level(logging.DEBUG):
|
|
for _ in range(3):
|
|
manager._fetch_prices()
|
|
|
|
assert manager._failure_counts["AAPL"] == 3
|
|
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
|
|
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
|
|
|
|
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
|
|
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
|
manager.subscribe(["AAPL"])
|
|
|
|
with patch.object(
|
|
manager,
|
|
"_fetch_quote",
|
|
side_effect=[
|
|
ValueError("temporary outage"),
|
|
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
|
|
],
|
|
):
|
|
with caplog.at_level(logging.INFO):
|
|
manager._fetch_prices()
|
|
manager._fetch_prices()
|
|
|
|
assert "AAPL" not in manager._failure_counts
|
|
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
|
|
|
|
|
|
class TestRetryChatModel:
|
|
@pytest.mark.asyncio
|
|
async def test_async_retry_recovers_from_disconnect(self):
|
|
attempts = {"count": 0}
|
|
|
|
class FakeAsyncModel:
|
|
model_name = "fake-async-model"
|
|
|
|
async def __call__(self, *args, **kwargs):
|
|
attempts["count"] += 1
|
|
if attempts["count"] < 2:
|
|
raise RuntimeError("Server disconnected")
|
|
return {"ok": True}
|
|
|
|
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
|
|
result = await wrapped("hello")
|
|
|
|
assert result == {"ok": True}
|
|
assert attempts["count"] == 2
|
|
|
|
|
|
class TestMarketService:
|
|
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
|
|
@patch.object(PollingPriceManager, "start")
|
|
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
poll_interval=10,
|
|
)
|
|
|
|
service._start_real_mode()
|
|
|
|
assert isinstance(service._price_manager, PollingPriceManager)
|
|
assert service._price_manager.provider == "yfinance"
|
|
|
|
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
|
|
@patch.object(PollingPriceManager, "start")
|
|
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
poll_interval=10,
|
|
)
|
|
|
|
service._start_real_mode()
|
|
|
|
assert isinstance(service._price_manager, PollingPriceManager)
|
|
assert service._price_manager.provider == "yfinance"
|
|
|
|
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
|
|
@pytest.mark.asyncio
|
|
async def test_start_real_mode_without_api_key(self, _mock_sources):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
api_key=None,
|
|
)
|
|
|
|
broadcast_func = AsyncMock()
|
|
|
|
with pytest.raises(ValueError) as excinfo:
|
|
await service.start(broadcast_func)
|
|
|
|
assert "API key required" in str(excinfo.value)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_already_running(self):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
backtest_mode=True,
|
|
)
|
|
|
|
broadcast_func = AsyncMock()
|
|
|
|
# First start with backtest mode
|
|
await service.start(broadcast_func)
|
|
assert service.running is True
|
|
|
|
# Start again should not fail
|
|
await service.start(broadcast_func)
|
|
|
|
service.stop()
|
|
|
|
def test_stop(self):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
backtest_mode=True,
|
|
)
|
|
service.running = True
|
|
service._price_manager = MagicMock()
|
|
|
|
service.stop()
|
|
|
|
assert service.running is False
|
|
assert service._price_manager is None
|
|
|
|
def test_stop_when_not_running(self):
|
|
service = MarketService(
|
|
tickers=["AAPL"],
|
|
backtest_mode=True,
|
|
)
|
|
|
|
# Should not raise
|
|
service.stop()
|
|
assert service.running is False
|
|
|
|
def test_get_price_sync(self):
|
|
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
|
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
|
|
|
price = service.get_price_sync("AAPL")
|
|
assert price == 150.0
|
|
|
|
def test_get_price_sync_not_found(self):
|
|
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
|
|
|
price = service.get_price_sync("MSFT")
|
|
assert price is None
|
|
|
|
def test_get_all_prices(self):
|
|
service = MarketService(tickers=["AAPL", "MSFT"], backtest_mode=True)
|
|
service.cache["AAPL"] = {"price": 150.0}
|
|
service.cache["MSFT"] = {"price": 400.0}
|
|
|
|
prices = service.get_all_prices()
|
|
|
|
assert prices["AAPL"] == 150.0
|
|
assert prices["MSFT"] == 400.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_price_update(self):
|
|
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
|
service._broadcast_func = AsyncMock()
|
|
|
|
price_data = {
|
|
"symbol": "AAPL",
|
|
"price": 150.0,
|
|
"open": 148.0,
|
|
"timestamp": 1234567890,
|
|
}
|
|
|
|
await service._broadcast_price_update(price_data)
|
|
|
|
service._broadcast_func.assert_called_once()
|
|
call_args = service._broadcast_func.call_args[0][0]
|
|
assert call_args["type"] == "price_update"
|
|
assert call_args["symbol"] == "AAPL"
|
|
assert call_args["price"] == 150.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_price_update_no_func(self):
|
|
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
|
service._broadcast_func = None
|
|
|
|
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
|
|
|
# Should not raise
|
|
await service._broadcast_price_update(price_data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|