Files
evotraders/backend/tests/test_market_service.py

532 lines
16 KiB
Python

# -*- coding: utf-8 -*-
# pylint: disable=W0212
import asyncio
import time
import logging
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
from backend.services.market import MarketService
from backend.data.mock_price_manager import MockPriceManager
from backend.data.polling_price_manager import PollingPriceManager
from backend.llm.models import RetryChatModel
class TestMockPriceManager:
def test_init_default(self):
manager = MockPriceManager()
assert manager.poll_interval == 10
assert manager.volatility == 0.5
assert manager.running is False
assert len(manager.subscribed_symbols) == 0
def test_init_custom(self):
manager = MockPriceManager(poll_interval=5, volatility=1.0)
assert manager.poll_interval == 5
assert manager.volatility == 1.0
def test_subscribe(self):
manager = MockPriceManager()
manager.subscribe(["AAPL", "MSFT"])
assert "AAPL" in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
assert manager.base_prices["AAPL"] == 237.50 # default price
assert manager.base_prices["MSFT"] == 425.30 # default price
def test_subscribe_with_base_prices(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
assert manager.base_prices["AAPL"] == 100.0
assert manager.open_prices["AAPL"] == 100.0
assert manager.latest_prices["AAPL"] == 100.0
def test_subscribe_unknown_symbol(self):
manager = MockPriceManager()
manager.subscribe(["UNKNOWN"])
assert "UNKNOWN" in manager.subscribed_symbols
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
def test_unsubscribe(self):
manager = MockPriceManager()
manager.subscribe(["AAPL", "MSFT"])
manager.unsubscribe(["AAPL"])
assert "AAPL" not in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_add_price_callback(self):
manager = MockPriceManager()
callback = MagicMock()
manager.add_price_callback(callback)
assert callback in manager.price_callbacks
def test_generate_price_update_within_bounds(self):
manager = MockPriceManager(volatility=0.5)
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
for _ in range(100):
new_price = manager._generate_price_update("AAPL")
# Should be within +/-10% of open
assert 90.0 <= new_price <= 110.0
def test_update_prices_triggers_callback(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
callback = MagicMock()
manager.add_price_callback(callback)
manager._update_prices()
callback.assert_called_once()
call_args = callback.call_args[0][0]
assert call_args["symbol"] == "AAPL"
assert "price" in call_args
assert "timestamp" in call_args
def test_start_stop(self):
manager = MockPriceManager(poll_interval=1)
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.start()
assert manager.running is True
time.sleep(0.1) # let thread start
manager.stop()
assert manager.running is False
def test_start_without_subscription(self):
manager = MockPriceManager()
manager.start()
assert (
manager.running is False
) # should not start without subscriptions
def test_get_latest_price(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
price = manager.get_latest_price("AAPL")
assert price == 100.0
def test_get_latest_price_unknown(self):
manager = MockPriceManager()
price = manager.get_latest_price("UNKNOWN")
assert price is None
def test_get_all_latest_prices(self):
manager = MockPriceManager()
manager.subscribe(
["AAPL", "MSFT"],
base_prices={"AAPL": 100.0, "MSFT": 200.0},
)
prices = manager.get_all_latest_prices()
assert prices["AAPL"] == 100.0
assert prices["MSFT"] == 200.0
def test_reset_open_prices(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.latest_prices["AAPL"] = 105.0
manager.reset_open_prices()
# Open price should change (based on latest with small gap)
assert manager.open_prices["AAPL"] != 100.0
def test_set_base_price(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.set_base_price("AAPL", 150.0)
assert manager.base_prices["AAPL"] == 150.0
assert manager.open_prices["AAPL"] == 150.0
assert manager.latest_prices["AAPL"] == 150.0
class TestPollingPriceManager:
def test_init(self):
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
assert manager.api_key == "test_key"
assert manager.poll_interval == 30
assert manager.provider == "finnhub"
assert manager.running is False
def test_init_yfinance(self):
manager = PollingPriceManager(provider="yfinance", poll_interval=15)
assert manager.api_key is None
assert manager.poll_interval == 15
assert manager.provider == "yfinance"
assert manager.running is False
def test_subscribe(self):
manager = PollingPriceManager(api_key="test_key")
manager.subscribe(["AAPL", "MSFT"])
assert "AAPL" in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_unsubscribe(self):
manager = PollingPriceManager(api_key="test_key")
manager.subscribe(["AAPL", "MSFT"])
manager.unsubscribe(["AAPL"])
assert "AAPL" not in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_add_price_callback(self):
manager = PollingPriceManager(api_key="test_key")
callback = MagicMock()
manager.add_price_callback(callback)
assert callback in manager.price_callbacks
@patch.object(PollingPriceManager, "_fetch_prices")
def test_start_stop(self, _mock_fetch_prices):
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
manager.subscribe(["AAPL"])
manager.start()
assert manager.running is True
time.sleep(0.1)
manager.stop()
assert manager.running is False
def test_start_without_subscription(self):
manager = PollingPriceManager(api_key="test_key")
manager.start()
assert manager.running is False
def test_get_latest_price(self):
manager = PollingPriceManager(api_key="test_key")
manager.latest_prices["AAPL"] = 150.0
price = manager.get_latest_price("AAPL")
assert price == 150.0
def test_get_open_price(self):
manager = PollingPriceManager(api_key="test_key")
manager.open_prices["AAPL"] = 148.0
price = manager.get_open_price("AAPL")
assert price == 148.0
def test_reset_open_prices(self):
manager = PollingPriceManager(api_key="test_key")
manager.open_prices["AAPL"] = 150.0
manager.reset_open_prices()
assert len(manager.open_prices) == 0
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
with caplog.at_level(logging.DEBUG):
for _ in range(3):
manager._fetch_prices()
assert manager._failure_counts["AAPL"] == 3
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(
manager,
"_fetch_quote",
side_effect=[
ValueError("temporary outage"),
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
],
):
with caplog.at_level(logging.INFO):
manager._fetch_prices()
manager._fetch_prices()
assert "AAPL" not in manager._failure_counts
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
class TestRetryChatModel:
@pytest.mark.asyncio
async def test_async_retry_recovers_from_disconnect(self):
attempts = {"count": 0}
class FakeAsyncModel:
model_name = "fake-async-model"
async def __call__(self, *args, **kwargs):
attempts["count"] += 1
if attempts["count"] < 2:
raise RuntimeError("Server disconnected")
return {"ok": True}
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
result = await wrapped("hello")
assert result == {"ok": True}
assert attempts["count"] == 2
class TestMarketService:
def test_init_mock_mode(self):
service = MarketService(
tickers=["AAPL", "MSFT"],
poll_interval=10,
mock_mode=True,
)
assert service.tickers == ["AAPL", "MSFT"]
assert service.poll_interval == 10
assert service.mock_mode is True
assert service.running is False
def test_init_real_mode(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=False,
api_key="test_key",
)
assert service.mock_mode is False
assert service.api_key == "test_key"
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
mock_mode=False,
)
service._start_real_mode()
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
mock_mode=False,
)
service._start_real_mode()
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
@pytest.mark.asyncio
async def test_start_mock_mode(self):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
mock_mode=True,
)
broadcast_func = AsyncMock()
await service.start(broadcast_func)
assert service.running is True
assert service._price_manager is not None
assert isinstance(service._price_manager, MockPriceManager)
service.stop()
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
@pytest.mark.asyncio
async def test_start_real_mode_without_api_key(self, _mock_sources):
service = MarketService(
tickers=["AAPL"],
mock_mode=False,
api_key=None,
)
broadcast_func = AsyncMock()
with pytest.raises(ValueError) as excinfo:
await service.start(broadcast_func)
assert "API key required" in str(excinfo.value)
@pytest.mark.asyncio
async def test_start_already_running(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
)
broadcast_func = AsyncMock()
await service.start(broadcast_func)
assert service.running is True
# Start again should not fail
await service.start(broadcast_func)
service.stop()
def test_stop(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
)
service.running = True
service._price_manager = MagicMock()
service.stop()
assert service.running is False
assert service._price_manager is None
def test_stop_when_not_running(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
)
# Should not raise
service.stop()
assert service.running is False
def test_get_price_sync(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
price = service.get_price_sync("AAPL")
assert price == 150.0
def test_get_price_sync_not_found(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
price = service.get_price_sync("MSFT")
assert price is None
def test_get_all_prices(self):
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
service.cache["AAPL"] = {"price": 150.0}
service.cache["MSFT"] = {"price": 400.0}
prices = service.get_all_prices()
assert prices["AAPL"] == 150.0
assert prices["MSFT"] == 400.0
@pytest.mark.asyncio
async def test_broadcast_price_update(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service._broadcast_func = AsyncMock()
price_data = {
"symbol": "AAPL",
"price": 150.0,
"open": 148.0,
"timestamp": 1234567890,
}
await service._broadcast_price_update(price_data)
service._broadcast_func.assert_called_once()
call_args = service._broadcast_func.call_args[0][0]
assert call_args["type"] == "price_update"
assert call_args["symbol"] == "AAPL"
assert call_args["price"] == 150.0
@pytest.mark.asyncio
async def test_broadcast_price_update_no_func(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service._broadcast_func = None
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
# Should not raise
await service._broadcast_price_update(price_data)
@pytest.mark.asyncio
async def test_price_callback_thread_safety(self):
service = MarketService(
tickers=["AAPL"],
poll_interval=1,
mock_mode=True,
)
received_prices = []
async def capture_broadcast(msg):
received_prices.append(msg)
await service.start(capture_broadcast)
# Wait for at least one price update
await asyncio.sleep(1.5)
service.stop()
# Should have received at least one price update
assert len(received_prices) >= 1
assert received_prices[0]["type"] == "price_update"
class TestMarketServiceIntegration:
@pytest.mark.asyncio
async def test_full_mock_cycle(self):
service = MarketService(
tickers=["AAPL", "MSFT"],
poll_interval=1,
mock_mode=True,
)
messages = []
async def collect_messages(msg):
messages.append(msg)
await service.start(collect_messages)
# Wait for price updates
await asyncio.sleep(2.5)
service.stop()
# Should have received multiple price updates
assert len(messages) >= 2
# Check message structure
symbols_seen = set()
for msg in messages:
assert msg["type"] == "price_update"
assert "symbol" in msg
assert "price" in msg
assert "ret" in msg
symbols_seen.add(msg["symbol"])
# Should have prices for both tickers
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
if __name__ == "__main__":
pytest.main([__file__, "-v"])