Add explain analysis workflow and UI

This commit is contained in:
2026-03-16 22:28:41 +08:00
parent 3a5558b576
commit 1f5ee3698e
49 changed files with 8888 additions and 1476 deletions

View File

@@ -17,6 +17,12 @@ from backend.config.bootstrap_config import (
update_bootstrap_values_for_run,
)
from backend.data.provider_utils import normalize_symbol
from backend.data.market_ingest import ingest_symbols
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.explain.range_explainer import build_range_explanation
from backend.explain.similarity_service import find_similar_days
from backend.explain.story_service import get_or_create_stock_story
from backend.utils.msg_adapter import FrontendAdapter
from backend.utils.terminal_dashboard import get_dashboard
from backend.core.pipeline import TradingPipeline
@@ -25,6 +31,7 @@ from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.data.provider_router import get_provider_router
from backend.tools.data_tools import get_prices
from backend.tools.data_tools import get_company_news
logger = logging.getLogger(__name__)
@@ -65,6 +72,7 @@ class Gateway:
self._backtest_end_date: Optional[str] = None
self._dashboard = get_dashboard()
self._market_status_task: Optional[asyncio.Task] = None
self._watchlist_ingest_task: Optional[asyncio.Task] = None
# Session tracking for live returns
self._session_start_portfolio_value: Optional[float] = None
@@ -182,6 +190,17 @@ class Gateway:
def state(self) -> Dict[str, Any]:
return self.state_sync.state
@staticmethod
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
if not rows:
return True
return all(
not row.get("sentiment")
and not row.get("relevance")
and not row.get("key_discussion")
for row in rows
)
async def handle_client(self, websocket: ServerConnection):
"""Handle WebSocket client connection"""
async with self.lock:
@@ -250,6 +269,22 @@ class Gateway:
await self._handle_get_stock_history(websocket, data)
elif msg_type == "get_stock_explain_events":
await self._handle_get_stock_explain_events(websocket, data)
elif msg_type == "get_stock_news":
await self._handle_get_stock_news(websocket, data)
elif msg_type == "get_stock_news_for_date":
await self._handle_get_stock_news_for_date(websocket, data)
elif msg_type == "get_stock_news_timeline":
await self._handle_get_stock_news_timeline(websocket, data)
elif msg_type == "get_stock_news_categories":
await self._handle_get_stock_news_categories(websocket, data)
elif msg_type == "get_stock_range_explain":
await self._handle_get_stock_range_explain(websocket, data)
elif msg_type == "get_stock_story":
await self._handle_get_stock_story(websocket, data)
elif msg_type == "get_stock_similar_days":
await self._handle_get_stock_similar_days(websocket, data)
elif msg_type == "run_stock_enrich":
await self._handle_run_stock_enrich(websocket, data)
except websockets.ConnectionClosed:
pass
@@ -298,20 +333,38 @@ class Gateway:
)
prices = await asyncio.to_thread(
get_prices,
self.storage.market_store.get_ohlc,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
source = "polygon"
if not prices:
prices = await asyncio.to_thread(
get_prices,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
self.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price.model_dump() for price in prices][-120:],
"prices": [
price if isinstance(price, dict) else price.model_dump()
for price in prices
][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
@@ -342,6 +395,636 @@ class Gateway:
),
)
async def _handle_get_stock_news(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
source = "polygon"
if self._news_rows_need_enrichment(news_rows):
news = await asyncio.to_thread(
get_company_news,
ticker,
end_date,
start_date,
limit,
)
if news:
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("company_news")
await asyncio.to_thread(
self.storage.market_store.upsert_news,
ticker,
[item.model_dump() for item in news],
source=source or "provider",
)
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
source = source or "market_store"
await websocket.send(
json.dumps(
{
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_for_date(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(
json.dumps(
{
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
},
ensure_ascii=False,
),
)
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
trade_date=trade_date,
limit=limit,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=trade_date,
end_date=trade_date,
limit=limit,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
trade_date=trade_date,
limit=limit,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": "market_store",
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_timeline(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
timeline = await asyncio.to_thread(
self.storage.market_store.get_news_timeline_enriched,
ticker,
start_date=start_date,
end_date=end_date,
)
if not timeline:
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
timeline = await asyncio.to_thread(
self.storage.market_store.get_news_timeline_enriched,
ticker,
start_date=start_date,
end_date=end_date,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_categories(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
categories = await asyncio.to_thread(
self.storage.market_store.get_news_categories_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_range_explain(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(
json.dumps(
{
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
},
ensure_ascii=False,
),
)
return
article_ids = data.get("article_ids")
if isinstance(article_ids, list) and article_ids:
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_by_ids_enriched,
ticker,
article_ids,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_by_ids_enriched,
ticker,
article_ids,
)
else:
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
if not news_rows:
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
result = await asyncio.to_thread(
build_range_explanation,
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_rows=news_rows,
)
await websocket.send(
json.dumps(
{
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_story(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
as_of_date = str(
data.get("as_of_date")
or self.state_sync.state.get("current_date")
or datetime.now().strftime("%Y-%m-%d")
).strip()[:10]
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
end_date=as_of_date,
limit=80,
)
result = await asyncio.to_thread(
get_or_create_stock_story,
self.storage.market_store,
symbol=ticker,
as_of_date=as_of_date,
)
await websocket.send(
json.dumps(
{
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_similar_days(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(
json.dumps(
{
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
},
ensure_ascii=False,
),
)
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
end_date=target_date,
limit=200,
)
result = await asyncio.to_thread(
find_similar_days,
self.storage.market_store,
symbol=ticker,
target_date=target_date,
top_k=top_k,
)
await websocket.send(
json.dumps(
{
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_run_stock_enrich(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
},
ensure_ascii=False,
),
)
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
},
ensure_ascii=False,
),
)
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(
self.storage.market_store.delete_story_cache,
ticker,
as_of_date=story_date,
)
story_result = await asyncio.to_thread(
get_or_create_stock_story,
self.storage.market_store,
symbol=ticker,
as_of_date=story_date,
)
story_status = {
"as_of_date": story_date,
"source": story_result.get("source") or "local",
}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(
self.storage.market_store.delete_similar_day_cache,
ticker,
target_date=target_date,
)
similar_result = await asyncio.to_thread(
find_similar_days,
self.storage.market_store,
symbol=ticker,
target_date=target_date,
top_k=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
@@ -410,6 +1093,7 @@ class Gateway:
},
)
await self._handle_reload_runtime_assets()
self._schedule_watchlist_market_store_refresh(tickers)
@staticmethod
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
@@ -538,6 +1222,48 @@ class Gateway:
trades=trades,
)
def _schedule_watchlist_market_store_refresh(
self,
tickers: List[str],
) -> None:
"""Kick off a non-blocking Polygon refresh for the updated watchlist."""
if not tickers:
return
if self._watchlist_ingest_task and not self._watchlist_ingest_task.done():
self._watchlist_ingest_task.cancel()
self._watchlist_ingest_task = asyncio.create_task(
self._refresh_market_store_for_watchlist(tickers),
)
async def _refresh_market_store_for_watchlist(
self,
tickers: List[str],
) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await self.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await self.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await self.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
@@ -896,4 +1622,6 @@ class Gateway:
self._backtest_task.cancel()
if self._market_status_task:
self._market_status_task.cancel()
if self._watchlist_ingest_task:
self._watchlist_ingest_task.cancel()
self._dashboard.stop()