Add explain analysis workflow and UI
This commit is contained in:
420
backend/cli.py
420
backend/cli.py
@@ -12,7 +12,7 @@ import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
@@ -21,18 +21,27 @@ import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.table import Table
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled
|
||||
from backend.enrich.news_enricher import enrich_symbols
|
||||
|
||||
app = typer.Typer(
|
||||
name="evotraders",
|
||||
help="EvoTraders: A self-evolving multi-agent trading system",
|
||||
add_completion=False,
|
||||
)
|
||||
ingest_app = typer.Typer(help="Ingest Polygon market data into the research warehouse.")
|
||||
app.add_typer(ingest_app, name="ingest")
|
||||
|
||||
console = Console()
|
||||
_prompt_loader = PromptLoader()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
@@ -204,6 +213,189 @@ def initialize_workspace(config_name: str) -> Path:
|
||||
return workspace_manager.get_run_dir(config_name)
|
||||
|
||||
|
||||
def _resolve_symbols(raw_tickers: Optional[str], config_name: Optional[str] = None) -> list[str]:
|
||||
"""Resolve symbols from explicit input or runtime bootstrap config."""
|
||||
if raw_tickers and raw_tickers.strip():
|
||||
return [
|
||||
item.strip().upper()
|
||||
for item in raw_tickers.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
workspace_manager = WorkspaceManager(project_root=get_project_root())
|
||||
bootstrap_path = workspace_manager.get_run_dir(config_name or "default") / "BOOTSTRAP.md"
|
||||
if bootstrap_path.exists():
|
||||
content = bootstrap_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
if line.strip().startswith("tickers:"):
|
||||
raw = line.split(":", 1)[1]
|
||||
return [
|
||||
item.strip().upper()
|
||||
for item in raw.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _filter_problematic_report_rows(rows: list[dict]) -> list[dict]:
|
||||
"""Keep tickers with incomplete coverage or without any LLM-enriched rows."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if float(row.get("coverage_pct") or 0.0) < 100.0
|
||||
or int(row.get("llm_count") or 0) == 0
|
||||
]
|
||||
|
||||
|
||||
def auto_update_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
end_date: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Refresh the long-lived Polygon market store for the active watchlist."""
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
console.print(
|
||||
"[dim]Skipping Polygon market store update: POLYGON_API_KEY not set[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping Polygon market store update: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
target_end = end_date or datetime.now().date().isoformat()
|
||||
console.print(
|
||||
f"[cyan]Updating Polygon market store for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||
)
|
||||
|
||||
try:
|
||||
results = ingest_symbols(
|
||||
symbols,
|
||||
mode="incremental",
|
||||
end_date=target_end,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Polygon market store update failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
def auto_prepare_backtest_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> None:
|
||||
"""Ensure the market store has the requested backtest window for the active watchlist."""
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
console.print(
|
||||
"[dim]Skipping Polygon backtest preload: POLYGON_API_KEY not set[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping Polygon backtest preload: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"[cyan]Preparing Polygon market store for backtest {start_date} -> {end_date} "
|
||||
f"({', '.join(symbols)})[/cyan]",
|
||||
)
|
||||
|
||||
try:
|
||||
results = ingest_symbols(
|
||||
symbols,
|
||||
mode="full",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Polygon backtest preload failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
def auto_enrich_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
end_date: Optional[str] = None,
|
||||
lookback_days: int = 120,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Refresh explain-oriented enriched news for the active watchlist."""
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping explain enrich: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
target_end = end_date or datetime.now().date().isoformat()
|
||||
try:
|
||||
end_dt = datetime.strptime(target_end, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
console.print(
|
||||
f"[yellow]Skipping explain enrich: invalid end date {target_end}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
start_date = (end_dt - timedelta(days=max(1, lookback_days))).date().isoformat()
|
||||
console.print(
|
||||
f"[cyan]Refreshing explain enrich for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||
)
|
||||
store = MarketStore()
|
||||
try:
|
||||
results = enrich_symbols(
|
||||
store,
|
||||
symbols,
|
||||
start_date=start_date,
|
||||
end_date=target_end,
|
||||
limit=300,
|
||||
skip_existing=not force,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Explain enrich failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||
f"llm={result['llm_count']} local={result['local_count']}"
|
||||
)
|
||||
|
||||
|
||||
@app.command("init-workspace")
|
||||
def init_workspace(
|
||||
config_name: str = typer.Option(
|
||||
@@ -223,6 +415,213 @@ def init_workspace(
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("full")
|
||||
def ingest_full(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to ingest",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Start date for full ingestion (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="End date for ingestion (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run full Polygon ingestion for the specified symbols."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting full Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||
results = ingest_symbols(symbols, mode="full", start_date=start, end_date=end)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("update")
|
||||
def ingest_update(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to update",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date override (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run incremental Polygon ingestion using stored watermarks."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting incremental Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||
results = ingest_symbols(symbols, mode="incremental", end_date=end)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("enrich")
|
||||
def ingest_enrich(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to enrich",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Optional start date for enrichment window (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date for enrichment window (YYYY-MM-DD)",
|
||||
),
|
||||
limit: int = typer.Option(
|
||||
300,
|
||||
"--limit",
|
||||
help="Maximum raw news rows per ticker to analyze",
|
||||
),
|
||||
force: bool = typer.Option(
|
||||
False,
|
||||
"--force",
|
||||
help="Re-analyze already enriched news instead of only missing rows",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run explain-oriented news enrichment for symbols already in the market store."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting explain enrich for {', '.join(symbols)}[/cyan]")
|
||||
store = MarketStore()
|
||||
results = enrich_symbols(
|
||||
store,
|
||||
symbols,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
limit=max(10, limit),
|
||||
skip_existing=not force,
|
||||
)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] "
|
||||
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||
f"llm={result['llm_count']} local={result['local_count']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("report")
|
||||
def ingest_report(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Optional comma-separated tickers to report",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Optional start date for report window (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date for report window (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
only_problematic: bool = typer.Option(
|
||||
False,
|
||||
"--only-problematic",
|
||||
help="Only show tickers with incomplete coverage or no LLM-enriched news",
|
||||
),
|
||||
):
|
||||
"""Show explain enrichment coverage and freshness per ticker."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
store = MarketStore()
|
||||
report_rows = store.get_enrich_report(
|
||||
symbols=symbols or None,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
)
|
||||
if only_problematic:
|
||||
report_rows = _filter_problematic_report_rows(report_rows)
|
||||
if not report_rows:
|
||||
if only_problematic:
|
||||
console.print("[green]No problematic enrich report rows found for the requested scope[/green]")
|
||||
else:
|
||||
console.print("[yellow]No enrich report rows found for the requested scope[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
model_info = get_explain_model_info()
|
||||
model_label = model_info["label"] if llm_enrichment_enabled() else "disabled"
|
||||
table = Table(title="Explain Enrichment Report")
|
||||
table.add_column("Ticker", style="cyan")
|
||||
table.add_column("Raw News", justify="right")
|
||||
table.add_column("Analyzed", justify="right")
|
||||
table.add_column("Coverage", justify="right")
|
||||
table.add_column("LLM", justify="right")
|
||||
table.add_column("Local", justify="right")
|
||||
table.add_column("Latest Trade Date")
|
||||
table.add_column("Latest Analysis")
|
||||
table.caption = f"Explain LLM: {model_label}"
|
||||
|
||||
for row in report_rows:
|
||||
table.add_row(
|
||||
row["symbol"],
|
||||
str(row["raw_news_count"]),
|
||||
str(row["analyzed_news_count"]),
|
||||
f'{row["coverage_pct"]:.1f}%',
|
||||
str(row["llm_count"]),
|
||||
str(row["local_count"]),
|
||||
str(row["latest_trade_date"] or "-"),
|
||||
str(row["latest_analysis_at"] or "-"),
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
start: Optional[str] = typer.Option(
|
||||
@@ -331,6 +730,16 @@ def backtest(
|
||||
|
||||
# Run data updater
|
||||
run_data_updater(project_root)
|
||||
auto_prepare_backtest_market_store(
|
||||
config_name,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=end,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
@@ -514,6 +923,15 @@ def live(
|
||||
# Data update (if not mock mode)
|
||||
if not mock:
|
||||
run_data_updater(project_root)
|
||||
auto_update_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
force=False,
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||
|
||||
@@ -47,6 +47,10 @@ class StateSync:
|
||||
"""Set current simulation date for backtest-compatible timestamps"""
|
||||
self._simulation_date = date
|
||||
|
||||
def clear_simulation_date(self):
|
||||
"""Disable backtest timestamp simulation and use wall-clock time."""
|
||||
self._simulation_date = None
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""
|
||||
Get timestamp in milliseconds.
|
||||
@@ -97,9 +101,21 @@ class StateSync:
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
# Ensure timestamp exists (use simulation date if in backtest mode)
|
||||
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
|
||||
# frontend displays local wall time correctly instead of date-only UTC.
|
||||
if "timestamp" not in event:
|
||||
if self._simulation_date:
|
||||
ts_ms = event.get("ts")
|
||||
if ts_ms is not None:
|
||||
try:
|
||||
event["timestamp"] = datetime.fromtimestamp(
|
||||
float(ts_ms) / 1000.0,
|
||||
).isoformat()
|
||||
except (TypeError, ValueError, OSError):
|
||||
if self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
elif self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
@@ -238,9 +254,12 @@ class StateSync:
|
||||
"""Called at start of trading cycle"""
|
||||
self._state["current_date"] = date
|
||||
self._state["status"] = "running"
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
if self._state.get("server_mode") == "backtest":
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
else:
|
||||
self.clear_simulation_date()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.data.provider_router import get_provider_router
|
||||
|
||||
@@ -26,6 +27,7 @@ class HistoricalPriceManager:
|
||||
self.close_prices = {}
|
||||
self.running = False
|
||||
self._router = get_provider_router()
|
||||
self._market_store = MarketStore()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
@@ -58,21 +60,48 @@ class HistoricalPriceManager:
|
||||
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_market_db(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""Load price data from the long-lived market research database."""
|
||||
try:
|
||||
rows = self._market_store.get_ohlc(symbol, start_date, end_date)
|
||||
if not rows:
|
||||
return None
|
||||
df = pd.DataFrame(rows)
|
||||
if df.empty or "date" not in df.columns:
|
||||
return None
|
||||
df["Date"] = pd.to_datetime(df["date"])
|
||||
df.set_index("Date", inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load market DB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def preload_data(self, start_date: str, end_date: str):
|
||||
"""Preload historical data from local CSV files."""
|
||||
"""Preload historical data from market DB first, then local CSV."""
|
||||
logger.info(f"Preloading data: {start_date} to {end_date}")
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
if symbol in self._price_cache:
|
||||
continue
|
||||
|
||||
# Load from local CSV file directly
|
||||
df = self._load_from_market_db(symbol, start_date, end_date)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from market DB: {len(df)} records")
|
||||
continue
|
||||
|
||||
df = self._load_from_csv(symbol)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
||||
else:
|
||||
logger.warning(f"No CSV data for {symbol}")
|
||||
logger.warning(f"No market DB or CSV data for {symbol}")
|
||||
|
||||
def set_date(self, date: str):
|
||||
"""Set current trading date and update prices"""
|
||||
|
||||
149
backend/data/market_ingest.py
Normal file
149
backend/data/market_ingest.py
Normal file
@@ -0,0 +1,149 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Ingest Polygon market data into the long-lived research warehouse."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.news_alignment import align_news_for_symbol
|
||||
from backend.data.polygon_client import (
|
||||
fetch_news,
|
||||
fetch_ohlc,
|
||||
fetch_ticker_details,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
|
||||
def _today_utc() -> str:
|
||||
return datetime.now(timezone.utc).date().isoformat()
|
||||
|
||||
|
||||
def _default_start(years: int = 2) -> str:
|
||||
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
|
||||
|
||||
|
||||
def ingest_ticker_history(
|
||||
symbol: str,
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Fetch and persist Polygon OHLC + news for a ticker."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
start = start_date or _default_start()
|
||||
end = end_date or _today_utc()
|
||||
market_store = store or MarketStore()
|
||||
|
||||
details = fetch_ticker_details(ticker)
|
||||
market_store.upsert_ticker(
|
||||
symbol=ticker,
|
||||
name=details.get("name"),
|
||||
sector=details.get("sic_description"),
|
||||
is_active=bool(details.get("active", True)),
|
||||
)
|
||||
|
||||
ohlc_rows = fetch_ohlc(ticker, start, end)
|
||||
news_rows = fetch_news(ticker, start, end)
|
||||
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon")
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source="polygon")
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(symbol=ticker, price_date=end, news_date=end)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start,
|
||||
"end_date": end,
|
||||
"prices": price_count,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
}
|
||||
|
||||
|
||||
def update_ticker_incremental(
|
||||
symbol: str,
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Incrementally fetch OHLC + news since the last watermark."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
market_store = store or MarketStore()
|
||||
watermarks = market_store.get_ticker_watermarks(ticker)
|
||||
end = end_date or _today_utc()
|
||||
start_prices = (
|
||||
(datetime.fromisoformat(watermarks["last_price_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||
if watermarks.get("last_price_fetch")
|
||||
else _default_start()
|
||||
)
|
||||
start_news = (
|
||||
(datetime.fromisoformat(watermarks["last_news_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||
if watermarks.get("last_news_fetch")
|
||||
else _default_start()
|
||||
)
|
||||
|
||||
details = fetch_ticker_details(ticker)
|
||||
market_store.upsert_ticker(
|
||||
symbol=ticker,
|
||||
name=details.get("name"),
|
||||
sector=details.get("sic_description"),
|
||||
is_active=bool(details.get("active", True)),
|
||||
)
|
||||
|
||||
ohlc_rows = [] if start_prices > end else fetch_ohlc(ticker, start_prices, end)
|
||||
news_rows = [] if start_news > end else fetch_news(ticker, start_news, end)
|
||||
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon") if ohlc_rows else 0
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source="polygon") if news_rows else 0
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(
|
||||
symbol=ticker,
|
||||
price_date=end if ohlc_rows or watermarks.get("last_price_fetch") else None,
|
||||
news_date=end if news_rows or watermarks.get("last_news_fetch") else None,
|
||||
)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_price_date": start_prices,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"prices": price_count,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
}
|
||||
|
||||
|
||||
def ingest_symbols(
|
||||
symbols: Iterable[str],
|
||||
*,
|
||||
mode: str = "incremental",
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> list[dict]:
|
||||
"""Fetch Polygon data for a list of tickers."""
|
||||
market_store = store or MarketStore()
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
ticker = normalize_symbol(symbol)
|
||||
if not ticker:
|
||||
continue
|
||||
if mode == "full":
|
||||
results.append(
|
||||
ingest_ticker_history(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
update_ticker_incremental(
|
||||
ticker,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
return results
|
||||
1074
backend/data/market_store.py
Normal file
1074
backend/data/market_store.py
Normal file
File diff suppressed because it is too large
Load Diff
64
backend/data/news_alignment.py
Normal file
64
backend/data/news_alignment.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Align persisted news to the nearest NYSE trading date."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import time
|
||||
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
def _next_trading_day(date_str: str) -> str:
|
||||
start = pd.Timestamp(date_str).tz_localize(None)
|
||||
sessions = NYSE_CALENDAR.valid_days(
|
||||
start_date=(start - pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
end_date=(start + pd.Timedelta(days=10)).strftime("%Y-%m-%d"),
|
||||
)
|
||||
future = [
|
||||
pd.Timestamp(day).tz_localize(None).strftime("%Y-%m-%d")
|
||||
for day in sessions
|
||||
if pd.Timestamp(day).tz_localize(None) >= start
|
||||
]
|
||||
return future[0] if future else date_str
|
||||
|
||||
|
||||
def resolve_trade_date(published_utc: str | None) -> str | None:
|
||||
"""Map a published timestamp to an NYSE trade date."""
|
||||
if not published_utc:
|
||||
return None
|
||||
timestamp = pd.to_datetime(published_utc, utc=True, errors="coerce")
|
||||
if pd.isna(timestamp):
|
||||
return None
|
||||
nyse_time = timestamp.tz_convert("America/New_York")
|
||||
candidate = nyse_time.date().isoformat()
|
||||
valid_days = NYSE_CALENDAR.valid_days(start_date=candidate, end_date=candidate)
|
||||
if len(valid_days) == 0:
|
||||
return _next_trading_day(candidate)
|
||||
if nyse_time.time() >= time(16, 0):
|
||||
return _next_trading_day((nyse_time + pd.Timedelta(days=1)).date().isoformat())
|
||||
return candidate
|
||||
|
||||
|
||||
def align_news_for_symbol(store: MarketStore, symbol: str, *, limit: int = 5000) -> int:
|
||||
"""Fill missing trade_date values for one ticker."""
|
||||
pending = store.get_news_without_trade_date(symbol, limit=limit)
|
||||
updates = []
|
||||
for row in pending:
|
||||
trade_date = resolve_trade_date(row.get("published_utc"))
|
||||
if trade_date:
|
||||
updates.append(
|
||||
{
|
||||
"news_id": row["news_id"],
|
||||
"symbol": row["symbol"],
|
||||
"trade_date": trade_date,
|
||||
}
|
||||
)
|
||||
if not updates:
|
||||
return 0
|
||||
return store.set_trade_dates(updates)
|
||||
161
backend/data/polygon_client.py
Normal file
161
backend/data/polygon_client.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Polygon client used for long-lived market research ingestion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
BASE = "https://api.polygon.io"
|
||||
|
||||
|
||||
def _headers() -> dict[str, str]:
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
raise ValueError("Missing required API key: POLYGON_API_KEY")
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
|
||||
def http_get(
|
||||
url: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
max_retries: int = 8,
|
||||
backoff: float = 2.0,
|
||||
) -> requests.Response:
|
||||
"""HTTP GET with exponential backoff and 429 handling."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
params=params or {},
|
||||
headers=_headers(),
|
||||
timeout=30,
|
||||
)
|
||||
except requests.RequestException:
|
||||
time.sleep((backoff**attempt) + 0.5)
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
continue
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
wait = (
|
||||
float(retry_after)
|
||||
if retry_after and retry_after.isdigit()
|
||||
else min((backoff**attempt) + 1.0, 60.0)
|
||||
)
|
||||
time.sleep(wait)
|
||||
if attempt == max_retries - 1:
|
||||
response.raise_for_status()
|
||||
continue
|
||||
|
||||
if 500 <= response.status_code < 600:
|
||||
time.sleep(min((backoff**attempt) + 1.0, 60.0))
|
||||
if attempt == max_retries - 1:
|
||||
response.raise_for_status()
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
raise RuntimeError("Unreachable")
|
||||
|
||||
|
||||
def fetch_ticker_details(symbol: str) -> dict[str, Any]:
|
||||
"""Fetch company metadata from Polygon."""
|
||||
response = http_get(f"{BASE}/v3/reference/tickers/{symbol}")
|
||||
return response.json().get("results", {}) or {}
|
||||
|
||||
|
||||
def fetch_ohlc(symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
|
||||
"""Fetch daily OHLC data from Polygon."""
|
||||
response = http_get(
|
||||
f"{BASE}/v2/aggs/ticker/{symbol}/range/1/day/{start_date}/{end_date}",
|
||||
params={"adjusted": "true", "sort": "asc", "limit": 50000},
|
||||
)
|
||||
results = response.json().get("results") or []
|
||||
rows: list[dict[str, Any]] = []
|
||||
for item in results:
|
||||
rows.append(
|
||||
{
|
||||
"date": datetime.fromtimestamp(
|
||||
int(item["t"]) / 1000,
|
||||
tz=timezone.utc,
|
||||
).date().isoformat(),
|
||||
"open": item.get("o"),
|
||||
"high": item.get("h"),
|
||||
"low": item.get("l"),
|
||||
"close": item.get("c"),
|
||||
"volume": item.get("v"),
|
||||
"vwap": item.get("vw"),
|
||||
"transactions": item.get("n"),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def fetch_news(
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
*,
|
||||
per_page: int = 50,
|
||||
page_sleep: float = 1.2,
|
||||
max_pages: Optional[int] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch all Polygon news for a ticker, with pagination."""
|
||||
url = f"{BASE}/v2/reference/news"
|
||||
params = {
|
||||
"ticker": symbol,
|
||||
"published_utc.gte": start_date,
|
||||
"published_utc.lte": end_date,
|
||||
"limit": per_page,
|
||||
"order": "asc",
|
||||
}
|
||||
next_url: Optional[str] = None
|
||||
pages = 0
|
||||
all_articles: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
while True:
|
||||
response = http_get(next_url or url, params=None if next_url else params)
|
||||
data = response.json()
|
||||
results = data.get("results") or []
|
||||
if not results:
|
||||
break
|
||||
|
||||
for item in results:
|
||||
article_id = item.get("id")
|
||||
if article_id and article_id in seen_ids:
|
||||
continue
|
||||
all_articles.append(
|
||||
{
|
||||
"id": article_id,
|
||||
"publisher": (item.get("publisher") or {}).get("name"),
|
||||
"title": item.get("title"),
|
||||
"author": item.get("author"),
|
||||
"published_utc": item.get("published_utc"),
|
||||
"amp_url": item.get("amp_url"),
|
||||
"article_url": item.get("article_url"),
|
||||
"tickers": item.get("tickers"),
|
||||
"description": item.get("description"),
|
||||
"insights": item.get("insights"),
|
||||
}
|
||||
)
|
||||
if article_id:
|
||||
seen_ids.add(article_id)
|
||||
|
||||
next_url = data.get("next_url")
|
||||
pages += 1
|
||||
if max_pages is not None and pages >= max_pages:
|
||||
break
|
||||
if not next_url:
|
||||
break
|
||||
time.sleep(page_sleep)
|
||||
|
||||
return all_articles
|
||||
2
backend/enrich/__init__.py
Normal file
2
backend/enrich/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""News enrichment utilities for explain-oriented market research."""
|
||||
|
||||
296
backend/enrich/llm_enricher.py
Normal file
296
backend/enrich/llm_enricher.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Optional AgentScope-backed news enrichment with safe local fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
|
||||
from backend.llm.models import create_model
|
||||
|
||||
|
||||
class EnrichedNewsItem(BaseModel):
|
||||
"""Structured output schema for one enriched article."""
|
||||
|
||||
id: str = Field(description="The source article id")
|
||||
relevance: str = Field(description="One of high, medium, low")
|
||||
sentiment: str = Field(description="One of positive, negative, neutral")
|
||||
key_discussion: str = Field(description="Concise core discussion")
|
||||
summary: str = Field(description="Concise factual summary")
|
||||
reason_growth: str = Field(description="Growth-oriented reason if present")
|
||||
reason_decrease: str = Field(description="Downside-oriented reason if present")
|
||||
|
||||
|
||||
class EnrichedNewsBatch(BaseModel):
|
||||
"""Structured output schema for a batch of enriched articles."""
|
||||
|
||||
items: list[EnrichedNewsItem]
|
||||
|
||||
|
||||
class RangeAnalysisPayload(BaseModel):
|
||||
"""Structured output schema for range explanation text."""
|
||||
|
||||
summary: str = Field(description="Concise Chinese range summary for the selected window")
|
||||
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
|
||||
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
|
||||
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
|
||||
|
||||
|
||||
def get_explain_model_info() -> dict[str, str]:
|
||||
"""Resolve provider/model used by explain enrichment."""
|
||||
provider = canonicalize_model_provider(
|
||||
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
|
||||
or get_env_str("MODEL_PROVIDER", "OPENAI"),
|
||||
)
|
||||
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
|
||||
"MODEL_NAME",
|
||||
"gpt-4o-mini",
|
||||
)
|
||||
return {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"label": f"{provider}:{model_name}",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
|
||||
if isinstance(payload, BaseModel):
|
||||
payload = payload.model_dump()
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return {
|
||||
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
|
||||
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
|
||||
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
|
||||
"summary": str(payload.get("summary") or "").strip() or None,
|
||||
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
|
||||
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
|
||||
"raw_json": payload,
|
||||
}
|
||||
|
||||
|
||||
def _run_async(coro: Any) -> Any:
|
||||
"""Run an async AgentScope model call from sync code, even inside a running loop."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_explain_model():
|
||||
"""Create an AgentScope model for explain enrichment."""
|
||||
model_info = get_explain_model_info()
|
||||
return create_model(
|
||||
model_name=model_info["model_name"],
|
||||
provider=model_info["provider"],
|
||||
stream=False,
|
||||
generate_kwargs={"temperature": 0.1},
|
||||
)
|
||||
|
||||
|
||||
def llm_enrichment_enabled() -> bool:
|
||||
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
|
||||
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
|
||||
return False
|
||||
provider = get_explain_model_info()["provider"]
|
||||
provider_key_map = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
env_key = provider_key_map.get(provider)
|
||||
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
|
||||
|
||||
|
||||
def llm_range_analysis_enabled() -> bool:
|
||||
"""Return whether LLM range analysis should be attempted."""
|
||||
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
|
||||
if raw_value is not None and str(raw_value).strip() != "":
|
||||
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
|
||||
return llm_enrichment_enabled()
|
||||
|
||||
|
||||
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented structured analysis for one article."""
|
||||
if not llm_enrichment_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis. "
|
||||
"Use only the requested fields and keep content factual."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze this stock-news article for an explain UI.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep each text field concise and factual\n"
|
||||
f"- article id: {str(row.get('id') or '').strip()}\n"
|
||||
f"Title: {title}\n"
|
||||
f"Summary: {summary}\n"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
|
||||
if payload:
|
||||
payload.setdefault("raw_json", {})
|
||||
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
return payload
|
||||
|
||||
|
||||
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
"""Generate structured analysis for multiple articles in one request."""
|
||||
if not llm_enrichment_enabled() or not rows:
|
||||
return {}
|
||||
|
||||
payload_rows = [
|
||||
{
|
||||
"id": str(row.get("id") or "").strip(),
|
||||
"title": str(row.get("title") or "").strip(),
|
||||
"summary": str(row.get("summary") or "").strip(),
|
||||
}
|
||||
for row in rows
|
||||
if str(row.get("id") or "").strip()
|
||||
]
|
||||
if not payload_rows:
|
||||
return {}
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis in JSON. "
|
||||
"Preserve ids exactly and do not invent extra items."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze these stock-news articles for an explain UI.\n"
|
||||
"For each item return: id, relevance, sentiment, key_discussion, summary, "
|
||||
"reason_growth, reason_decrease.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep all text concise and factual\n"
|
||||
f"Articles: {payload_rows}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=EnrichedNewsBatch),
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
items = metadata.get("items") if isinstance(metadata, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return {}
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for item in items:
|
||||
normalized = _normalize_enrichment_payload(item)
|
||||
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
|
||||
if normalized and news_id:
|
||||
normalized.setdefault("raw_json", {})
|
||||
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
results[news_id] = normalized
|
||||
return results
|
||||
|
||||
|
||||
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented range summary and factor refinement."""
|
||||
if not llm_range_analysis_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You write concise Chinese stock range analysis for an explain UI. "
|
||||
"Use only the supplied facts. Keep the tone factual and analyst-like."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"请基于给定事实生成区间分析。\n"
|
||||
"输出字段:summary, trend_analysis, bullish_factors, bearish_factors。\n"
|
||||
"要求:\n"
|
||||
"- 全部使用简体中文\n"
|
||||
"- summary 1到2句,概括区间走势、新闻密度和主导主题\n"
|
||||
"- trend_analysis 1句,解释区间内部阶段变化\n"
|
||||
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
|
||||
"- 不要编造未提供的信息\n"
|
||||
f"事实数据: {payload}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=RangeAnalysisPayload),
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"summary": str(metadata.get("summary") or "").strip() or None,
|
||||
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
|
||||
"bullish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bullish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"bearish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bearish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"model_provider": get_explain_model_info()["provider"],
|
||||
"model_name": get_explain_model_info()["model_name"],
|
||||
"model_label": get_explain_model_info()["label"],
|
||||
}
|
||||
362
backend/enrich/news_enricher.py
Normal file
362
backend/enrich/news_enricher.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Lightweight news enrichment for explain-oriented market analysis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from backend.config.env_config import get_env_int
|
||||
from backend.enrich.llm_enricher import (
|
||||
analyze_news_row_with_llm,
|
||||
analyze_news_rows_with_llm,
|
||||
llm_enrichment_enabled,
|
||||
)
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
POSITIVE_KEYWORDS = (
|
||||
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
|
||||
"partnership", "approved", "launch", "expands", "profit",
|
||||
)
|
||||
NEGATIVE_KEYWORDS = (
|
||||
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
|
||||
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
|
||||
)
|
||||
HIGH_RELEVANCE_KEYWORDS = (
|
||||
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
|
||||
"regulation", "acquisition", "buyback", "forecast", "launch",
|
||||
)
|
||||
|
||||
|
||||
def _dedupe_key(row: dict[str, Any]) -> str:
|
||||
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
|
||||
title = str(row.get("title") or "").strip().lower()
|
||||
summary = str(row.get("summary") or "").strip().lower()[:160]
|
||||
raw = f"{trade_date}::{title}::{summary}"
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
|
||||
chunk_size = max(1, int(size))
|
||||
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
|
||||
|
||||
|
||||
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return a lightweight explain-oriented analysis for one article."""
|
||||
llm_result = analyze_news_row_with_llm(row)
|
||||
if isinstance(llm_result, dict):
|
||||
merged = dict(llm_result)
|
||||
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
merged.setdefault("raw_json", row)
|
||||
merged["analysis_source"] = "llm"
|
||||
return merged
|
||||
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
text = f"{title} {summary}".lower()
|
||||
|
||||
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
|
||||
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
|
||||
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
|
||||
|
||||
if len(positive_hits) > len(negative_hits):
|
||||
sentiment = "positive"
|
||||
elif len(negative_hits) > len(positive_hits):
|
||||
sentiment = "negative"
|
||||
else:
|
||||
sentiment = "neutral"
|
||||
|
||||
relevance = "high" if relevance_hits else "medium" if title else "low"
|
||||
summary_text = summary or title
|
||||
key_discussion = ""
|
||||
if relevance_hits:
|
||||
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
|
||||
elif summary_text:
|
||||
key_discussion = summary_text[:160]
|
||||
|
||||
reason_growth = ""
|
||||
reason_decrease = ""
|
||||
if sentiment == "positive":
|
||||
reason_growth = summary_text[:200]
|
||||
elif sentiment == "negative":
|
||||
reason_decrease = summary_text[:200]
|
||||
|
||||
return {
|
||||
"relevance": relevance,
|
||||
"sentiment": sentiment,
|
||||
"key_discussion": key_discussion,
|
||||
"summary": summary_text[:280],
|
||||
"reason_growth": reason_growth,
|
||||
"reason_decrease": reason_decrease,
|
||||
"analysis_source": "local",
|
||||
"raw_json": row,
|
||||
}
|
||||
|
||||
|
||||
def attach_forward_returns(
|
||||
*,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Attach forward-return labels to each analyzed row."""
|
||||
if not ohlc_rows:
|
||||
return news_rows
|
||||
|
||||
closes_by_date = {
|
||||
str(row.get("date")): float(row.get("close"))
|
||||
for row in ohlc_rows
|
||||
if row.get("date") is not None and row.get("close") is not None
|
||||
}
|
||||
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
|
||||
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
|
||||
|
||||
horizons = {
|
||||
"ret_t0": 0,
|
||||
"ret_t1": 1,
|
||||
"ret_t3": 3,
|
||||
"ret_t5": 5,
|
||||
"ret_t10": 10,
|
||||
}
|
||||
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for row in news_rows:
|
||||
trade_date = str(row.get("trade_date") or "")[:10]
|
||||
base_close = closes_by_date.get(trade_date)
|
||||
if not trade_date or base_close in (None, 0):
|
||||
enriched.append(row)
|
||||
continue
|
||||
|
||||
next_row = dict(row)
|
||||
base_index = date_index.get(trade_date)
|
||||
if base_index is None:
|
||||
enriched.append(next_row)
|
||||
continue
|
||||
|
||||
for field, offset in horizons.items():
|
||||
target_index = base_index + offset
|
||||
if target_index >= len(ordered_dates):
|
||||
next_row[field] = None
|
||||
continue
|
||||
target_close = closes_by_date.get(ordered_dates[target_index])
|
||||
next_row[field] = (
|
||||
(float(target_close) - float(base_close)) / float(base_close)
|
||||
if target_close not in (None, 0)
|
||||
else None
|
||||
)
|
||||
enriched.append(next_row)
|
||||
return enriched
|
||||
|
||||
|
||||
def build_analysis_rows(
|
||||
*,
|
||||
symbol: str,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
||||
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
|
||||
llm_results: dict[str, dict[str, Any]] = {}
|
||||
if llm_enrichment_enabled():
|
||||
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
|
||||
for chunk in _chunk_rows(news_rows, batch_size):
|
||||
llm_results.update(analyze_news_rows_with_llm(chunk))
|
||||
|
||||
staged_rows: list[dict[str, Any]] = []
|
||||
seen_dedupe_keys: set[str] = set()
|
||||
deduped_count = 0
|
||||
llm_count = 0
|
||||
local_count = 0
|
||||
for row in news_rows:
|
||||
news_id = str(row.get("id") or "").strip()
|
||||
if not news_id:
|
||||
continue
|
||||
dedupe_key = _dedupe_key(row)
|
||||
if dedupe_key in seen_dedupe_keys:
|
||||
deduped_count += 1
|
||||
continue
|
||||
seen_dedupe_keys.add(dedupe_key)
|
||||
batch_result = llm_results.get(news_id)
|
||||
if isinstance(batch_result, dict):
|
||||
analysis = dict(batch_result)
|
||||
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
analysis.setdefault("raw_json", row)
|
||||
analysis["analysis_source"] = "llm"
|
||||
llm_count += 1
|
||||
else:
|
||||
analysis = classify_news_row(row)
|
||||
if analysis.get("analysis_source") == "llm":
|
||||
llm_count += 1
|
||||
else:
|
||||
local_count += 1
|
||||
staged_rows.append(
|
||||
{
|
||||
"news_id": news_id,
|
||||
"trade_date": str(row.get("trade_date") or "")[:10] or None,
|
||||
**analysis,
|
||||
}
|
||||
)
|
||||
return (
|
||||
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
|
||||
{
|
||||
"deduped_count": deduped_count,
|
||||
"llm_count": llm_count,
|
||||
"local_count": local_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def enrich_news_for_symbol(
|
||||
store: MarketStore,
|
||||
symbol: str,
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Read raw market news, compute explain fields, and persist them."""
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
return {"symbol": "", "analyzed": 0}
|
||||
|
||||
news_rows = store.get_news_items(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
total_news_count = len(news_rows)
|
||||
skipped_existing_count = 0
|
||||
analyzed_sources: dict[str, str] = {}
|
||||
skipped_missing_analysis_count = 0
|
||||
skipped_non_local_count = 0
|
||||
if news_rows and only_reanalyze_local:
|
||||
analyzed_sources = store.get_analyzed_news_sources(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_missing_analysis_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
)
|
||||
skipped_non_local_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_sources
|
||||
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
|
||||
]
|
||||
elif skip_existing and news_rows:
|
||||
analyzed_ids = store.get_analyzed_news_ids(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_ids
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_ids
|
||||
]
|
||||
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
|
||||
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
|
||||
ohlc_rows = (
|
||||
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
|
||||
if ohlc_start and ohlc_end
|
||||
else []
|
||||
)
|
||||
analysis_rows, stats = build_analysis_rows(
|
||||
symbol=normalized_symbol,
|
||||
news_rows=news_rows,
|
||||
ohlc_rows=ohlc_rows,
|
||||
)
|
||||
analyzed = store.upsert_news_analysis(
|
||||
normalized_symbol,
|
||||
analysis_rows,
|
||||
analysis_source=analysis_source,
|
||||
)
|
||||
upgraded_dates = sorted(
|
||||
{
|
||||
str(row.get("trade_date") or "")[:10]
|
||||
for row in analysis_rows
|
||||
if str(row.get("analysis_source") or "").strip().lower() == "llm"
|
||||
and str(row.get("trade_date") or "").strip()
|
||||
}
|
||||
)
|
||||
remaining_local_titles = [
|
||||
str(row.get("title") or row.get("news_id") or "").strip()
|
||||
for row in news_rows
|
||||
for analyzed_row in analysis_rows
|
||||
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
|
||||
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
|
||||
][:5]
|
||||
return {
|
||||
"symbol": normalized_symbol,
|
||||
"analyzed": analyzed,
|
||||
"news_count": total_news_count,
|
||||
"queued_count": len(news_rows),
|
||||
"skipped_existing_count": skipped_existing_count,
|
||||
"deduped_count": stats["deduped_count"],
|
||||
"llm_count": stats["llm_count"],
|
||||
"local_count": stats["local_count"],
|
||||
"only_reanalyze_local": only_reanalyze_local,
|
||||
"upgraded_local_to_llm_count": (
|
||||
stats["llm_count"]
|
||||
if only_reanalyze_local
|
||||
else 0
|
||||
),
|
||||
"execution_summary": {
|
||||
"upgraded_dates": upgraded_dates[:5],
|
||||
"remaining_local_titles": remaining_local_titles,
|
||||
"skipped_missing_analysis_count": skipped_missing_analysis_count,
|
||||
"skipped_non_local_count": skipped_non_local_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def enrich_symbols(
|
||||
store: MarketStore,
|
||||
symbols: list[str],
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Batch enrich multiple symbols for explain-oriented news analysis."""
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
continue
|
||||
results.append(
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
analysis_source=analysis_source,
|
||||
skip_existing=skip_existing,
|
||||
only_reanalyze_local=only_reanalyze_local,
|
||||
)
|
||||
)
|
||||
return results
|
||||
2
backend/explain/__init__.py
Normal file
2
backend/explain/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Explain-oriented services for stock narratives and news research."""
|
||||
69
backend/explain/category_engine.py
Normal file
69
backend/explain/category_engine.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Rule-based news categorization for explain UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
|
||||
CATEGORY_KEYWORDS = {
|
||||
"market": [
|
||||
"market", "stock", "rally", "sell-off", "selloff", "trading",
|
||||
"wall street", "s&p", "nasdaq", "dow", "index", "bull", "bear",
|
||||
"correction", "volatility",
|
||||
],
|
||||
"policy": [
|
||||
"regulation", "fed", "federal reserve", "tariff", "sanction",
|
||||
"interest rate", "policy", "government", "congress", "sec",
|
||||
"trade war", "ban", "legislation", "tax",
|
||||
],
|
||||
"earnings": [
|
||||
"earnings", "revenue", "profit", "quarter", "eps", "guidance",
|
||||
"forecast", "income", "sales", "beat", "miss", "outlook",
|
||||
"financial results",
|
||||
],
|
||||
"product_tech": [
|
||||
"product", "ai", "chip", "cloud", "launch", "patent",
|
||||
"technology", "innovation", "release", "platform", "model",
|
||||
"software", "hardware", "gpu", "autonomous",
|
||||
],
|
||||
"competition": [
|
||||
"competitor", "rival", "market share", "overtake", "compete",
|
||||
"competition", "vs", "versus", "battle", "challenge",
|
||||
],
|
||||
"management": [
|
||||
"ceo", "executive", "resign", "layoff", "restructure",
|
||||
"management", "leadership", "appoint", "hire", "board",
|
||||
"chairman",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def categorize_news_rows(rows: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Bucket news rows by keyword categories."""
|
||||
categories: Dict[str, Dict[str, Any]] = {
|
||||
key: {
|
||||
"label": key,
|
||||
"count": 0,
|
||||
"article_ids": [],
|
||||
}
|
||||
for key in CATEGORY_KEYWORDS
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
text = " ".join(
|
||||
[
|
||||
str(row.get("title") or ""),
|
||||
str(row.get("summary") or ""),
|
||||
str(row.get("related") or ""),
|
||||
str(row.get("category") or ""),
|
||||
]
|
||||
).lower()
|
||||
article_id = row.get("id")
|
||||
for category, keywords in CATEGORY_KEYWORDS.items():
|
||||
if any(keyword in text for keyword in keywords):
|
||||
categories[category]["count"] += 1
|
||||
if article_id:
|
||||
categories[category]["article_ids"].append(article_id)
|
||||
|
||||
return categories
|
||||
214
backend/explain/range_explainer.py
Normal file
214
backend/explain/range_explainer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Local range explanation built from price and persisted news."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.enrich.llm_enricher import analyze_range_with_llm
|
||||
from backend.explain.category_engine import categorize_news_rows
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
|
||||
def _rank_event_score(row: Dict[str, Any]) -> float:
|
||||
relevance = str(row.get("relevance") or "").strip().lower()
|
||||
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
|
||||
relevance,
|
||||
0.5,
|
||||
)
|
||||
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
|
||||
return relevance_score + impact_score
|
||||
|
||||
|
||||
def summarize_bullish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "positive":
|
||||
continue
|
||||
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def summarize_bearish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "negative":
|
||||
continue
|
||||
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def build_trend_analysis(prices: list[Any]) -> str:
|
||||
if len(prices) < 2:
|
||||
return "区间样本较短,暂不具备足够趋势信息。"
|
||||
if len(prices) < 3:
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
|
||||
|
||||
mid = len(prices) // 2
|
||||
first_open = float(prices[0].open)
|
||||
first_close = float(prices[mid].close)
|
||||
second_open = float(prices[mid].open)
|
||||
second_close = float(prices[-1].close)
|
||||
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
|
||||
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
|
||||
return (
|
||||
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%,"
|
||||
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%,"
|
||||
"说明价格驱动在区间内部出现了阶段性切换。"
|
||||
)
|
||||
|
||||
|
||||
def build_range_explanation(
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
news_rows: list[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Explain a price range with local price and news heuristics."""
|
||||
prices = get_prices(ticker, start_date, end_date)
|
||||
if not prices:
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "No OHLC data for this range",
|
||||
}
|
||||
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
high_price = max(float(price.high) for price in prices)
|
||||
low_price = min(float(price.low) for price in prices)
|
||||
total_volume = sum(int(price.volume) for price in prices)
|
||||
price_change_pct = (
|
||||
((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
)
|
||||
|
||||
categories = categorize_news_rows(news_rows)
|
||||
news_count = len(news_rows)
|
||||
dominant_categories = sorted(
|
||||
(
|
||||
{"category": key, "count": value["count"]}
|
||||
for key, value in categories.items()
|
||||
if value["count"] > 0
|
||||
),
|
||||
key=lambda item: item["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
|
||||
category_text = (
|
||||
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}。"
|
||||
if dominant_categories
|
||||
else "区间内未识别出明显的主题聚类。"
|
||||
)
|
||||
summary = (
|
||||
f"{ticker} 在 {start_date} 至 {end_date} 区间内{direction} {abs(price_change_pct):.2f}%,"
|
||||
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
|
||||
)
|
||||
|
||||
bullish_factors = summarize_bullish_factors(news_rows)
|
||||
bearish_factors = summarize_bearish_factors(news_rows)
|
||||
trend_analysis = build_trend_analysis(prices)
|
||||
llm_source = "local"
|
||||
|
||||
range_payload = {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"bullish_factors": bullish_factors[:3],
|
||||
"bearish_factors": bearish_factors[:3],
|
||||
"trend_analysis": trend_analysis,
|
||||
"top_news": [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "",
|
||||
"summary": row.get("summary") or "",
|
||||
"sentiment": row.get("sentiment") or "",
|
||||
"relevance": row.get("relevance") or "",
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
|
||||
],
|
||||
}
|
||||
llm_analysis = analyze_range_with_llm(range_payload)
|
||||
if isinstance(llm_analysis, dict):
|
||||
summary = llm_analysis.get("summary") or summary
|
||||
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
|
||||
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
|
||||
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
|
||||
llm_source = "llm"
|
||||
|
||||
key_events = [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "Untitled news",
|
||||
"summary": row.get("summary") or "",
|
||||
"category": row.get("category") or "",
|
||||
"id": row.get("id"),
|
||||
"sentiment": row.get("sentiment"),
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"open_price": open_price,
|
||||
"close_price": close_price,
|
||||
"high_price": high_price,
|
||||
"low_price": low_price,
|
||||
"total_volume": total_volume,
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"analysis": {
|
||||
"summary": summary,
|
||||
"key_events": key_events,
|
||||
"bullish_factors": bullish_factors,
|
||||
"bearish_factors": bearish_factors,
|
||||
"trend_analysis": trend_analysis,
|
||||
"analysis_source": llm_source,
|
||||
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
|
||||
},
|
||||
}
|
||||
202
backend/explain/similarity_service.py
Normal file
202
backend/explain/similarity_service.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Same-ticker historical similar day search for explain view."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed
|
||||
|
||||
|
||||
def build_daily_feature_rows(
|
||||
*,
|
||||
symbol: str,
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
news_rows: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Aggregate price/news context into daily feature rows."""
|
||||
price_by_date = {str(row.get("date")): row for row in ohlc_rows if row.get("date")}
|
||||
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date")]
|
||||
|
||||
news_by_date: dict[str, list[dict[str, Any]]] = {}
|
||||
for row in news_rows:
|
||||
trade_date = str(row.get("trade_date") or "")[:10] or str(row.get("date") or "")[:10]
|
||||
if not trade_date:
|
||||
continue
|
||||
news_by_date.setdefault(trade_date, []).append(row)
|
||||
|
||||
features: list[dict[str, Any]] = []
|
||||
previous_close: float | None = None
|
||||
for idx, date in enumerate(ordered_dates):
|
||||
price_row = price_by_date[date]
|
||||
close_price = _safe_float(price_row.get("close"))
|
||||
open_price = _safe_float(price_row.get("open"), close_price)
|
||||
day_news = news_by_date.get(date, [])
|
||||
positive_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "positive")
|
||||
negative_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "negative")
|
||||
high_relevance_count = sum(
|
||||
1 for item in day_news if str(item.get("relevance") or "").lower() in {"high", "relevant"}
|
||||
)
|
||||
ret_1d = (
|
||||
((close_price - previous_close) / previous_close)
|
||||
if previous_close not in (None, 0)
|
||||
else 0.0
|
||||
)
|
||||
intraday_ret = ((close_price - open_price) / open_price) if open_price else 0.0
|
||||
sentiment_score = (
|
||||
(positive_count - negative_count) / max(len(day_news), 1)
|
||||
if day_news
|
||||
else 0.0
|
||||
)
|
||||
future_t1 = None
|
||||
future_t3 = None
|
||||
if idx + 1 < len(ordered_dates) and close_price:
|
||||
next_close = _safe_float(price_by_date[ordered_dates[idx + 1]].get("close"))
|
||||
future_t1 = ((next_close - close_price) / close_price) if next_close else None
|
||||
if idx + 3 < len(ordered_dates) and close_price:
|
||||
next_close = _safe_float(price_by_date[ordered_dates[idx + 3]].get("close"))
|
||||
future_t3 = ((next_close - close_price) / close_price) if next_close else None
|
||||
|
||||
features.append(
|
||||
{
|
||||
"date": date,
|
||||
"symbol": symbol,
|
||||
"n_articles": len(day_news),
|
||||
"positive_count": positive_count,
|
||||
"negative_count": negative_count,
|
||||
"high_relevance_count": high_relevance_count,
|
||||
"sentiment_score": sentiment_score,
|
||||
"ret_1d": ret_1d,
|
||||
"intraday_ret": intraday_ret,
|
||||
"close": close_price,
|
||||
"ret_t1_after": future_t1,
|
||||
"ret_t3_after": future_t3,
|
||||
"news": [
|
||||
{
|
||||
"title": row.get("title") or "",
|
||||
"sentiment": row.get("sentiment") or "neutral",
|
||||
}
|
||||
for row in day_news[:3]
|
||||
],
|
||||
}
|
||||
)
|
||||
previous_close = close_price
|
||||
return features
|
||||
|
||||
|
||||
def compute_similarity_scores(
|
||||
target_vector: list[float],
|
||||
candidate_vectors: list[tuple[str, list[float], dict[str, Any]]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return sorted similarity matches based on normalized Euclidean distance."""
|
||||
if not candidate_vectors:
|
||||
return []
|
||||
dimensions = len(target_vector)
|
||||
ranges = []
|
||||
for dimension in range(dimensions):
|
||||
values = [vector[1][dimension] for vector in candidate_vectors] + [target_vector[dimension]]
|
||||
min_value = min(values)
|
||||
max_value = max(values)
|
||||
ranges.append(max(max_value - min_value, 1e-9))
|
||||
|
||||
scored = []
|
||||
for date, vector, payload in candidate_vectors:
|
||||
distance = sqrt(
|
||||
sum(
|
||||
((target_vector[i] - vector[i]) / ranges[i]) ** 2
|
||||
for i in range(dimensions)
|
||||
)
|
||||
)
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
scored.append(
|
||||
{
|
||||
"date": date,
|
||||
"score": round(similarity, 4),
|
||||
**payload,
|
||||
}
|
||||
)
|
||||
return sorted(scored, key=lambda item: item["score"], reverse=True)
|
||||
|
||||
|
||||
def find_similar_days(
|
||||
store: MarketStore,
|
||||
*,
|
||||
symbol: str,
|
||||
target_date: str,
|
||||
top_k: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Find same-ticker historical days most similar to a target day."""
|
||||
cached = store.get_similar_day_cache(symbol, target_date=target_date)
|
||||
if cached and cached.get("payload"):
|
||||
return cached["payload"]
|
||||
|
||||
ohlc_rows = store.get_ohlc(symbol, "1900-01-01", target_date)
|
||||
news_rows = store.get_news_items_enriched(symbol, end_date=target_date, limit=500)
|
||||
daily_rows = build_daily_feature_rows(symbol=symbol, ohlc_rows=ohlc_rows, news_rows=news_rows)
|
||||
feature_map = {row["date"]: row for row in daily_rows}
|
||||
target_row = feature_map.get(target_date)
|
||||
if not target_row:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"items": [],
|
||||
"error": "No feature row for target date",
|
||||
}
|
||||
|
||||
vector_keys = [
|
||||
"sentiment_score",
|
||||
"n_articles",
|
||||
"positive_count",
|
||||
"negative_count",
|
||||
"high_relevance_count",
|
||||
"ret_1d",
|
||||
"intraday_ret",
|
||||
]
|
||||
target_vector = [_safe_float(target_row.get(key)) for key in vector_keys]
|
||||
candidates = []
|
||||
for row in daily_rows:
|
||||
date = row["date"]
|
||||
if date == target_date:
|
||||
continue
|
||||
payload = {
|
||||
"n_articles": row["n_articles"],
|
||||
"sentiment_score": round(row["sentiment_score"], 4),
|
||||
"ret_1d": round(row["ret_1d"] * 100, 2),
|
||||
"intraday_ret": round(row["intraday_ret"] * 100, 2),
|
||||
"ret_t1_after": round(row["ret_t1_after"] * 100, 2) if row["ret_t1_after"] is not None else None,
|
||||
"ret_t3_after": round(row["ret_t3_after"] * 100, 2) if row["ret_t3_after"] is not None else None,
|
||||
"top_reasons": [item["title"] for item in row["news"][:2] if item.get("title")],
|
||||
"news": row["news"],
|
||||
}
|
||||
candidates.append(
|
||||
(
|
||||
date,
|
||||
[_safe_float(row.get(key)) for key in vector_keys],
|
||||
payload,
|
||||
)
|
||||
)
|
||||
|
||||
items = compute_similarity_scores(target_vector, candidates)[: max(1, min(int(top_k), 20))]
|
||||
result = {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"target_features": {
|
||||
"sentiment_score": round(target_row["sentiment_score"], 4),
|
||||
"n_articles": target_row["n_articles"],
|
||||
"ret_1d": round(target_row["ret_1d"] * 100, 2),
|
||||
"intraday_ret": round(target_row["intraday_ret"] * 100, 2),
|
||||
"high_relevance_count": target_row["high_relevance_count"],
|
||||
},
|
||||
"items": items,
|
||||
}
|
||||
store.upsert_similar_day_cache(symbol, target_date=target_date, payload=result, source="local")
|
||||
return result
|
||||
127
backend/explain/story_service.py
Normal file
127
backend/explain/story_service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Stock story generation for explain view."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def build_stock_story(
|
||||
*,
|
||||
symbol: str,
|
||||
as_of_date: str,
|
||||
price_rows: list[dict[str, Any]],
|
||||
news_rows: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build a compact markdown story from enriched news and recent price action."""
|
||||
lines = [f"## {symbol} Story", f"As of `{as_of_date}`"]
|
||||
if not price_rows:
|
||||
lines.append("")
|
||||
lines.append("No OHLC data available for story generation.")
|
||||
return "\n".join(lines)
|
||||
|
||||
open_price = float(price_rows[0].get("open") or price_rows[0].get("close") or 0.0)
|
||||
close_price = float(price_rows[-1].get("close") or 0.0)
|
||||
price_change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
high_price = max(float(row.get("high") or row.get("close") or 0.0) for row in price_rows)
|
||||
low_price = min(float(row.get("low") or row.get("close") or 0.0) for row in price_rows)
|
||||
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"The stock moved {'up' if price_change >= 0 else 'down'} "
|
||||
f"{abs(price_change):.2f}% over the recent window, trading between "
|
||||
f"${low_price:.2f} and ${high_price:.2f}."
|
||||
)
|
||||
|
||||
positive = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "positive"]
|
||||
negative = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "negative"]
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"Recent coverage included {len(news_rows)} relevant articles "
|
||||
f"({len(positive)} positive / {len(negative)} negative)."
|
||||
)
|
||||
|
||||
if news_rows:
|
||||
lines.append("")
|
||||
lines.append("### Key Moments")
|
||||
ranked_rows = sorted(
|
||||
news_rows,
|
||||
key=lambda row: (
|
||||
0 if str(row.get("relevance") or "").lower() in {"high", "relevant"} else 1,
|
||||
-abs(float(row.get("ret_t0") or 0.0)),
|
||||
),
|
||||
)
|
||||
for row in ranked_rows[:5]:
|
||||
trade_date = row.get("trade_date") or str(row.get("date") or "")[:10]
|
||||
title = row.get("title") or "Untitled"
|
||||
key_discussion = row.get("key_discussion") or row.get("summary") or ""
|
||||
sentiment = str(row.get("sentiment") or "neutral").lower()
|
||||
lines.append(
|
||||
f"- `{trade_date}` [{sentiment}] {title}: {str(key_discussion).strip()[:220]}"
|
||||
)
|
||||
|
||||
if positive:
|
||||
lines.append("")
|
||||
lines.append("### Bullish Threads")
|
||||
for row in positive[:3]:
|
||||
reason = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
lines.append(f"- {str(reason).strip()[:220]}")
|
||||
|
||||
if negative:
|
||||
lines.append("")
|
||||
lines.append("### Bearish Threads")
|
||||
for row in negative[:3]:
|
||||
reason = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
lines.append(f"- {str(reason).strip()[:220]}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_or_create_stock_story(
|
||||
store: MarketStore,
|
||||
*,
|
||||
symbol: str,
|
||||
as_of_date: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Return cached story or build a new one from recent market context."""
|
||||
cached = store.get_story_cache(symbol, as_of_date=as_of_date)
|
||||
if cached:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": cached.get("content") or "",
|
||||
"source": cached.get("source") or "cache",
|
||||
}
|
||||
|
||||
start_date = None
|
||||
if len(as_of_date) >= 10:
|
||||
target_date = datetime.strptime(as_of_date[:10], "%Y-%m-%d").date()
|
||||
start_date = (target_date - timedelta(days=29)).isoformat()
|
||||
|
||||
price_rows = (
|
||||
store.get_ohlc(symbol, start_date, as_of_date)
|
||||
if start_date
|
||||
else []
|
||||
)
|
||||
news_rows = store.get_news_items_enriched(
|
||||
symbol,
|
||||
start_date=start_date,
|
||||
end_date=as_of_date,
|
||||
limit=40,
|
||||
)
|
||||
story = build_stock_story(
|
||||
symbol=symbol,
|
||||
as_of_date=as_of_date,
|
||||
price_rows=price_rows,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
store.upsert_story_cache(symbol, as_of_date=as_of_date, content=story, source="local")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": story,
|
||||
"source": "local",
|
||||
}
|
||||
@@ -17,6 +17,12 @@ from backend.config.bootstrap_config import (
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.explain.range_explainer import build_range_explanation
|
||||
from backend.explain.similarity_service import find_similar_days
|
||||
from backend.explain.story_service import get_or_create_stock_story
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -25,6 +31,7 @@ from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
from backend.tools.data_tools import get_company_news
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -65,6 +72,7 @@ class Gateway:
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._dashboard = get_dashboard()
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
self._watchlist_ingest_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
@@ -182,6 +190,17 @@ class Gateway:
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
@staticmethod
|
||||
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
|
||||
if not rows:
|
||||
return True
|
||||
return all(
|
||||
not row.get("sentiment")
|
||||
and not row.get("relevance")
|
||||
and not row.get("key_discussion")
|
||||
for row in rows
|
||||
)
|
||||
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
@@ -250,6 +269,22 @@ class Gateway:
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
elif msg_type == "get_stock_news":
|
||||
await self._handle_get_stock_news(websocket, data)
|
||||
elif msg_type == "get_stock_news_for_date":
|
||||
await self._handle_get_stock_news_for_date(websocket, data)
|
||||
elif msg_type == "get_stock_news_timeline":
|
||||
await self._handle_get_stock_news_timeline(websocket, data)
|
||||
elif msg_type == "get_stock_news_categories":
|
||||
await self._handle_get_stock_news_categories(websocket, data)
|
||||
elif msg_type == "get_stock_range_explain":
|
||||
await self._handle_get_stock_range_explain(websocket, data)
|
||||
elif msg_type == "get_stock_story":
|
||||
await self._handle_get_stock_story(websocket, data)
|
||||
elif msg_type == "get_stock_similar_days":
|
||||
await self._handle_get_stock_similar_days(websocket, data)
|
||||
elif msg_type == "run_stock_enrich":
|
||||
await self._handle_run_stock_enrich(websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
@@ -298,20 +333,38 @@ class Gateway:
|
||||
)
|
||||
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
self.storage.market_store.get_ohlc,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
source = "polygon"
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.upsert_ohlc,
|
||||
ticker,
|
||||
[price.model_dump() for price in prices],
|
||||
source=source or "provider",
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price.model_dump() for price in prices][-120:],
|
||||
"prices": [
|
||||
price if isinstance(price, dict) else price.model_dump()
|
||||
for price in prices
|
||||
][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
@@ -342,6 +395,636 @@ class Gateway:
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": "",
|
||||
"news": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 30)
|
||||
limit = data.get("limit", 12)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 180))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 30
|
||||
try:
|
||||
limit = max(1, min(int(limit), 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 12
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
source = "polygon"
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
news = await asyncio.to_thread(
|
||||
get_company_news,
|
||||
ticker,
|
||||
end_date,
|
||||
start_date,
|
||||
limit,
|
||||
)
|
||||
if news:
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("company_news")
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.upsert_news,
|
||||
ticker,
|
||||
[item.model_dump() for item in news],
|
||||
source=source or "provider",
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=max(limit, 50),
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
source = source or "market_store"
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": ticker,
|
||||
"news": news_rows[-limit:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_for_date(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
trade_date = str(data.get("date") or "").strip()
|
||||
if not ticker or not trade_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": [],
|
||||
"error": "ticker and date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
limit = data.get("limit", 20)
|
||||
try:
|
||||
limit = max(1, min(int(limit), 50))
|
||||
except (TypeError, ValueError):
|
||||
limit = 20
|
||||
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=trade_date,
|
||||
end_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
trade_date=trade_date,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": news_rows,
|
||||
"source": "market_store",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_timeline(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": "",
|
||||
"timeline": [],
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
timeline = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_timeline_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not timeline:
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
timeline = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_timeline_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_news_categories(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": "",
|
||||
"categories": {},
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
categories = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_categories_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": ticker,
|
||||
"categories": categories,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_range_explain(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()
|
||||
end_date = str(data.get("end_date") or "").strip()
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": {"error": "ticker, start_date, end_date are required"},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
article_ids = data.get("article_ids")
|
||||
if isinstance(article_ids, list) and article_ids:
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_by_ids_enriched,
|
||||
ticker,
|
||||
article_ids,
|
||||
)
|
||||
if self._news_rows_need_enrichment(news_rows):
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_by_ids_enriched,
|
||||
ticker,
|
||||
article_ids,
|
||||
)
|
||||
else:
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
if not news_rows:
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
news_rows = await asyncio.to_thread(
|
||||
self.storage.market_store.get_news_items_enriched,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=100,
|
||||
)
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
build_range_explanation,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_story(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": "",
|
||||
"story": "",
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
as_of_date = str(
|
||||
data.get("as_of_date")
|
||||
or self.state_sync.state.get("current_date")
|
||||
or datetime.now().strftime("%Y-%m-%d")
|
||||
).strip()[:10]
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
end_date=as_of_date,
|
||||
limit=80,
|
||||
)
|
||||
result = await asyncio.to_thread(
|
||||
get_or_create_stock_story,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": ticker,
|
||||
"as_of_date": as_of_date,
|
||||
"story": result.get("story") or "",
|
||||
"source": result.get("source") or "local",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_similar_days(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
target_date = str(data.get("date") or "").strip()[:10]
|
||||
if not ticker or not target_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
"items": [],
|
||||
"error": "ticker and date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
top_k = data.get("top_k", 8)
|
||||
try:
|
||||
top_k = max(1, min(int(top_k), 20))
|
||||
except (TypeError, ValueError):
|
||||
top_k = 8
|
||||
|
||||
await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
end_date=target_date,
|
||||
limit=200,
|
||||
)
|
||||
result = await asyncio.to_thread(
|
||||
find_similar_days,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
target_date=target_date,
|
||||
top_k=top_k,
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
**result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_run_stock_enrich(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
end_date = str(data.get("end_date") or "").strip()[:10]
|
||||
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
||||
target_date = str(data.get("target_date") or "").strip()[:10]
|
||||
force = bool(data.get("force", False))
|
||||
rebuild_story = bool(data.get("rebuild_story", True))
|
||||
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
||||
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
||||
limit = data.get("limit", 200)
|
||||
|
||||
try:
|
||||
limit = max(10, min(int(limit), 500))
|
||||
except (TypeError, ValueError):
|
||||
limit = 200
|
||||
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "ticker, start_date, end_date are required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if only_local_to_llm and not llm_enrichment_enabled():
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
self.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
skip_existing=not force,
|
||||
only_reanalyze_local=only_local_to_llm,
|
||||
)
|
||||
|
||||
story_status = None
|
||||
if rebuild_story and story_date:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.delete_story_cache,
|
||||
ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_result = await asyncio.to_thread(
|
||||
get_or_create_stock_story,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_status = {
|
||||
"as_of_date": story_date,
|
||||
"source": story_result.get("source") or "local",
|
||||
}
|
||||
|
||||
similar_status = None
|
||||
if rebuild_similar_days and target_date:
|
||||
await asyncio.to_thread(
|
||||
self.storage.market_store.delete_similar_day_cache,
|
||||
ticker,
|
||||
target_date=target_date,
|
||||
)
|
||||
similar_result = await asyncio.to_thread(
|
||||
find_similar_days,
|
||||
self.storage.market_store,
|
||||
symbol=ticker,
|
||||
target_date=target_date,
|
||||
top_k=8,
|
||||
)
|
||||
similar_status = {
|
||||
"target_date": target_date,
|
||||
"count": len(similar_result.get("items") or []),
|
||||
"error": similar_result.get("error"),
|
||||
}
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"story_date": story_date or None,
|
||||
"target_date": target_date or None,
|
||||
"force": force,
|
||||
"only_local_to_llm": only_local_to_llm,
|
||||
"stats": result,
|
||||
"story_status": story_status,
|
||||
"similar_status": similar_status,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
@@ -410,6 +1093,7 @@ class Gateway:
|
||||
},
|
||||
)
|
||||
await self._handle_reload_runtime_assets()
|
||||
self._schedule_watchlist_market_store_refresh(tickers)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
@@ -538,6 +1222,48 @@ class Gateway:
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
def _schedule_watchlist_market_store_refresh(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
"""Kick off a non-blocking Polygon refresh for the updated watchlist."""
|
||||
if not tickers:
|
||||
return
|
||||
if self._watchlist_ingest_task and not self._watchlist_ingest_task.done():
|
||||
self._watchlist_ingest_task.cancel()
|
||||
self._watchlist_ingest_task = asyncio.create_task(
|
||||
self._refresh_market_store_for_watchlist(tickers),
|
||||
)
|
||||
|
||||
async def _refresh_market_store_for_watchlist(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
"""Refresh the long-lived market store after a watchlist update."""
|
||||
try:
|
||||
await self.state_sync.on_system_message(
|
||||
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
ingest_symbols,
|
||||
tickers,
|
||||
mode="incremental",
|
||||
)
|
||||
summary = ", ".join(
|
||||
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||
for item in results
|
||||
)
|
||||
await self.state_sync.on_system_message(
|
||||
f"自选股市场数据已同步: {summary}",
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||
await self.state_sync.on_system_message(
|
||||
f"自选股市场数据同步失败: {exc}",
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
@@ -896,4 +1622,6 @@ class Gateway:
|
||||
self._backtest_task.cancel()
|
||||
if self._market_status_task:
|
||||
self._market_status_task.cancel()
|
||||
if self._watchlist_ingest_task:
|
||||
self._watchlist_ingest_task.cancel()
|
||||
self._dashboard.stop()
|
||||
|
||||
@@ -65,6 +65,18 @@ class MarketService:
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
def get_live_quote_provider(self) -> Optional[str]:
|
||||
"""Return the active live quote provider for UI/debugging."""
|
||||
if self.backtest_mode:
|
||||
return "backtest"
|
||||
if self.mock_mode:
|
||||
return "mock"
|
||||
if self._price_manager and hasattr(self._price_manager, "provider"):
|
||||
provider = getattr(self._price_manager, "provider", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
return provider.strip().lower()
|
||||
return None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
@@ -532,6 +544,7 @@ class MarketService:
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
@@ -544,6 +557,7 @@ class MarketService:
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
@@ -553,6 +567,7 @@ class MarketService:
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
@@ -563,6 +578,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
@@ -571,6 +587,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
@@ -579,6 +596,7 @@ class MarketService:
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
|
||||
280
backend/services/research_db.py
Normal file
280
backend/services/research_db.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Query-oriented storage for explain/research data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from backend.data.schema import CompanyNews
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS news_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
published_at TEXT,
|
||||
trade_date TEXT,
|
||||
source TEXT,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT,
|
||||
url TEXT,
|
||||
related TEXT,
|
||||
category TEXT,
|
||||
raw_json TEXT NOT NULL,
|
||||
ingest_run_date TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
|
||||
ON news_items (ticker, trade_date DESC, published_at DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
|
||||
base = item.url or item.title or f"{ticker}-{fallback_index}"
|
||||
return f"{ticker}:{base}"
|
||||
|
||||
|
||||
def _resolve_trade_date(date_value: str | None) -> str | None:
|
||||
if not date_value:
|
||||
return None
|
||||
normalized = str(date_value).strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" in normalized:
|
||||
return normalized.split("T", 1)[0]
|
||||
if " " in normalized:
|
||||
return normalized.split(" ", 1)[0]
|
||||
return normalized[:10]
|
||||
|
||||
|
||||
class ResearchDb:
|
||||
"""Small SQLite helper for explain-oriented news storage."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def upsert_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
items: Iterable[CompanyNews],
|
||||
ingest_run_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Persist provider news and return normalized rows."""
|
||||
normalized_rows: list[dict[str, Any]] = []
|
||||
timestamp = datetime.utcnow().isoformat(timespec="seconds")
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return normalized_rows
|
||||
|
||||
with self._connect() as conn:
|
||||
for index, item in enumerate(items):
|
||||
news_id = _resolve_news_id(symbol, item, index)
|
||||
trade_date = _resolve_trade_date(item.date)
|
||||
payload = item.model_dump()
|
||||
row = {
|
||||
"id": news_id,
|
||||
"ticker": symbol,
|
||||
"published_at": item.date,
|
||||
"trade_date": trade_date,
|
||||
"source": item.source,
|
||||
"title": item.title,
|
||||
"summary": item.summary,
|
||||
"url": item.url,
|
||||
"related": item.related,
|
||||
"category": item.category,
|
||||
"raw_json": _json_dumps(payload),
|
||||
"ingest_run_date": ingest_run_date,
|
||||
"created_at": timestamp,
|
||||
}
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO news_items
|
||||
(id, ticker, published_at, trade_date, source, title, summary, url,
|
||||
related, category, raw_json, ingest_run_date, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
ticker = excluded.ticker,
|
||||
published_at = excluded.published_at,
|
||||
trade_date = excluded.trade_date,
|
||||
source = excluded.source,
|
||||
title = excluded.title,
|
||||
summary = excluded.summary,
|
||||
url = excluded.url,
|
||||
related = excluded.related,
|
||||
category = excluded.category,
|
||||
raw_json = excluded.raw_json,
|
||||
ingest_run_date = excluded.ingest_run_date
|
||||
""",
|
||||
(
|
||||
row["id"],
|
||||
row["ticker"],
|
||||
row["published_at"],
|
||||
row["trade_date"],
|
||||
row["source"],
|
||||
row["title"],
|
||||
row["summary"],
|
||||
row["url"],
|
||||
row["related"],
|
||||
row["category"],
|
||||
row["raw_json"],
|
||||
row["ingest_run_date"],
|
||||
row["created_at"],
|
||||
),
|
||||
)
|
||||
normalized_rows.append(row)
|
||||
return normalized_rows
|
||||
|
||||
def get_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return normalized news rows for explain UI."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_news_timeline(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Aggregate news counts per trade date for chart markers."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
|
||||
COUNT(*) AS count,
|
||||
COUNT(DISTINCT source) AS source_count,
|
||||
MAX(title) AS top_title
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += """
|
||||
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
|
||||
ORDER BY date ASC
|
||||
"""
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"date": row["date"],
|
||||
"count": int(row["count"] or 0),
|
||||
"source_count": int(row["source_count"] or 0),
|
||||
"top_title": row["top_title"] or "",
|
||||
}
|
||||
for row in rows
|
||||
if row["date"]
|
||||
]
|
||||
|
||||
def get_news_by_ids(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
article_ids: Iterable[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return selected persisted news items."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
|
||||
if not symbol or not ids:
|
||||
return []
|
||||
|
||||
placeholders = ",".join("?" for _ in ids)
|
||||
sql = f"""
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ? AND id IN ({placeholders})
|
||||
ORDER BY COALESCE(published_at, trade_date) DESC
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, [symbol, *ids]).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
@@ -10,6 +10,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from .research_db import ResearchDb
|
||||
from .runtime_db import RuntimeDb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -64,6 +66,8 @@ class StorageService:
|
||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.server_state_file = self.state_dir / "server_state.json"
|
||||
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||
self.research_db = ResearchDb(self.state_dir / "research.db")
|
||||
self.market_store = MarketStore()
|
||||
|
||||
# Feed history (for agent messages)
|
||||
self.max_feed_history = 200
|
||||
|
||||
236
backend/tests/test_cli.py
Normal file
236
backend/tests/test_cli.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
|
||||
from backend import cli
|
||||
|
||||
|
||||
def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp_path):
|
||||
project_root = tmp_path
|
||||
(project_root / ".env").write_text("FINNHUB_API_KEY=test\n", encoding="utf-8")
|
||||
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
|
||||
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
|
||||
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_update_market_store",
|
||||
lambda config_name, end_date=None: calls.append(("auto_update_market_store", config_name, end_date)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_enrich_market_store",
|
||||
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
|
||||
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
|
||||
|
||||
def fake_run(cmd, check=True, **kwargs):
|
||||
calls.append(("subprocess.run", cmd, check))
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli.subprocess, "run", fake_run)
|
||||
|
||||
cli.live(
|
||||
mock=False,
|
||||
config_name="smoke_fullstack",
|
||||
host="0.0.0.0",
|
||||
port=8765,
|
||||
trigger_time="now",
|
||||
poll_interval=10,
|
||||
clean=False,
|
||||
enable_memory=False,
|
||||
)
|
||||
|
||||
assert any(item[0] == "run_data_updater" for item in calls)
|
||||
assert any(
|
||||
item[0] == "auto_update_market_store" and item[1] == "smoke_fullstack"
|
||||
for item in calls
|
||||
)
|
||||
assert any(
|
||||
item[0] == "auto_enrich_market_store" and item[1] == "smoke_fullstack"
|
||||
for item in calls
|
||||
)
|
||||
run_call = next(item for item in calls if item[0] == "subprocess.run")
|
||||
assert run_call[1][:6] == [
|
||||
cli.sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"live",
|
||||
]
|
||||
|
||||
|
||||
def test_backtest_runs_full_market_store_prepare_before_start(monkeypatch, tmp_path):
|
||||
project_root = tmp_path
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
|
||||
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
|
||||
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_prepare_backtest_market_store",
|
||||
lambda config_name, start_date, end_date: calls.append(
|
||||
("auto_prepare_backtest_market_store", config_name, start_date, end_date)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_enrich_market_store",
|
||||
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
|
||||
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
|
||||
|
||||
def fake_run(cmd, check=True, **kwargs):
|
||||
calls.append(("subprocess.run", cmd, check))
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli.subprocess, "run", fake_run)
|
||||
|
||||
cli.backtest(
|
||||
start="2026-03-01",
|
||||
end="2026-03-10",
|
||||
config_name="smoke_fullstack",
|
||||
host="0.0.0.0",
|
||||
port=8765,
|
||||
poll_interval=10,
|
||||
clean=False,
|
||||
enable_memory=False,
|
||||
)
|
||||
|
||||
assert any(item[0] == "run_data_updater" for item in calls)
|
||||
assert any(
|
||||
item[0] == "auto_prepare_backtest_market_store"
|
||||
and item[1:] == ("smoke_fullstack", "2026-03-01", "2026-03-10")
|
||||
for item in calls
|
||||
)
|
||||
assert any(
|
||||
item[0] == "auto_enrich_market_store"
|
||||
and item[1] == "smoke_fullstack"
|
||||
and item[2] == "2026-03-10"
|
||||
for item in calls
|
||||
)
|
||||
run_call = next(item for item in calls if item[0] == "subprocess.run")
|
||||
assert run_call[1][:6] == [
|
||||
cli.sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"backtest",
|
||||
]
|
||||
|
||||
|
||||
def test_ingest_enrich_runs_batch_enrichment(monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL", "MSFT"])
|
||||
|
||||
class DummyStore:
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"enrich_symbols",
|
||||
lambda store, symbols, start_date=None, end_date=None, limit=200, analysis_source="local", skip_existing=True: calls.append(
|
||||
("enrich_symbols", symbols, start_date, end_date, limit, analysis_source, skip_existing)
|
||||
) or [
|
||||
{
|
||||
"symbol": symbol,
|
||||
"news_count": 3,
|
||||
"queued_count": 3,
|
||||
"analyzed": 3,
|
||||
"skipped_existing_count": 0,
|
||||
"deduped_count": 0,
|
||||
"llm_count": 0,
|
||||
"local_count": 3,
|
||||
}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
cli.ingest_enrich(
|
||||
tickers=None,
|
||||
start="2026-03-01",
|
||||
end="2026-03-10",
|
||||
limit=150,
|
||||
force=False,
|
||||
config_name="smoke_fullstack",
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
("enrich_symbols", ["AAPL", "MSFT"], "2026-03-01", "2026-03-10", 150, "local", True)
|
||||
]
|
||||
|
||||
|
||||
def test_ingest_report_reads_market_store_report(monkeypatch):
|
||||
calls = []
|
||||
printed = []
|
||||
|
||||
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL"])
|
||||
|
||||
class DummyStore:
|
||||
def get_enrich_report(self, symbols=None, start_date=None, end_date=None):
|
||||
calls.append(("get_enrich_report", symbols, start_date, end_date))
|
||||
return [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"raw_news_count": 10,
|
||||
"analyzed_news_count": 8,
|
||||
"coverage_pct": 80.0,
|
||||
"llm_count": 5,
|
||||
"local_count": 3,
|
||||
"latest_trade_date": "2026-03-16",
|
||||
"latest_analysis_at": "2026-03-16T09:00:00",
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
|
||||
monkeypatch.setattr(cli, "get_explain_model_info", lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"})
|
||||
monkeypatch.setattr(cli, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
|
||||
|
||||
cli.ingest_report(
|
||||
tickers=None,
|
||||
start="2026-03-01",
|
||||
end="2026-03-16",
|
||||
config_name="smoke_fullstack",
|
||||
only_problematic=False,
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
("get_enrich_report", ["AAPL"], "2026-03-01", "2026-03-16")
|
||||
]
|
||||
assert printed
|
||||
assert getattr(printed[0], "caption", "") == "Explain LLM: DASHSCOPE:qwen-max"
|
||||
|
||||
|
||||
def test_filter_problematic_report_rows_keeps_low_coverage_and_no_llm():
|
||||
rows = [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"coverage_pct": 100.0,
|
||||
"llm_count": 2,
|
||||
},
|
||||
{
|
||||
"symbol": "MSFT",
|
||||
"coverage_pct": 80.0,
|
||||
"llm_count": 1,
|
||||
},
|
||||
{
|
||||
"symbol": "NVDA",
|
||||
"coverage_pct": 100.0,
|
||||
"llm_count": 0,
|
||||
},
|
||||
]
|
||||
|
||||
filtered = cli._filter_problematic_report_rows(rows)
|
||||
|
||||
assert [row["symbol"] for row in filtered] == ["MSFT", "NVDA"]
|
||||
384
backend/tests/test_gateway_explain_handlers.py
Normal file
384
backend/tests/test_gateway_explain_handlers.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services.gateway import Gateway
|
||||
import backend.services.gateway as gateway_module
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def send(self, payload: str):
|
||||
self.messages.append(json.loads(payload))
|
||||
|
||||
|
||||
class DummyStateSync:
|
||||
def __init__(self, current_date="2026-03-16"):
|
||||
self.state = {"current_date": current_date}
|
||||
self.system_messages = []
|
||||
|
||||
def set_broadcast_fn(self, _fn):
|
||||
return None
|
||||
|
||||
def update_state(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
|
||||
class FakeMarketStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
|
||||
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
|
||||
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
|
||||
|
||||
def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100):
|
||||
self.calls.append(("get_news_items", symbol, start_date, end_date, limit))
|
||||
return [
|
||||
{
|
||||
"id": "news-1",
|
||||
"ticker": symbol,
|
||||
"date": end_date,
|
||||
"trade_date": end_date,
|
||||
"title": "Title",
|
||||
"summary": "Summary",
|
||||
"source": "polygon",
|
||||
}
|
||||
]
|
||||
|
||||
def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit))
|
||||
target_date = trade_date or end_date
|
||||
return [
|
||||
{
|
||||
"id": "news-1",
|
||||
"ticker": symbol,
|
||||
"date": target_date,
|
||||
"trade_date": target_date,
|
||||
"title": "Title",
|
||||
"summary": "Summary",
|
||||
"source": "polygon",
|
||||
"sentiment": "negative",
|
||||
"relevance": "high",
|
||||
"key_discussion": "Key discussion",
|
||||
}
|
||||
]
|
||||
|
||||
def get_news_by_ids_enriched(self, symbol, article_ids):
|
||||
self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids)))
|
||||
return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}]
|
||||
|
||||
def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200):
|
||||
self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit))
|
||||
return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}}
|
||||
|
||||
def get_story_cache(self, symbol, *, as_of_date):
|
||||
self.calls.append(("get_story_cache", symbol, as_of_date))
|
||||
return None
|
||||
|
||||
def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"):
|
||||
self.calls.append(("upsert_story_cache", symbol, as_of_date, source))
|
||||
|
||||
def delete_story_cache(self, symbol, *, as_of_date=None):
|
||||
self.calls.append(("delete_story_cache", symbol, as_of_date))
|
||||
return 1
|
||||
|
||||
def get_similar_day_cache(self, symbol, *, target_date):
|
||||
self.calls.append(("get_similar_day_cache", symbol, target_date))
|
||||
return None
|
||||
|
||||
def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"):
|
||||
self.calls.append(("upsert_similar_day_cache", symbol, target_date, source))
|
||||
|
||||
def delete_similar_day_cache(self, symbol, *, target_date=None):
|
||||
self.calls.append(("delete_similar_day_cache", symbol, target_date))
|
||||
return 1
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
self.calls.append(("get_ohlc", symbol, start_date, end_date))
|
||||
return [
|
||||
{"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103},
|
||||
{"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107},
|
||||
]
|
||||
|
||||
|
||||
def make_gateway(market_store=None):
|
||||
storage = SimpleNamespace(market_store=market_store or FakeMarketStore())
|
||||
pipeline = SimpleNamespace(state_sync=None)
|
||||
market_service = SimpleNamespace()
|
||||
state_sync = DummyStateSync()
|
||||
return Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage,
|
||||
pipeline=pipeline,
|
||||
state_sync=state_sync,
|
||||
config={"mode": "live"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_timeline(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_categories(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
|
||||
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
|
||||
assert websocket.messages[-1]["categories"]["macro"]["count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"news_count": len(news_rows),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"build_range_explanation",
|
||||
fake_build_range_explanation,
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_range_explain(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
|
||||
]
|
||||
assert websocket.messages[-1] == {
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": "AAPL",
|
||||
"result": {
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"news_count": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"build_range_explanation",
|
||||
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_range_explain(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"article_ids": ["news-99"],
|
||||
},
|
||||
)
|
||||
|
||||
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
|
||||
assert websocket.messages[-1]["result"]["news_count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_for_date(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "date": "2026-03-16", "limit": 10},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
|
||||
assert websocket.messages[-1]["date"] == "2026-03-16"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_story(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_story_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert "AAPL Story" in websocket.messages[-1]["story"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_similar_days(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "date": "2026-03-16", "top_k": 5},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_similar_days_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert isinstance(websocket.messages[-1]["items"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
|
||||
)
|
||||
|
||||
await gateway._handle_run_stock_enrich(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"force": True,
|
||||
"rebuild_story": True,
|
||||
"rebuild_similar_days": True,
|
||||
"story_date": "2026-03-16",
|
||||
"target_date": "2026-03-16",
|
||||
},
|
||||
)
|
||||
|
||||
assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls
|
||||
assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls
|
||||
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
|
||||
assert websocket.messages[-1]["stats"]["analyzed"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
|
||||
|
||||
await gateway._handle_run_stock_enrich(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"only_local_to_llm": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
|
||||
assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"]
|
||||
|
||||
|
||||
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
||||
gateway = make_gateway()
|
||||
captured = {}
|
||||
|
||||
class DummyTask:
|
||||
def done(self):
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
captured["cancelled"] = True
|
||||
|
||||
def fake_create_task(coro):
|
||||
captured["coro_name"] = coro.cr_code.co_name
|
||||
coro.close()
|
||||
return DummyTask()
|
||||
|
||||
monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task)
|
||||
|
||||
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
|
||||
|
||||
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
|
||||
gateway = make_gateway()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
"ingest_symbols",
|
||||
lambda symbols, mode="incremental": [
|
||||
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"])
|
||||
|
||||
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
|
||||
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
|
||||
assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1]
|
||||
65
backend/tests/test_historical_price_manager.py
Normal file
65
backend/tests/test_historical_price_manager.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
|
||||
|
||||
def test_preload_data_prefers_market_db():
|
||||
manager = HistoricalPriceManager()
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
market_rows = [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"date": "2026-03-09",
|
||||
"open": 100.0,
|
||||
"high": 103.0,
|
||||
"low": 99.0,
|
||||
"close": 102.0,
|
||||
"volume": 10_000,
|
||||
"vwap": 101.0,
|
||||
"transactions": 500,
|
||||
"source": "polygon",
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(manager._market_store, "get_ohlc", return_value=market_rows),
|
||||
patch.object(manager._router, "load_local_price_frame") as load_csv,
|
||||
):
|
||||
manager.preload_data("2026-03-01", "2026-03-10")
|
||||
|
||||
load_csv.assert_not_called()
|
||||
assert "AAPL" in manager._price_cache
|
||||
assert float(manager._price_cache["AAPL"].iloc[0]["close"]) == 102.0
|
||||
|
||||
|
||||
def test_preload_data_falls_back_to_csv():
|
||||
manager = HistoricalPriceManager()
|
||||
manager.subscribe(["MSFT"])
|
||||
|
||||
csv_df = pd.DataFrame(
|
||||
{
|
||||
"time": ["2026-03-09"],
|
||||
"open": [200.0],
|
||||
"high": [205.0],
|
||||
"low": [198.0],
|
||||
"close": [204.0],
|
||||
"volume": [20_000],
|
||||
}
|
||||
)
|
||||
csv_df["time"] = pd.to_datetime(csv_df["time"])
|
||||
csv_df["Date"] = csv_df["time"]
|
||||
csv_df.set_index("Date", inplace=True)
|
||||
|
||||
with (
|
||||
patch.object(manager._market_store, "get_ohlc", return_value=[]),
|
||||
patch.object(manager._router, "load_local_price_frame", return_value=csv_df) as load_csv,
|
||||
):
|
||||
manager.preload_data("2026-03-01", "2026-03-10")
|
||||
|
||||
load_csv.assert_called_once_with("MSFT")
|
||||
assert "MSFT" in manager._price_cache
|
||||
assert float(manager._price_cache["MSFT"].iloc[0]["close"]) == 204.0
|
||||
133
backend/tests/test_llm_enricher.py
Normal file
133
backend/tests/test_llm_enricher.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.enrich import llm_enricher
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(self, metadata):
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class DummyModel:
|
||||
def __init__(self, metadata):
|
||||
self.metadata = metadata
|
||||
self.calls = []
|
||||
|
||||
async def __call__(self, messages, structured_model=None, **kwargs):
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"structured_model": structured_model,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return DummyResponse(self.metadata)
|
||||
|
||||
|
||||
def test_analyze_news_row_with_llm_uses_agentscope_model(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"id": "news-1",
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Demand remains resilient",
|
||||
"summary": "Structured summary",
|
||||
"reason_growth": "Orders improved",
|
||||
"reason_decrease": "",
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_news_row_with_llm(
|
||||
{
|
||||
"id": "news-1",
|
||||
"title": "Apple expands AI features",
|
||||
"summary": "New devices and software updates were announced.",
|
||||
}
|
||||
)
|
||||
|
||||
assert result["sentiment"] == "positive"
|
||||
assert result["summary"] == "Structured summary"
|
||||
assert result["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsItem
|
||||
|
||||
|
||||
def test_analyze_news_rows_with_llm_uses_agentscope_structured_batch(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"id": "news-1",
|
||||
"relevance": "high",
|
||||
"sentiment": "negative",
|
||||
"key_discussion": "Margin pressure",
|
||||
"summary": "Batch summary",
|
||||
"reason_growth": "",
|
||||
"reason_decrease": "Costs rose",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_news_rows_with_llm(
|
||||
[
|
||||
{
|
||||
"id": "news-1",
|
||||
"title": "Apple margins pressured",
|
||||
"summary": "Costs increased this quarter.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert result["news-1"]["sentiment"] == "negative"
|
||||
assert result["news-1"]["reason_decrease"] == "Costs rose"
|
||||
assert result["news-1"]["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsBatch
|
||||
|
||||
|
||||
def test_analyze_range_with_llm_uses_agentscope_structured_output(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"summary": "该股在区间内震荡下行,相关新闻主要集中在盈利预期和供应链扰动。",
|
||||
"trend_analysis": "前半段受利空新闻压制,后半段跌幅收敛。",
|
||||
"bullish_factors": ["估值消化后出现部分承接"],
|
||||
"bearish_factors": ["盈利预期下修", "供应链扰动持续"],
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_range_analysis_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_range_with_llm(
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"price_change_pct": -3.42,
|
||||
}
|
||||
)
|
||||
|
||||
assert result["summary"].startswith("该股在区间内震荡下行")
|
||||
assert result["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert result["bearish_factors"] == ["盈利预期下修", "供应链扰动持续"]
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.RangeAnalysisPayload
|
||||
54
backend/tests/test_market_store_report.py
Normal file
54
backend/tests/test_market_store_report.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def test_get_enrich_report_summarizes_coverage(tmp_path: Path):
|
||||
store = MarketStore(tmp_path / "market_research.db")
|
||||
store.upsert_news(
|
||||
"AAPL",
|
||||
[
|
||||
{
|
||||
"id": "news-1",
|
||||
"published_utc": "2026-03-10T12:00:00Z",
|
||||
"title": "Apple earnings beat",
|
||||
"summary": "Revenue topped expectations",
|
||||
"tickers": ["AAPL"],
|
||||
},
|
||||
{
|
||||
"id": "news-2",
|
||||
"published_utc": "2026-03-11T12:00:00Z",
|
||||
"title": "Apple supply chain warning",
|
||||
"summary": "Outlook softened",
|
||||
"tickers": ["AAPL"],
|
||||
},
|
||||
],
|
||||
)
|
||||
store.set_trade_dates(
|
||||
[
|
||||
{"news_id": "news-1", "symbol": "AAPL", "trade_date": "2026-03-10"},
|
||||
{"news_id": "news-2", "symbol": "AAPL", "trade_date": "2026-03-11"},
|
||||
]
|
||||
)
|
||||
store.upsert_news_analysis(
|
||||
"AAPL",
|
||||
[
|
||||
{
|
||||
"news_id": "news-1",
|
||||
"trade_date": "2026-03-10",
|
||||
"summary": "LLM enriched",
|
||||
"analysis_source": "llm",
|
||||
}
|
||||
],
|
||||
analysis_source="llm",
|
||||
)
|
||||
|
||||
rows = store.get_enrich_report(["AAPL"])
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["symbol"] == "AAPL"
|
||||
assert rows[0]["raw_news_count"] == 2
|
||||
assert rows[0]["analyzed_news_count"] == 1
|
||||
assert rows[0]["coverage_pct"] == 50.0
|
||||
assert rows[0]["llm_count"] == 1
|
||||
174
backend/tests/test_news_enricher.py
Normal file
174
backend/tests/test_news_enricher.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.enrich import news_enricher
|
||||
|
||||
|
||||
def test_classify_news_row_falls_back_to_local_rules(monkeypatch):
|
||||
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
|
||||
result = news_enricher.classify_news_row(
|
||||
{
|
||||
"title": "Apple shares drop after weak guidance",
|
||||
"summary": "Investors reacted negatively to softer-than-expected outlook.",
|
||||
}
|
||||
)
|
||||
assert result["analysis_source"] == "local"
|
||||
assert result["sentiment"] == "negative"
|
||||
assert result["summary"]
|
||||
|
||||
|
||||
def test_classify_news_row_prefers_llm_when_available(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"analyze_news_row_with_llm",
|
||||
lambda row: {
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Demand resilience",
|
||||
"summary": "LLM summary",
|
||||
"reason_growth": "Orders remain strong",
|
||||
"reason_decrease": "",
|
||||
"raw_json": {"provider": "llm"},
|
||||
},
|
||||
)
|
||||
result = news_enricher.classify_news_row(
|
||||
{
|
||||
"title": "Apple expands AI features",
|
||||
"summary": "New devices and software updates were announced.",
|
||||
}
|
||||
)
|
||||
assert result["analysis_source"] == "llm"
|
||||
assert result["sentiment"] == "positive"
|
||||
assert result["summary"] == "LLM summary"
|
||||
|
||||
|
||||
def test_build_analysis_rows_prefers_batch_llm_and_dedupes(monkeypatch):
|
||||
monkeypatch.setattr(news_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(news_enricher, "get_env_int", lambda key, default=0: 8)
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"analyze_news_rows_with_llm",
|
||||
lambda rows: {
|
||||
"news-1": {
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Batch result",
|
||||
"summary": "Batch summary",
|
||||
"reason_growth": "Growth",
|
||||
"reason_decrease": "",
|
||||
"raw_json": {"provider": "batch"},
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
|
||||
rows = news_enricher.build_analysis_rows(
|
||||
symbol="AAPL",
|
||||
news_rows=[
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
|
||||
{"id": "news-2", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
|
||||
],
|
||||
ohlc_rows=[],
|
||||
)
|
||||
rows, stats = rows
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["analysis_source"] == "llm"
|
||||
assert rows[0]["summary"] == "Batch summary"
|
||||
assert stats["deduped_count"] == 1
|
||||
assert stats["llm_count"] == 1
|
||||
|
||||
|
||||
def test_enrich_news_for_symbol_skips_existing(monkeypatch):
|
||||
class DummyStore:
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return [
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
|
||||
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
|
||||
]
|
||||
|
||||
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
|
||||
return {"news-1"}
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
return []
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
|
||||
self.rows = rows
|
||||
return len(rows)
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"build_analysis_rows",
|
||||
lambda symbol, news_rows, ohlc_rows: (
|
||||
[
|
||||
{
|
||||
"news_id": row["id"],
|
||||
"trade_date": row["trade_date"],
|
||||
"summary": row["summary"],
|
||||
"analysis_source": "local",
|
||||
}
|
||||
for row in news_rows
|
||||
],
|
||||
{"deduped_count": 0, "llm_count": 0, "local_count": len(news_rows)},
|
||||
),
|
||||
)
|
||||
store = DummyStore()
|
||||
result = news_enricher.enrich_news_for_symbol(store, "AAPL")
|
||||
assert result["news_count"] == 2
|
||||
assert result["queued_count"] == 1
|
||||
assert result["skipped_existing_count"] == 1
|
||||
assert len(store.rows) == 1
|
||||
assert store.rows[0]["news_id"] == "news-2"
|
||||
|
||||
|
||||
def test_enrich_news_for_symbol_only_reanalyzes_local(monkeypatch):
|
||||
class DummyStore:
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return [
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
|
||||
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
|
||||
{"id": "news-3", "trade_date": "2026-03-12", "title": "Three", "summary": "Three"},
|
||||
]
|
||||
|
||||
def get_analyzed_news_sources(self, symbol, start_date=None, end_date=None):
|
||||
return {"news-1": "local", "news-2": "llm"}
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
return []
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
|
||||
self.rows = rows
|
||||
return len(rows)
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"build_analysis_rows",
|
||||
lambda symbol, news_rows, ohlc_rows: (
|
||||
[
|
||||
{
|
||||
"news_id": row["id"],
|
||||
"trade_date": row["trade_date"],
|
||||
"summary": row["summary"],
|
||||
"analysis_source": "llm" if row["id"] == "news-1" else "local",
|
||||
}
|
||||
for row in news_rows
|
||||
],
|
||||
{"deduped_count": 0, "llm_count": 1, "local_count": 0},
|
||||
),
|
||||
)
|
||||
|
||||
store = DummyStore()
|
||||
result = news_enricher.enrich_news_for_symbol(
|
||||
store,
|
||||
"AAPL",
|
||||
only_reanalyze_local=True,
|
||||
)
|
||||
|
||||
assert result["news_count"] == 3
|
||||
assert result["queued_count"] == 1
|
||||
assert result["skipped_existing_count"] == 2
|
||||
assert result["only_reanalyze_local"] is True
|
||||
assert result["upgraded_local_to_llm_count"] == 1
|
||||
assert result["execution_summary"]["upgraded_dates"] == ["2026-03-10"]
|
||||
assert result["execution_summary"]["remaining_local_titles"] == []
|
||||
assert result["execution_summary"]["skipped_missing_analysis_count"] == 1
|
||||
assert result["execution_summary"]["skipped_non_local_count"] == 1
|
||||
assert [row["news_id"] for row in store.rows] == ["news-1"]
|
||||
54
backend/tests/test_range_explainer.py
Normal file
54
backend/tests/test_range_explainer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from backend.explain import range_explainer
|
||||
|
||||
|
||||
def test_build_range_explanation_prefers_llm_text_when_available(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
range_explainer,
|
||||
"get_prices",
|
||||
lambda ticker, start_date, end_date: [
|
||||
SimpleNamespace(open=100, close=98, high=102, low=97, volume=1000),
|
||||
SimpleNamespace(open=98, close=96, high=99, low=95, volume=1100),
|
||||
SimpleNamespace(open=96, close=97, high=98, low=94, volume=1200),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
range_explainer,
|
||||
"analyze_range_with_llm",
|
||||
lambda payload: {
|
||||
"summary": "区间内整体偏弱,主题集中在盈利预期和供应链风险。",
|
||||
"trend_analysis": "前半段快速下探,后半段出现修复。",
|
||||
"bullish_factors": ["回调后出现承接"],
|
||||
"bearish_factors": ["盈利预期承压"],
|
||||
"model_label": "DASHSCOPE:qwen-max",
|
||||
},
|
||||
)
|
||||
|
||||
result = range_explainer.build_range_explanation(
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-10",
|
||||
end_date="2026-03-16",
|
||||
news_rows=[
|
||||
{
|
||||
"id": "news-1",
|
||||
"trade_date": "2026-03-10",
|
||||
"title": "Apple margin pressure concerns grow",
|
||||
"summary": "Investors focused on weaker margin outlook.",
|
||||
"sentiment": "negative",
|
||||
"relevance": "high",
|
||||
"ret_t0": -0.02,
|
||||
"reason_decrease": "盈利预期承压",
|
||||
"category": "earnings",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert result["analysis"]["summary"] == "区间内整体偏弱,主题集中在盈利预期和供应链风险。"
|
||||
assert result["analysis"]["trend_analysis"] == "前半段快速下探,后半段出现修复。"
|
||||
assert result["analysis"]["bullish_factors"] == ["回调后出现承接"]
|
||||
assert result["analysis"]["analysis_source"] == "llm"
|
||||
assert result["analysis"]["analysis_model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert result["news_count"] == 1
|
||||
Reference in New Issue
Block a user