Add explain analysis workflow and UI
This commit is contained in:
@@ -17,6 +17,12 @@ from backend.config.bootstrap_config import (
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.explain.range_explainer import build_range_explanation
|
||||
from backend.explain.similarity_service import find_similar_days
|
||||
from backend.explain.story_service import get_or_create_stock_story
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -25,6 +31,7 @@ from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
from backend.tools.data_tools import get_company_news
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,6 +72,7 @@ class Gateway:
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._dashboard = get_dashboard()
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
self._watchlist_ingest_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
@@ -182,6 +190,17 @@ class Gateway:
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
@staticmethod
|
||||
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
|
||||
if not rows:
|
||||
return True
|
||||
return all(
|
||||
not row.get("sentiment")
|
||||
and not row.get("relevance")
|
||||
and not row.get("key_discussion")
|
||||
for row in rows
|
||||
)
|
||||
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
@@ -250,6 +269,22 @@ class Gateway:
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
elif msg_type == "get_stock_news":
|
||||
await self._handle_get_stock_news(websocket, data)
|
||||
elif msg_type == "get_stock_news_for_date":
|
||||
await self._handle_get_stock_news_for_date(websocket, data)
|
||||
elif msg_type == "get_stock_news_timeline":
|
||||
await self._handle_get_stock_news_timeline(websocket, data)
|
||||
elif msg_type == "get_stock_news_categories":
|
||||
await self._handle_get_stock_news_categories(websocket, data)
|
||||
elif msg_type == "get_stock_range_explain":
|
||||
await self._handle_get_stock_range_explain(websocket, data)
|
||||
elif msg_type == "get_stock_story":
|
||||
await self._handle_get_stock_story(websocket, data)
|
||||
elif msg_type == "get_stock_similar_days":
|
||||
await self._handle_get_stock_similar_days(websocket, data)
|
||||
elif msg_type == "run_stock_enrich":
|
||||
await self._handle_run_stock_enrich(websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
@@ -298,20 +333,38 @@ class Gateway:
|
||||
)
|
||||
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
self.storage.market_store.get_ohlc,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
source = "polygon"
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.upsert_ohlc,
|
||||
ticker,
|
||||
[price.model_dump() for price in prices],
|
||||
source=source or "provider",
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price.model_dump() for price in prices][-120:],
|
||||
"prices": [
|
||||
price if isinstance(price, dict) else price.model_dump()
|
||||
for price in prices
|
||||
][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
@@ -342,6 +395,636 @@ class Gateway:
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": "",
|
||||
"news": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 30)
|
||||
limit = data.get("limit", 12)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 180))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 30
|
||||
try:
|
||||
limit = max(1, min(int(limit), 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 12
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
source = "polygon"
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
news = await asyncio.to_thread(
|
||||
get_company_news,
|
||||
ticker,
|
||||
end_date,
|
||||
start_date,
|
||||
limit,
|
||||
)
|
||||
if news:
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("company_news")
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.upsert_news,
|
||||
ticker,
|
||||
[item.model_dump() for item in news],
|
||||
source=source or "provider",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=max(limit, 50),
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
source = source or "market_store"
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": ticker,
|
||||
"news": news_rows[-limit:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_for_date(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
trade_date = str(data.get("date") or "").strip()
|
||||
if not ticker or not trade_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": [],
|
||||
"error": "ticker and date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
limit = data.get("limit", 20)
|
||||
try:
|
||||
limit = max(1, min(int(limit), 50))
|
||||
except (TypeError, ValueError):
|
||||
limit = 20
|
||||
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=trade_date,
|
||||
end_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": news_rows,
|
||||
"source": "market_store",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_timeline(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": "",
|
||||
"timeline": [],
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
timeline = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_timeline_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not timeline:
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
timeline = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_timeline_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_categories(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": "",
|
||||
"categories": {},
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
categories = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_categories_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": ticker,
|
||||
"categories": categories,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_range_explain(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()
|
||||
end_date = str(data.get("end_date") or "").strip()
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": {"error": "ticker, start_date, end_date are required"},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
article_ids = data.get("article_ids")
|
||||
if isinstance(article_ids, list) and article_ids:
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_by_ids_enriched,
|
||||
ticker,
|
||||
article_ids,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_by_ids_enriched,
|
||||
ticker,
|
||||
article_ids,
|
||||
)
|
||||
else:
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
if not news_rows:
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
build_range_explanation,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_story(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": "",
|
||||
"story": "",
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
as_of_date = str(
|
||||
data.get("as_of_date")
|
||||
or self.state_sync.state.get("current_date")
|
||||
or datetime.now().strftime("%Y-%m-%d")
|
||||
).strip()[:10]
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
end_date=as_of_date,
|
||||
limit=80,
|
||||
)
|
||||
result = await asyncio.to_thread(
|
||||
get_or_create_stock_story,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": ticker,
|
||||
"as_of_date": as_of_date,
|
||||
"story": result.get("story") or "",
|
||||
"source": result.get("source") or "local",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_similar_days(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
target_date = str(data.get("date") or "").strip()[:10]
|
||||
if not ticker or not target_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
"items": [],
|
||||
"error": "ticker and date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
top_k = data.get("top_k", 8)
|
||||
try:
|
||||
top_k = max(1, min(int(top_k), 20))
|
||||
except (TypeError, ValueError):
|
||||
top_k = 8
|
||||
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
end_date=target_date,
|
||||
limit=200,
|
||||
)
|
||||
result = await asyncio.to_thread(
|
||||
find_similar_days,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
target_date=target_date,
|
||||
top_k=top_k,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
**result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_run_stock_enrich(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
end_date = str(data.get("end_date") or "").strip()[:10]
|
||||
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
||||
target_date = str(data.get("target_date") or "").strip()[:10]
|
||||
force = bool(data.get("force", False))
|
||||
rebuild_story = bool(data.get("rebuild_story", True))
|
||||
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
||||
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
||||
limit = data.get("limit", 200)
|
||||
|
||||
try:
|
||||
limit = max(10, min(int(limit), 500))
|
||||
except (TypeError, ValueError):
|
||||
limit = 200
|
||||
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "ticker, start_date, end_date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if only_local_to_llm and not llm_enrichment_enabled():
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
skip_existing=not force,
|
||||
only_reanalyze_local=only_local_to_llm,
|
||||
)
|
||||
|
||||
story_status = None
|
||||
if rebuild_story and story_date:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.delete_story_cache,
|
||||
ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_result = await asyncio.to_thread(
|
||||
get_or_create_stock_story,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_status = {
|
||||
"as_of_date": story_date,
|
||||
"source": story_result.get("source") or "local",
|
||||
}
|
||||
|
||||
similar_status = None
|
||||
if rebuild_similar_days and target_date:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.delete_similar_day_cache,
|
||||
ticker,
|
||||
target_date=target_date,
|
||||
)
|
||||
similar_result = await asyncio.to_thread(
|
||||
find_similar_days,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
target_date=target_date,
|
||||
top_k=8,
|
||||
)
|
||||
similar_status = {
|
||||
"target_date": target_date,
|
||||
"count": len(similar_result.get("items") or []),
|
||||
"error": similar_result.get("error"),
|
||||
}
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"story_date": story_date or None,
|
||||
"target_date": target_date or None,
|
||||
"force": force,
|
||||
"only_local_to_llm": only_local_to_llm,
|
||||
"stats": result,
|
||||
"story_status": story_status,
|
||||
"similar_status": similar_status,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
@@ -410,6 +1093,7 @@ class Gateway:
|
||||
},
|
||||
)
|
||||
await self._handle_reload_runtime_assets()
|
||||
self._schedule_watchlist_market_store_refresh(tickers)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
@@ -538,6 +1222,48 @@ class Gateway:
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
def _schedule_watchlist_market_store_refresh(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
"""Kick off a non-blocking Polygon refresh for the updated watchlist."""
|
||||
if not tickers:
|
||||
return
|
||||
if self._watchlist_ingest_task and not self._watchlist_ingest_task.done():
|
||||
self._watchlist_ingest_task.cancel()
|
||||
self._watchlist_ingest_task = asyncio.create_task(
|
||||
self._refresh_market_store_for_watchlist(tickers),
|
||||
)
|
||||
|
||||
async def _refresh_market_store_for_watchlist(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
"""Refresh the long-lived market store after a watchlist update."""
|
||||
try:
|
||||
await self.state_sync.on_system_message(
|
||||
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
ingest_symbols,
|
||||
tickers,
|
||||
mode="incremental",
|
||||
)
|
||||
summary = ", ".join(
|
||||
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||
for item in results
|
||||
)
|
||||
await self.state_sync.on_system_message(
|
||||
f"自选股市场数据已同步: {summary}",
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||
await self.state_sync.on_system_message(
|
||||
f"自选股市场数据同步失败: {exc}",
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
@@ -896,4 +1622,6 @@ class Gateway:
|
||||
self._backtest_task.cancel()
|
||||
if self._market_status_task:
|
||||
self._market_status_task.cancel()
|
||||
if self._watchlist_ingest_task:
|
||||
self._watchlist_ingest_task.cancel()
|
||||
self._dashboard.stop()
|
||||
|
||||
@@ -65,6 +65,18 @@ class MarketService:
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
def get_live_quote_provider(self) -> Optional[str]:
|
||||
"""Return the active live quote provider for UI/debugging."""
|
||||
if self.backtest_mode:
|
||||
return "backtest"
|
||||
if self.mock_mode:
|
||||
return "mock"
|
||||
if self._price_manager and hasattr(self._price_manager, "provider"):
|
||||
provider = getattr(self._price_manager, "provider", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
return provider.strip().lower()
|
||||
return None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
@@ -532,6 +544,7 @@ class MarketService:
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
@@ -544,6 +557,7 @@ class MarketService:
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
@@ -553,6 +567,7 @@ class MarketService:
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
@@ -563,6 +578,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
@@ -571,6 +587,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
@@ -579,6 +596,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
|
||||
280
backend/services/research_db.py
Normal file
280
backend/services/research_db.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Query-oriented storage for explain/research data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from backend.data.schema import CompanyNews
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS news_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
published_at TEXT,
|
||||
trade_date TEXT,
|
||||
source TEXT,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT,
|
||||
url TEXT,
|
||||
related TEXT,
|
||||
category TEXT,
|
||||
raw_json TEXT NOT NULL,
|
||||
ingest_run_date TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
|
||||
ON news_items (ticker, trade_date DESC, published_at DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
|
||||
base = item.url or item.title or f"{ticker}-{fallback_index}"
|
||||
return f"{ticker}:{base}"
|
||||
|
||||
|
||||
def _resolve_trade_date(date_value: str | None) -> str | None:
|
||||
if not date_value:
|
||||
return None
|
||||
normalized = str(date_value).strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" in normalized:
|
||||
return normalized.split("T", 1)[0]
|
||||
if " " in normalized:
|
||||
return normalized.split(" ", 1)[0]
|
||||
return normalized[:10]
|
||||
|
||||
|
||||
class ResearchDb:
|
||||
"""Small SQLite helper for explain-oriented news storage."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def upsert_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
items: Iterable[CompanyNews],
|
||||
ingest_run_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Persist provider news and return normalized rows."""
|
||||
normalized_rows: list[dict[str, Any]] = []
|
||||
timestamp = datetime.utcnow().isoformat(timespec="seconds")
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return normalized_rows
|
||||
|
||||
with self._connect() as conn:
|
||||
for index, item in enumerate(items):
|
||||
news_id = _resolve_news_id(symbol, item, index)
|
||||
trade_date = _resolve_trade_date(item.date)
|
||||
payload = item.model_dump()
|
||||
row = {
|
||||
"id": news_id,
|
||||
"ticker": symbol,
|
||||
"published_at": item.date,
|
||||
"trade_date": trade_date,
|
||||
"source": item.source,
|
||||
"title": item.title,
|
||||
"summary": item.summary,
|
||||
"url": item.url,
|
||||
"related": item.related,
|
||||
"category": item.category,
|
||||
"raw_json": _json_dumps(payload),
|
||||
"ingest_run_date": ingest_run_date,
|
||||
"created_at": timestamp,
|
||||
}
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO news_items
|
||||
(id, ticker, published_at, trade_date, source, title, summary, url,
|
||||
related, category, raw_json, ingest_run_date, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
ticker = excluded.ticker,
|
||||
published_at = excluded.published_at,
|
||||
trade_date = excluded.trade_date,
|
||||
source = excluded.source,
|
||||
title = excluded.title,
|
||||
summary = excluded.summary,
|
||||
url = excluded.url,
|
||||
related = excluded.related,
|
||||
category = excluded.category,
|
||||
raw_json = excluded.raw_json,
|
||||
ingest_run_date = excluded.ingest_run_date
|
||||
""",
|
||||
(
|
||||
row["id"],
|
||||
row["ticker"],
|
||||
row["published_at"],
|
||||
row["trade_date"],
|
||||
row["source"],
|
||||
row["title"],
|
||||
row["summary"],
|
||||
row["url"],
|
||||
row["related"],
|
||||
row["category"],
|
||||
row["raw_json"],
|
||||
row["ingest_run_date"],
|
||||
row["created_at"],
|
||||
),
|
||||
)
|
||||
normalized_rows.append(row)
|
||||
return normalized_rows
|
||||
|
||||
def get_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return normalized news rows for explain UI."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_news_timeline(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Aggregate news counts per trade date for chart markers."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
|
||||
COUNT(*) AS count,
|
||||
COUNT(DISTINCT source) AS source_count,
|
||||
MAX(title) AS top_title
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += """
|
||||
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
|
||||
ORDER BY date ASC
|
||||
"""
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"date": row["date"],
|
||||
"count": int(row["count"] or 0),
|
||||
"source_count": int(row["source_count"] or 0),
|
||||
"top_title": row["top_title"] or "",
|
||||
}
|
||||
for row in rows
|
||||
if row["date"]
|
||||
]
|
||||
|
||||
def get_news_by_ids(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
article_ids: Iterable[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return selected persisted news items."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
|
||||
if not symbol or not ids:
|
||||
return []
|
||||
|
||||
placeholders = ",".join("?" for _ in ids)
|
||||
sql = f"""
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ? AND id IN ({placeholders})
|
||||
ORDER BY COALESCE(published_at, trade_date) DESC
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, [symbol, *ids]).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
@@ -10,6 +10,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from .research_db import ResearchDb
|
||||
from .runtime_db import RuntimeDb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +66,8 @@ class StorageService:
|
||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.server_state_file = self.state_dir / "server_state.json"
|
||||
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||
self.research_db = ResearchDb(self.state_dir / "research.db")
|
||||
self.market_store = MarketStore()
|
||||
|
||||
# Feed history (for agent messages)
|
||||
self.max_feed_history = 200
|
||||
|
||||
Reference in New Issue
Block a user