Files
evotraders/backend/tests/test_gateway_explain_handlers.py

385 lines
12 KiB
Python

# -*- coding: utf-8 -*-
import json
from types import SimpleNamespace
import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
class DummyWebSocket:
def __init__(self):
self.messages = []
async def send(self, payload: str):
self.messages.append(json.loads(payload))
class DummyStateSync:
def __init__(self, current_date="2026-03-16"):
self.state = {"current_date": current_date}
self.system_messages = []
def set_broadcast_fn(self, _fn):
return None
def update_state(self, *_args, **_kwargs):
return None
async def on_system_message(self, message):
self.system_messages.append(message)
class FakeMarketStore:
def __init__(self):
self.calls = []
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100):
self.calls.append(("get_news_items", symbol, start_date, end_date, limit))
return [
{
"id": "news-1",
"ticker": symbol,
"date": end_date,
"trade_date": end_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
}
]
def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit))
target_date = trade_date or end_date
return [
{
"id": "news-1",
"ticker": symbol,
"date": target_date,
"trade_date": target_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
"sentiment": "negative",
"relevance": "high",
"key_discussion": "Key discussion",
}
]
def get_news_by_ids_enriched(self, symbol, article_ids):
self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids)))
return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}]
def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit))
return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}}
def get_story_cache(self, symbol, *, as_of_date):
self.calls.append(("get_story_cache", symbol, as_of_date))
return None
def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"):
self.calls.append(("upsert_story_cache", symbol, as_of_date, source))
def delete_story_cache(self, symbol, *, as_of_date=None):
self.calls.append(("delete_story_cache", symbol, as_of_date))
return 1
def get_similar_day_cache(self, symbol, *, target_date):
self.calls.append(("get_similar_day_cache", symbol, target_date))
return None
def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"):
self.calls.append(("upsert_similar_day_cache", symbol, target_date, source))
def delete_similar_day_cache(self, symbol, *, target_date=None):
self.calls.append(("delete_similar_day_cache", symbol, target_date))
return 1
def get_ohlc(self, symbol, start_date, end_date):
self.calls.append(("get_ohlc", symbol, start_date, end_date))
return [
{"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103},
{"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107},
]
def make_gateway(market_store=None):
storage = SimpleNamespace(market_store=market_store or FakeMarketStore())
pipeline = SimpleNamespace(state_sync=None)
market_service = SimpleNamespace()
state_sync = DummyStateSync()
return Gateway(
market_service=market_service,
storage_service=storage,
pipeline=pipeline,
state_sync=state_sync,
config={"mode": "live"},
)
@pytest.mark.asyncio
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_timeline(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
]
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
]
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["macro"]["count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows):
return {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"news_count": len(news_rows),
}
monkeypatch.setattr(
gateway_module,
"build_range_explanation",
fake_build_range_explanation,
)
await gateway._handle_get_stock_range_explain(
websocket,
{"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
]
assert websocket.messages[-1] == {
"type": "stock_range_explain_loaded",
"ticker": "AAPL",
"result": {
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"news_count": 1,
},
}
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"build_range_explanation",
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
)
await gateway._handle_get_stock_range_explain(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"article_ids": ["news-99"],
},
)
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
assert websocket.messages[-1]["result"]["news_count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_for_date(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "limit": 10},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
]
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
assert websocket.messages[-1]["date"] == "2026-03-16"
@pytest.mark.asyncio
async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert "AAPL Story" in websocket.messages[-1]["story"]
@pytest.mark.asyncio
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_similar_days(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "top_k": 5},
)
assert websocket.messages[-1]["type"] == "stock_similar_days_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert isinstance(websocket.messages[-1]["items"], list)
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"force": True,
"rebuild_story": True,
"rebuild_similar_days": True,
"story_date": "2026-03-16",
"target_date": "2026-03-16",
},
)
assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls
assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert websocket.messages[-1]["stats"]["analyzed"] == 2
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"only_local_to_llm": True,
},
)
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"]
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["coro_name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task)
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway()
monkeypatch.setattr(
gateway_module,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
for symbol in symbols
],
)
await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1]