Add explain analysis workflow and UI

This commit is contained in:
2026-03-16 22:28:41 +08:00
parent 3a5558b576
commit 1f5ee3698e
49 changed files with 8888 additions and 1476 deletions

View File

@@ -12,7 +12,7 @@ import os
import shutil
import subprocess
import sys
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional
from zoneinfo import ZoneInfo
@@ -21,18 +21,27 @@ import typer
from rich.console import Console
from rich.panel import Panel
from rich.prompt import Confirm
from rich.table import Table
from dotenv import load_dotenv
from backend.agents.prompt_loader import PromptLoader
from backend.agents.workspace_manager import WorkspaceManager
from backend.data.market_ingest import ingest_symbols
from backend.data.market_store import MarketStore
from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled
from backend.enrich.news_enricher import enrich_symbols
app = typer.Typer(
name="evotraders",
help="EvoTraders: A self-evolving multi-agent trading system",
add_completion=False,
)
ingest_app = typer.Typer(help="Ingest Polygon market data into the research warehouse.")
app.add_typer(ingest_app, name="ingest")
console = Console()
_prompt_loader = PromptLoader()
load_dotenv()
def get_project_root() -> Path:
@@ -204,6 +213,189 @@ def initialize_workspace(config_name: str) -> Path:
return workspace_manager.get_run_dir(config_name)
def _resolve_symbols(raw_tickers: Optional[str], config_name: Optional[str] = None) -> list[str]:
"""Resolve symbols from explicit input or runtime bootstrap config."""
if raw_tickers and raw_tickers.strip():
return [
item.strip().upper()
for item in raw_tickers.split(",")
if item.strip()
]
workspace_manager = WorkspaceManager(project_root=get_project_root())
bootstrap_path = workspace_manager.get_run_dir(config_name or "default") / "BOOTSTRAP.md"
if bootstrap_path.exists():
content = bootstrap_path.read_text(encoding="utf-8")
for line in content.splitlines():
if line.strip().startswith("tickers:"):
raw = line.split(":", 1)[1]
return [
item.strip().upper()
for item in raw.split(",")
if item.strip()
]
return []
def _filter_problematic_report_rows(rows: list[dict]) -> list[dict]:
"""Keep tickers with incomplete coverage or without any LLM-enriched rows."""
return [
row
for row in rows
if float(row.get("coverage_pct") or 0.0) < 100.0
or int(row.get("llm_count") or 0) == 0
]
def auto_update_market_store(
config_name: str,
*,
end_date: Optional[str] = None,
) -> None:
"""Refresh the long-lived Polygon market store for the active watchlist."""
api_key = os.getenv("POLYGON_API_KEY", "").strip()
if not api_key:
console.print(
"[dim]Skipping Polygon market store update: POLYGON_API_KEY not set[/dim]",
)
return
symbols = _resolve_symbols(None, config_name)
if not symbols:
console.print(
f"[dim]Skipping Polygon market store update: no tickers found for config '{config_name}'[/dim]",
)
return
target_end = end_date or datetime.now().date().isoformat()
console.print(
f"[cyan]Updating Polygon market store for {', '.join(symbols)} -> {target_end}[/cyan]",
)
try:
results = ingest_symbols(
symbols,
mode="incremental",
end_date=target_end,
)
except Exception as exc:
console.print(
f"[yellow]Polygon market store update failed, continuing startup: {exc}[/yellow]",
)
return
for result in results:
console.print(
"[green]"
f"{result['symbol']}"
"[/green] "
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
)
def auto_prepare_backtest_market_store(
config_name: str,
*,
start_date: str,
end_date: str,
) -> None:
"""Ensure the market store has the requested backtest window for the active watchlist."""
api_key = os.getenv("POLYGON_API_KEY", "").strip()
if not api_key:
console.print(
"[dim]Skipping Polygon backtest preload: POLYGON_API_KEY not set[/dim]",
)
return
symbols = _resolve_symbols(None, config_name)
if not symbols:
console.print(
f"[dim]Skipping Polygon backtest preload: no tickers found for config '{config_name}'[/dim]",
)
return
console.print(
f"[cyan]Preparing Polygon market store for backtest {start_date} -> {end_date} "
f"({', '.join(symbols)})[/cyan]",
)
try:
results = ingest_symbols(
symbols,
mode="full",
start_date=start_date,
end_date=end_date,
)
except Exception as exc:
console.print(
f"[yellow]Polygon backtest preload failed, continuing startup: {exc}[/yellow]",
)
return
for result in results:
console.print(
"[green]"
f"{result['symbol']}"
"[/green] "
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
)
def auto_enrich_market_store(
config_name: str,
*,
end_date: Optional[str] = None,
lookback_days: int = 120,
force: bool = False,
) -> None:
"""Refresh explain-oriented enriched news for the active watchlist."""
symbols = _resolve_symbols(None, config_name)
if not symbols:
console.print(
f"[dim]Skipping explain enrich: no tickers found for config '{config_name}'[/dim]",
)
return
target_end = end_date or datetime.now().date().isoformat()
try:
end_dt = datetime.strptime(target_end, "%Y-%m-%d")
except ValueError:
console.print(
f"[yellow]Skipping explain enrich: invalid end date {target_end}[/yellow]",
)
return
start_date = (end_dt - timedelta(days=max(1, lookback_days))).date().isoformat()
console.print(
f"[cyan]Refreshing explain enrich for {', '.join(symbols)} -> {target_end}[/cyan]",
)
store = MarketStore()
try:
results = enrich_symbols(
store,
symbols,
start_date=start_date,
end_date=target_end,
limit=300,
skip_existing=not force,
)
except Exception as exc:
console.print(
f"[yellow]Explain enrich failed, continuing startup: {exc}[/yellow]",
)
return
for result in results:
console.print(
"[green]"
f"{result['symbol']}"
"[/green] "
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
f"llm={result['llm_count']} local={result['local_count']}"
)
@app.command("init-workspace")
def init_workspace(
config_name: str = typer.Option(
@@ -223,6 +415,213 @@ def init_workspace(
)
@ingest_app.command("full")
def ingest_full(
tickers: Optional[str] = typer.Option(
None,
"--tickers",
"-t",
help="Comma-separated tickers to ingest",
),
start: Optional[str] = typer.Option(
None,
"--start",
help="Start date for full ingestion (YYYY-MM-DD)",
),
end: Optional[str] = typer.Option(
None,
"--end",
help="End date for ingestion (YYYY-MM-DD)",
),
config_name: str = typer.Option(
"default",
"--config-name",
"-c",
help="Fallback config to read tickers from BOOTSTRAP.md",
),
):
"""Run full Polygon ingestion for the specified symbols."""
symbols = _resolve_symbols(tickers, config_name)
if not symbols:
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
raise typer.Exit(1)
console.print(f"[cyan]Starting full Polygon ingest for {', '.join(symbols)}[/cyan]")
results = ingest_symbols(symbols, mode="full", start_date=start, end_date=end)
for result in results:
console.print(
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
)
@ingest_app.command("update")
def ingest_update(
tickers: Optional[str] = typer.Option(
None,
"--tickers",
"-t",
help="Comma-separated tickers to update",
),
end: Optional[str] = typer.Option(
None,
"--end",
help="Optional end date override (YYYY-MM-DD)",
),
config_name: str = typer.Option(
"default",
"--config-name",
"-c",
help="Fallback config to read tickers from BOOTSTRAP.md",
),
):
"""Run incremental Polygon ingestion using stored watermarks."""
symbols = _resolve_symbols(tickers, config_name)
if not symbols:
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
raise typer.Exit(1)
console.print(f"[cyan]Starting incremental Polygon ingest for {', '.join(symbols)}[/cyan]")
results = ingest_symbols(symbols, mode="incremental", end_date=end)
for result in results:
console.print(
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
)
@ingest_app.command("enrich")
def ingest_enrich(
tickers: Optional[str] = typer.Option(
None,
"--tickers",
"-t",
help="Comma-separated tickers to enrich",
),
start: Optional[str] = typer.Option(
None,
"--start",
help="Optional start date for enrichment window (YYYY-MM-DD)",
),
end: Optional[str] = typer.Option(
None,
"--end",
help="Optional end date for enrichment window (YYYY-MM-DD)",
),
limit: int = typer.Option(
300,
"--limit",
help="Maximum raw news rows per ticker to analyze",
),
force: bool = typer.Option(
False,
"--force",
help="Re-analyze already enriched news instead of only missing rows",
),
config_name: str = typer.Option(
"default",
"--config-name",
"-c",
help="Fallback config to read tickers from BOOTSTRAP.md",
),
):
"""Run explain-oriented news enrichment for symbols already in the market store."""
symbols = _resolve_symbols(tickers, config_name)
if not symbols:
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
raise typer.Exit(1)
console.print(f"[cyan]Starting explain enrich for {', '.join(symbols)}[/cyan]")
store = MarketStore()
results = enrich_symbols(
store,
symbols,
start_date=start,
end_date=end,
limit=max(10, limit),
skip_existing=not force,
)
for result in results:
console.print(
f"[green]{result['symbol']}[/green] "
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
f"llm={result['llm_count']} local={result['local_count']}"
)
@ingest_app.command("report")
def ingest_report(
tickers: Optional[str] = typer.Option(
None,
"--tickers",
"-t",
help="Optional comma-separated tickers to report",
),
start: Optional[str] = typer.Option(
None,
"--start",
help="Optional start date for report window (YYYY-MM-DD)",
),
end: Optional[str] = typer.Option(
None,
"--end",
help="Optional end date for report window (YYYY-MM-DD)",
),
config_name: str = typer.Option(
"default",
"--config-name",
"-c",
help="Fallback config to read tickers from BOOTSTRAP.md",
),
only_problematic: bool = typer.Option(
False,
"--only-problematic",
help="Only show tickers with incomplete coverage or no LLM-enriched news",
),
):
"""Show explain enrichment coverage and freshness per ticker."""
symbols = _resolve_symbols(tickers, config_name)
store = MarketStore()
report_rows = store.get_enrich_report(
symbols=symbols or None,
start_date=start,
end_date=end,
)
if only_problematic:
report_rows = _filter_problematic_report_rows(report_rows)
if not report_rows:
if only_problematic:
console.print("[green]No problematic enrich report rows found for the requested scope[/green]")
else:
console.print("[yellow]No enrich report rows found for the requested scope[/yellow]")
raise typer.Exit(0)
model_info = get_explain_model_info()
model_label = model_info["label"] if llm_enrichment_enabled() else "disabled"
table = Table(title="Explain Enrichment Report")
table.add_column("Ticker", style="cyan")
table.add_column("Raw News", justify="right")
table.add_column("Analyzed", justify="right")
table.add_column("Coverage", justify="right")
table.add_column("LLM", justify="right")
table.add_column("Local", justify="right")
table.add_column("Latest Trade Date")
table.add_column("Latest Analysis")
table.caption = f"Explain LLM: {model_label}"
for row in report_rows:
table.add_row(
row["symbol"],
str(row["raw_news_count"]),
str(row["analyzed_news_count"]),
f'{row["coverage_pct"]:.1f}%',
str(row["llm_count"]),
str(row["local_count"]),
str(row["latest_trade_date"] or "-"),
str(row["latest_analysis_at"] or "-"),
)
console.print(table)
@app.command()
def backtest(
start: Optional[str] = typer.Option(
@@ -331,6 +730,16 @@ def backtest(
# Run data updater
run_data_updater(project_root)
auto_prepare_backtest_market_store(
config_name,
start_date=start,
end_date=end,
)
auto_enrich_market_store(
config_name,
end_date=end,
force=False,
)
# Build command using backend.main
cmd = [
@@ -514,6 +923,15 @@ def live(
# Data update (if not mock mode)
if not mock:
run_data_updater(project_root)
auto_update_market_store(
config_name,
end_date=nyse_now.date().isoformat(),
)
auto_enrich_market_store(
config_name,
end_date=nyse_now.date().isoformat(),
force=False,
)
else:
console.print(
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",

View File

@@ -47,6 +47,10 @@ class StateSync:
"""Set current simulation date for backtest-compatible timestamps"""
self._simulation_date = date
def clear_simulation_date(self):
"""Disable backtest timestamp simulation and use wall-clock time."""
self._simulation_date = None
def _get_timestamp_ms(self) -> int:
"""
Get timestamp in milliseconds.
@@ -97,9 +101,21 @@ class StateSync:
if not self._enabled:
return
# Ensure timestamp exists (use simulation date if in backtest mode)
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
# frontend displays local wall time correctly instead of date-only UTC.
if "timestamp" not in event:
if self._simulation_date:
ts_ms = event.get("ts")
if ts_ms is not None:
try:
event["timestamp"] = datetime.fromtimestamp(
float(ts_ms) / 1000.0,
).isoformat()
except (TypeError, ValueError, OSError):
if self._simulation_date:
event["timestamp"] = f"{self._simulation_date}"
else:
event["timestamp"] = datetime.now().isoformat()
elif self._simulation_date:
event["timestamp"] = f"{self._simulation_date}"
else:
event["timestamp"] = datetime.now().isoformat()
@@ -238,9 +254,12 @@ class StateSync:
"""Called at start of trading cycle"""
self._state["current_date"] = date
self._state["status"] = "running"
self.set_simulation_date(
date,
) # Set for backtest-compatible timestamps
if self._state.get("server_mode") == "backtest":
self.set_simulation_date(
date,
) # Set for backtest-compatible timestamps
else:
self.clear_simulation_date()
await self.emit(
{

View File

@@ -7,6 +7,7 @@ from datetime import datetime
from typing import Callable, Dict, List, Optional
import pandas as pd
from backend.data.market_store import MarketStore
from backend.data.provider_utils import normalize_symbol
from backend.data.provider_router import get_provider_router
@@ -26,6 +27,7 @@ class HistoricalPriceManager:
self.close_prices = {}
self.running = False
self._router = get_provider_router()
self._market_store = MarketStore()
def subscribe(
self,
@@ -58,21 +60,48 @@ class HistoricalPriceManager:
logger.warning(f"Failed to load CSV for {symbol}: {e}")
return None
def _load_from_market_db(
self,
symbol: str,
start_date: str,
end_date: str,
) -> Optional[pd.DataFrame]:
"""Load price data from the long-lived market research database."""
try:
rows = self._market_store.get_ohlc(symbol, start_date, end_date)
if not rows:
return None
df = pd.DataFrame(rows)
if df.empty or "date" not in df.columns:
return None
df["Date"] = pd.to_datetime(df["date"])
df.set_index("Date", inplace=True)
df.sort_index(inplace=True)
return df
except Exception as e:
logger.warning(f"Failed to load market DB data for {symbol}: {e}")
return None
def preload_data(self, start_date: str, end_date: str):
"""Preload historical data from local CSV files."""
"""Preload historical data from market DB first, then local CSV."""
logger.info(f"Preloading data: {start_date} to {end_date}")
for symbol in self.subscribed_symbols:
if symbol in self._price_cache:
continue
# Load from local CSV file directly
df = self._load_from_market_db(symbol, start_date, end_date)
if df is not None and not df.empty:
self._price_cache[symbol] = df
logger.info(f"Loaded {symbol} from market DB: {len(df)} records")
continue
df = self._load_from_csv(symbol)
if df is not None and not df.empty:
self._price_cache[symbol] = df
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
else:
logger.warning(f"No CSV data for {symbol}")
logger.warning(f"No market DB or CSV data for {symbol}")
def set_date(self, date: str):
"""Set current trading date and update prices"""

View File

@@ -0,0 +1,149 @@
# -*- coding: utf-8 -*-
"""Ingest Polygon market data into the long-lived research warehouse."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Iterable
from backend.data.market_store import MarketStore
from backend.data.news_alignment import align_news_for_symbol
from backend.data.polygon_client import (
fetch_news,
fetch_ohlc,
fetch_ticker_details,
)
from backend.data.provider_utils import normalize_symbol
def _today_utc() -> str:
return datetime.now(timezone.utc).date().isoformat()
def _default_start(years: int = 2) -> str:
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
def ingest_ticker_history(
symbol: str,
*,
start_date: str | None = None,
end_date: str | None = None,
store: MarketStore | None = None,
) -> dict:
"""Fetch and persist Polygon OHLC + news for a ticker."""
ticker = normalize_symbol(symbol)
start = start_date or _default_start()
end = end_date or _today_utc()
market_store = store or MarketStore()
details = fetch_ticker_details(ticker)
market_store.upsert_ticker(
symbol=ticker,
name=details.get("name"),
sector=details.get("sic_description"),
is_active=bool(details.get("active", True)),
)
ohlc_rows = fetch_ohlc(ticker, start, end)
news_rows = fetch_news(ticker, start, end)
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon")
news_count = market_store.upsert_news(ticker, news_rows, source="polygon")
aligned_count = align_news_for_symbol(market_store, ticker)
market_store.update_fetch_watermark(symbol=ticker, price_date=end, news_date=end)
return {
"symbol": ticker,
"start_date": start,
"end_date": end,
"prices": price_count,
"news": news_count,
"aligned": aligned_count,
}
def update_ticker_incremental(
symbol: str,
*,
end_date: str | None = None,
store: MarketStore | None = None,
) -> dict:
"""Incrementally fetch OHLC + news since the last watermark."""
ticker = normalize_symbol(symbol)
market_store = store or MarketStore()
watermarks = market_store.get_ticker_watermarks(ticker)
end = end_date or _today_utc()
start_prices = (
(datetime.fromisoformat(watermarks["last_price_fetch"]) + timedelta(days=1)).date().isoformat()
if watermarks.get("last_price_fetch")
else _default_start()
)
start_news = (
(datetime.fromisoformat(watermarks["last_news_fetch"]) + timedelta(days=1)).date().isoformat()
if watermarks.get("last_news_fetch")
else _default_start()
)
details = fetch_ticker_details(ticker)
market_store.upsert_ticker(
symbol=ticker,
name=details.get("name"),
sector=details.get("sic_description"),
is_active=bool(details.get("active", True)),
)
ohlc_rows = [] if start_prices > end else fetch_ohlc(ticker, start_prices, end)
news_rows = [] if start_news > end else fetch_news(ticker, start_news, end)
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon") if ohlc_rows else 0
news_count = market_store.upsert_news(ticker, news_rows, source="polygon") if news_rows else 0
aligned_count = align_news_for_symbol(market_store, ticker)
market_store.update_fetch_watermark(
symbol=ticker,
price_date=end if ohlc_rows or watermarks.get("last_price_fetch") else None,
news_date=end if news_rows or watermarks.get("last_news_fetch") else None,
)
return {
"symbol": ticker,
"start_price_date": start_prices,
"start_news_date": start_news,
"end_date": end,
"prices": price_count,
"news": news_count,
"aligned": aligned_count,
}
def ingest_symbols(
symbols: Iterable[str],
*,
mode: str = "incremental",
start_date: str | None = None,
end_date: str | None = None,
store: MarketStore | None = None,
) -> list[dict]:
"""Fetch Polygon data for a list of tickers."""
market_store = store or MarketStore()
results = []
for symbol in symbols:
ticker = normalize_symbol(symbol)
if not ticker:
continue
if mode == "full":
results.append(
ingest_ticker_history(
ticker,
start_date=start_date,
end_date=end_date,
store=market_store,
)
)
else:
results.append(
update_ticker_incremental(
ticker,
end_date=end_date,
store=market_store,
)
)
return results

1074
backend/data/market_store.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""Align persisted news to the nearest NYSE trading date."""
from __future__ import annotations
from datetime import time
import pandas as pd
import pandas_market_calendars as mcal
from backend.data.market_store import MarketStore
NYSE_CALENDAR = mcal.get_calendar("NYSE")
def _next_trading_day(date_str: str) -> str:
start = pd.Timestamp(date_str).tz_localize(None)
sessions = NYSE_CALENDAR.valid_days(
start_date=(start - pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
end_date=(start + pd.Timedelta(days=10)).strftime("%Y-%m-%d"),
)
future = [
pd.Timestamp(day).tz_localize(None).strftime("%Y-%m-%d")
for day in sessions
if pd.Timestamp(day).tz_localize(None) >= start
]
return future[0] if future else date_str
def resolve_trade_date(published_utc: str | None) -> str | None:
"""Map a published timestamp to an NYSE trade date."""
if not published_utc:
return None
timestamp = pd.to_datetime(published_utc, utc=True, errors="coerce")
if pd.isna(timestamp):
return None
nyse_time = timestamp.tz_convert("America/New_York")
candidate = nyse_time.date().isoformat()
valid_days = NYSE_CALENDAR.valid_days(start_date=candidate, end_date=candidate)
if len(valid_days) == 0:
return _next_trading_day(candidate)
if nyse_time.time() >= time(16, 0):
return _next_trading_day((nyse_time + pd.Timedelta(days=1)).date().isoformat())
return candidate
def align_news_for_symbol(store: MarketStore, symbol: str, *, limit: int = 5000) -> int:
"""Fill missing trade_date values for one ticker."""
pending = store.get_news_without_trade_date(symbol, limit=limit)
updates = []
for row in pending:
trade_date = resolve_trade_date(row.get("published_utc"))
if trade_date:
updates.append(
{
"news_id": row["news_id"],
"symbol": row["symbol"],
"trade_date": trade_date,
}
)
if not updates:
return 0
return store.set_trade_dates(updates)

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Polygon client used for long-lived market research ingestion."""
from __future__ import annotations
import os
import time
from datetime import datetime, timezone
from typing import Any, Optional
import requests
BASE = "https://api.polygon.io"
def _headers() -> dict[str, str]:
api_key = os.getenv("POLYGON_API_KEY", "").strip()
if not api_key:
raise ValueError("Missing required API key: POLYGON_API_KEY")
return {"Authorization": f"Bearer {api_key}"}
def http_get(
url: str,
params: Optional[dict[str, Any]] = None,
*,
max_retries: int = 8,
backoff: float = 2.0,
) -> requests.Response:
"""HTTP GET with exponential backoff and 429 handling."""
for attempt in range(max_retries):
try:
response = requests.get(
url,
params=params or {},
headers=_headers(),
timeout=30,
)
except requests.RequestException:
time.sleep((backoff**attempt) + 0.5)
if attempt == max_retries - 1:
raise
continue
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
wait = (
float(retry_after)
if retry_after and retry_after.isdigit()
else min((backoff**attempt) + 1.0, 60.0)
)
time.sleep(wait)
if attempt == max_retries - 1:
response.raise_for_status()
continue
if 500 <= response.status_code < 600:
time.sleep(min((backoff**attempt) + 1.0, 60.0))
if attempt == max_retries - 1:
response.raise_for_status()
continue
response.raise_for_status()
return response
raise RuntimeError("Unreachable")
def fetch_ticker_details(symbol: str) -> dict[str, Any]:
"""Fetch company metadata from Polygon."""
response = http_get(f"{BASE}/v3/reference/tickers/{symbol}")
return response.json().get("results", {}) or {}
def fetch_ohlc(symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
"""Fetch daily OHLC data from Polygon."""
response = http_get(
f"{BASE}/v2/aggs/ticker/{symbol}/range/1/day/{start_date}/{end_date}",
params={"adjusted": "true", "sort": "asc", "limit": 50000},
)
results = response.json().get("results") or []
rows: list[dict[str, Any]] = []
for item in results:
rows.append(
{
"date": datetime.fromtimestamp(
int(item["t"]) / 1000,
tz=timezone.utc,
).date().isoformat(),
"open": item.get("o"),
"high": item.get("h"),
"low": item.get("l"),
"close": item.get("c"),
"volume": item.get("v"),
"vwap": item.get("vw"),
"transactions": item.get("n"),
}
)
return rows
def fetch_news(
symbol: str,
start_date: str,
end_date: str,
*,
per_page: int = 50,
page_sleep: float = 1.2,
max_pages: Optional[int] = None,
) -> list[dict[str, Any]]:
"""Fetch all Polygon news for a ticker, with pagination."""
url = f"{BASE}/v2/reference/news"
params = {
"ticker": symbol,
"published_utc.gte": start_date,
"published_utc.lte": end_date,
"limit": per_page,
"order": "asc",
}
next_url: Optional[str] = None
pages = 0
all_articles: list[dict[str, Any]] = []
seen_ids: set[str] = set()
while True:
response = http_get(next_url or url, params=None if next_url else params)
data = response.json()
results = data.get("results") or []
if not results:
break
for item in results:
article_id = item.get("id")
if article_id and article_id in seen_ids:
continue
all_articles.append(
{
"id": article_id,
"publisher": (item.get("publisher") or {}).get("name"),
"title": item.get("title"),
"author": item.get("author"),
"published_utc": item.get("published_utc"),
"amp_url": item.get("amp_url"),
"article_url": item.get("article_url"),
"tickers": item.get("tickers"),
"description": item.get("description"),
"insights": item.get("insights"),
}
)
if article_id:
seen_ids.add(article_id)
next_url = data.get("next_url")
pages += 1
if max_pages is not None and pages >= max_pages:
break
if not next_url:
break
time.sleep(page_sleep)
return all_articles

View File

@@ -0,0 +1,2 @@
"""News enrichment utilities for explain-oriented market research."""

View File

@@ -0,0 +1,296 @@
# -*- coding: utf-8 -*-
"""Optional AgentScope-backed news enrichment with safe local fallback."""
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import BaseModel, Field
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
from backend.llm.models import create_model
class EnrichedNewsItem(BaseModel):
"""Structured output schema for one enriched article."""
id: str = Field(description="The source article id")
relevance: str = Field(description="One of high, medium, low")
sentiment: str = Field(description="One of positive, negative, neutral")
key_discussion: str = Field(description="Concise core discussion")
summary: str = Field(description="Concise factual summary")
reason_growth: str = Field(description="Growth-oriented reason if present")
reason_decrease: str = Field(description="Downside-oriented reason if present")
class EnrichedNewsBatch(BaseModel):
"""Structured output schema for a batch of enriched articles."""
items: list[EnrichedNewsItem]
class RangeAnalysisPayload(BaseModel):
"""Structured output schema for range explanation text."""
summary: str = Field(description="Concise Chinese range summary for the selected window")
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
def get_explain_model_info() -> dict[str, str]:
"""Resolve provider/model used by explain enrichment."""
provider = canonicalize_model_provider(
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
or get_env_str("MODEL_PROVIDER", "OPENAI"),
)
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
"MODEL_NAME",
"gpt-4o-mini",
)
return {
"provider": provider,
"model_name": model_name,
"label": f"{provider}:{model_name}",
}
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
if isinstance(payload, BaseModel):
payload = payload.model_dump()
if not isinstance(payload, dict):
return None
return {
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
"summary": str(payload.get("summary") or "").strip() or None,
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
"raw_json": payload,
}
def _run_async(coro: Any) -> Any:
"""Run an async AgentScope model call from sync code, even inside a running loop."""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(asyncio.run, coro)
return future.result()
def _get_explain_model():
"""Create an AgentScope model for explain enrichment."""
model_info = get_explain_model_info()
return create_model(
model_name=model_info["model_name"],
provider=model_info["provider"],
stream=False,
generate_kwargs={"temperature": 0.1},
)
def llm_enrichment_enabled() -> bool:
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
return False
provider = get_explain_model_info()["provider"]
provider_key_map = {
"OPENAI": "OPENAI_API_KEY",
"ANTHROPIC": "ANTHROPIC_API_KEY",
"DASHSCOPE": "DASHSCOPE_API_KEY",
"ALIBABA": "DASHSCOPE_API_KEY",
"GEMINI": "GOOGLE_API_KEY",
"GOOGLE": "GOOGLE_API_KEY",
"DEEPSEEK": "DEEPSEEK_API_KEY",
"GROQ": "GROQ_API_KEY",
"OPENROUTER": "OPENROUTER_API_KEY",
}
env_key = provider_key_map.get(provider)
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
def llm_range_analysis_enabled() -> bool:
"""Return whether LLM range analysis should be attempted."""
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
if raw_value is not None and str(raw_value).strip() != "":
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
return llm_enrichment_enabled()
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
"""Generate explain-oriented structured analysis for one article."""
if not llm_enrichment_enabled():
return None
model = _get_explain_model()
title = str(row.get("title") or "").strip()
summary = str(row.get("summary") or "").strip()
messages = [
{
"role": "system",
"content": (
"You produce concise structured financial news analysis. "
"Use only the requested fields and keep content factual."
),
},
{
"role": "user",
"content": (
"Analyze this stock-news article for an explain UI.\n"
"Rules:\n"
"- relevance must be one of: high, medium, low\n"
"- sentiment must be one of: positive, negative, neutral\n"
"- keep each text field concise and factual\n"
f"- article id: {str(row.get('id') or '').strip()}\n"
f"Title: {title}\n"
f"Summary: {summary}\n"
),
},
]
try:
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
except Exception:
return None
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
if payload:
payload.setdefault("raw_json", {})
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
return payload
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
"""Generate structured analysis for multiple articles in one request."""
if not llm_enrichment_enabled() or not rows:
return {}
payload_rows = [
{
"id": str(row.get("id") or "").strip(),
"title": str(row.get("title") or "").strip(),
"summary": str(row.get("summary") or "").strip(),
}
for row in rows
if str(row.get("id") or "").strip()
]
if not payload_rows:
return {}
model = _get_explain_model()
messages = [
{
"role": "system",
"content": (
"You produce concise structured financial news analysis in JSON. "
"Preserve ids exactly and do not invent extra items."
),
},
{
"role": "user",
"content": (
"Analyze these stock-news articles for an explain UI.\n"
"For each item return: id, relevance, sentiment, key_discussion, summary, "
"reason_growth, reason_decrease.\n"
"Rules:\n"
"- relevance must be one of: high, medium, low\n"
"- sentiment must be one of: positive, negative, neutral\n"
"- keep all text concise and factual\n"
f"Articles: {payload_rows}"
),
},
]
try:
response = _run_async(
model(messages=messages, structured_model=EnrichedNewsBatch),
)
except Exception:
return {}
metadata = getattr(response, "metadata", None)
if isinstance(metadata, BaseModel):
metadata = metadata.model_dump()
items = metadata.get("items") if isinstance(metadata, dict) else None
if not isinstance(items, list):
return {}
results: dict[str, dict[str, Any]] = {}
for item in items:
normalized = _normalize_enrichment_payload(item)
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
if normalized and news_id:
normalized.setdefault("raw_json", {})
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
results[news_id] = normalized
return results
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
"""Generate explain-oriented range summary and factor refinement."""
if not llm_range_analysis_enabled():
return None
model = _get_explain_model()
messages = [
{
"role": "system",
"content": (
"You write concise Chinese stock range analysis for an explain UI. "
"Use only the supplied facts. Keep the tone factual and analyst-like."
),
},
{
"role": "user",
"content": (
"请基于给定事实生成区间分析。\n"
"输出字段summary, trend_analysis, bullish_factors, bearish_factors。\n"
"要求:\n"
"- 全部使用简体中文\n"
"- summary 1到2句概括区间走势、新闻密度和主导主题\n"
"- trend_analysis 1句解释区间内部阶段变化\n"
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
"- 不要编造未提供的信息\n"
f"事实数据: {payload}"
),
},
]
try:
response = _run_async(
model(messages=messages, structured_model=RangeAnalysisPayload),
)
except Exception:
return None
metadata = getattr(response, "metadata", None)
if isinstance(metadata, BaseModel):
metadata = metadata.model_dump()
if not isinstance(metadata, dict):
return None
return {
"summary": str(metadata.get("summary") or "").strip() or None,
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
"bullish_factors": [
str(item).strip()
for item in list(metadata.get("bullish_factors") or [])
if str(item).strip()
][:3],
"bearish_factors": [
str(item).strip()
for item in list(metadata.get("bearish_factors") or [])
if str(item).strip()
][:3],
"model_provider": get_explain_model_info()["provider"],
"model_name": get_explain_model_info()["model_name"],
"model_label": get_explain_model_info()["label"],
}

View File

@@ -0,0 +1,362 @@
# -*- coding: utf-8 -*-
"""Lightweight news enrichment for explain-oriented market analysis."""
from __future__ import annotations
import hashlib
from typing import Any
from backend.config.env_config import get_env_int
from backend.enrich.llm_enricher import (
analyze_news_row_with_llm,
analyze_news_rows_with_llm,
llm_enrichment_enabled,
)
from backend.data.market_store import MarketStore
POSITIVE_KEYWORDS = (
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
"partnership", "approved", "launch", "expands", "profit",
)
NEGATIVE_KEYWORDS = (
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
)
HIGH_RELEVANCE_KEYWORDS = (
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
"regulation", "acquisition", "buyback", "forecast", "launch",
)
def _dedupe_key(row: dict[str, Any]) -> str:
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
title = str(row.get("title") or "").strip().lower()
summary = str(row.get("summary") or "").strip().lower()[:160]
raw = f"{trade_date}::{title}::{summary}"
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
chunk_size = max(1, int(size))
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
"""Return a lightweight explain-oriented analysis for one article."""
llm_result = analyze_news_row_with_llm(row)
if isinstance(llm_result, dict):
merged = dict(llm_result)
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
merged.setdefault("raw_json", row)
merged["analysis_source"] = "llm"
return merged
title = str(row.get("title") or "").strip()
summary = str(row.get("summary") or "").strip()
text = f"{title} {summary}".lower()
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
if len(positive_hits) > len(negative_hits):
sentiment = "positive"
elif len(negative_hits) > len(positive_hits):
sentiment = "negative"
else:
sentiment = "neutral"
relevance = "high" if relevance_hits else "medium" if title else "low"
summary_text = summary or title
key_discussion = ""
if relevance_hits:
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
elif summary_text:
key_discussion = summary_text[:160]
reason_growth = ""
reason_decrease = ""
if sentiment == "positive":
reason_growth = summary_text[:200]
elif sentiment == "negative":
reason_decrease = summary_text[:200]
return {
"relevance": relevance,
"sentiment": sentiment,
"key_discussion": key_discussion,
"summary": summary_text[:280],
"reason_growth": reason_growth,
"reason_decrease": reason_decrease,
"analysis_source": "local",
"raw_json": row,
}
def attach_forward_returns(
*,
news_rows: list[dict[str, Any]],
ohlc_rows: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Attach forward-return labels to each analyzed row."""
if not ohlc_rows:
return news_rows
closes_by_date = {
str(row.get("date")): float(row.get("close"))
for row in ohlc_rows
if row.get("date") is not None and row.get("close") is not None
}
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
horizons = {
"ret_t0": 0,
"ret_t1": 1,
"ret_t3": 3,
"ret_t5": 5,
"ret_t10": 10,
}
enriched: list[dict[str, Any]] = []
for row in news_rows:
trade_date = str(row.get("trade_date") or "")[:10]
base_close = closes_by_date.get(trade_date)
if not trade_date or base_close in (None, 0):
enriched.append(row)
continue
next_row = dict(row)
base_index = date_index.get(trade_date)
if base_index is None:
enriched.append(next_row)
continue
for field, offset in horizons.items():
target_index = base_index + offset
if target_index >= len(ordered_dates):
next_row[field] = None
continue
target_close = closes_by_date.get(ordered_dates[target_index])
next_row[field] = (
(float(target_close) - float(base_close)) / float(base_close)
if target_close not in (None, 0)
else None
)
enriched.append(next_row)
return enriched
def build_analysis_rows(
*,
symbol: str,
news_rows: list[dict[str, Any]],
ohlc_rows: list[dict[str, Any]],
) -> tuple[list[dict[str, Any]], dict[str, int]]:
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
llm_results: dict[str, dict[str, Any]] = {}
if llm_enrichment_enabled():
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
for chunk in _chunk_rows(news_rows, batch_size):
llm_results.update(analyze_news_rows_with_llm(chunk))
staged_rows: list[dict[str, Any]] = []
seen_dedupe_keys: set[str] = set()
deduped_count = 0
llm_count = 0
local_count = 0
for row in news_rows:
news_id = str(row.get("id") or "").strip()
if not news_id:
continue
dedupe_key = _dedupe_key(row)
if dedupe_key in seen_dedupe_keys:
deduped_count += 1
continue
seen_dedupe_keys.add(dedupe_key)
batch_result = llm_results.get(news_id)
if isinstance(batch_result, dict):
analysis = dict(batch_result)
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
analysis.setdefault("raw_json", row)
analysis["analysis_source"] = "llm"
llm_count += 1
else:
analysis = classify_news_row(row)
if analysis.get("analysis_source") == "llm":
llm_count += 1
else:
local_count += 1
staged_rows.append(
{
"news_id": news_id,
"trade_date": str(row.get("trade_date") or "")[:10] or None,
**analysis,
}
)
return (
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
{
"deduped_count": deduped_count,
"llm_count": llm_count,
"local_count": local_count,
},
)
def enrich_news_for_symbol(
store: MarketStore,
symbol: str,
*,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 200,
analysis_source: str = "local",
skip_existing: bool = True,
only_reanalyze_local: bool = False,
) -> dict[str, Any]:
"""Read raw market news, compute explain fields, and persist them."""
normalized_symbol = str(symbol or "").strip().upper()
if not normalized_symbol:
return {"symbol": "", "analyzed": 0}
news_rows = store.get_news_items(
normalized_symbol,
start_date=start_date,
end_date=end_date,
limit=limit,
)
total_news_count = len(news_rows)
skipped_existing_count = 0
analyzed_sources: dict[str, str] = {}
skipped_missing_analysis_count = 0
skipped_non_local_count = 0
if news_rows and only_reanalyze_local:
analyzed_sources = store.get_analyzed_news_sources(
normalized_symbol,
start_date=start_date,
end_date=end_date,
)
skipped_missing_analysis_count = sum(
1
for row in news_rows
if str(row.get("id") or "").strip() not in analyzed_sources
)
skipped_non_local_count = sum(
1
for row in news_rows
if str(row.get("id") or "").strip() in analyzed_sources
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
)
skipped_existing_count = sum(
1
for row in news_rows
if str(row.get("id") or "").strip() not in analyzed_sources
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
)
news_rows = [
row for row in news_rows
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
]
elif skip_existing and news_rows:
analyzed_ids = store.get_analyzed_news_ids(
normalized_symbol,
start_date=start_date,
end_date=end_date,
)
skipped_existing_count = sum(
1
for row in news_rows
if str(row.get("id") or "").strip() in analyzed_ids
)
news_rows = [
row for row in news_rows
if str(row.get("id") or "").strip() not in analyzed_ids
]
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
ohlc_rows = (
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
if ohlc_start and ohlc_end
else []
)
analysis_rows, stats = build_analysis_rows(
symbol=normalized_symbol,
news_rows=news_rows,
ohlc_rows=ohlc_rows,
)
analyzed = store.upsert_news_analysis(
normalized_symbol,
analysis_rows,
analysis_source=analysis_source,
)
upgraded_dates = sorted(
{
str(row.get("trade_date") or "")[:10]
for row in analysis_rows
if str(row.get("analysis_source") or "").strip().lower() == "llm"
and str(row.get("trade_date") or "").strip()
}
)
remaining_local_titles = [
str(row.get("title") or row.get("news_id") or "").strip()
for row in news_rows
for analyzed_row in analysis_rows
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
][:5]
return {
"symbol": normalized_symbol,
"analyzed": analyzed,
"news_count": total_news_count,
"queued_count": len(news_rows),
"skipped_existing_count": skipped_existing_count,
"deduped_count": stats["deduped_count"],
"llm_count": stats["llm_count"],
"local_count": stats["local_count"],
"only_reanalyze_local": only_reanalyze_local,
"upgraded_local_to_llm_count": (
stats["llm_count"]
if only_reanalyze_local
else 0
),
"execution_summary": {
"upgraded_dates": upgraded_dates[:5],
"remaining_local_titles": remaining_local_titles,
"skipped_missing_analysis_count": skipped_missing_analysis_count,
"skipped_non_local_count": skipped_non_local_count,
},
}
def enrich_symbols(
store: MarketStore,
symbols: list[str],
*,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 200,
analysis_source: str = "local",
skip_existing: bool = True,
only_reanalyze_local: bool = False,
) -> list[dict[str, Any]]:
"""Batch enrich multiple symbols for explain-oriented news analysis."""
results = []
for symbol in symbols:
normalized_symbol = str(symbol or "").strip().upper()
if not normalized_symbol:
continue
results.append(
enrich_news_for_symbol(
store,
normalized_symbol,
start_date=start_date,
end_date=end_date,
limit=limit,
analysis_source=analysis_source,
skip_existing=skip_existing,
only_reanalyze_local=only_reanalyze_local,
)
)
return results

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Explain-oriented services for stock narratives and news research."""

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""Rule-based news categorization for explain UI."""
from __future__ import annotations
from typing import Any, Dict, Iterable
CATEGORY_KEYWORDS = {
"market": [
"market", "stock", "rally", "sell-off", "selloff", "trading",
"wall street", "s&p", "nasdaq", "dow", "index", "bull", "bear",
"correction", "volatility",
],
"policy": [
"regulation", "fed", "federal reserve", "tariff", "sanction",
"interest rate", "policy", "government", "congress", "sec",
"trade war", "ban", "legislation", "tax",
],
"earnings": [
"earnings", "revenue", "profit", "quarter", "eps", "guidance",
"forecast", "income", "sales", "beat", "miss", "outlook",
"financial results",
],
"product_tech": [
"product", "ai", "chip", "cloud", "launch", "patent",
"technology", "innovation", "release", "platform", "model",
"software", "hardware", "gpu", "autonomous",
],
"competition": [
"competitor", "rival", "market share", "overtake", "compete",
"competition", "vs", "versus", "battle", "challenge",
],
"management": [
"ceo", "executive", "resign", "layoff", "restructure",
"management", "leadership", "appoint", "hire", "board",
"chairman",
],
}
def categorize_news_rows(rows: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
"""Bucket news rows by keyword categories."""
categories: Dict[str, Dict[str, Any]] = {
key: {
"label": key,
"count": 0,
"article_ids": [],
}
for key in CATEGORY_KEYWORDS
}
for row in rows:
text = " ".join(
[
str(row.get("title") or ""),
str(row.get("summary") or ""),
str(row.get("related") or ""),
str(row.get("category") or ""),
]
).lower()
article_id = row.get("id")
for category, keywords in CATEGORY_KEYWORDS.items():
if any(keyword in text for keyword in keywords):
categories[category]["count"] += 1
if article_id:
categories[category]["article_ids"].append(article_id)
return categories

View File

@@ -0,0 +1,214 @@
# -*- coding: utf-8 -*-
"""Local range explanation built from price and persisted news."""
from __future__ import annotations
from typing import Any, Dict
from backend.enrich.llm_enricher import analyze_range_with_llm
from backend.explain.category_engine import categorize_news_rows
from backend.tools.data_tools import get_prices
def _rank_event_score(row: Dict[str, Any]) -> float:
relevance = str(row.get("relevance") or "").strip().lower()
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
relevance,
0.5,
)
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
return relevance_score + impact_score
def summarize_bullish_factors(
news_rows: list[Dict[str, Any]],
*,
limit: int = 5,
) -> list[str]:
factors = []
for row in news_rows:
if str(row.get("sentiment") or "").strip().lower() != "positive":
continue
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
if candidate:
factors.append(str(candidate).strip())
seen = set()
output = []
for factor in factors:
if factor in seen:
continue
seen.add(factor)
output.append(factor[:200])
if len(output) >= limit:
break
return output
def summarize_bearish_factors(
news_rows: list[Dict[str, Any]],
*,
limit: int = 5,
) -> list[str]:
factors = []
for row in news_rows:
if str(row.get("sentiment") or "").strip().lower() != "negative":
continue
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
if candidate:
factors.append(str(candidate).strip())
seen = set()
output = []
for factor in factors:
if factor in seen:
continue
seen.add(factor)
output.append(factor[:200])
if len(output) >= limit:
break
return output
def build_trend_analysis(prices: list[Any]) -> str:
if len(prices) < 2:
return "区间样本较短,暂不具备足够趋势信息。"
if len(prices) < 3:
open_price = float(prices[0].open)
close_price = float(prices[-1].close)
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
mid = len(prices) // 2
first_open = float(prices[0].open)
first_close = float(prices[mid].close)
second_open = float(prices[mid].open)
second_close = float(prices[-1].close)
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
return (
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%"
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%"
"说明价格驱动在区间内部出现了阶段性切换。"
)
def build_range_explanation(
*,
ticker: str,
start_date: str,
end_date: str,
news_rows: list[Dict[str, Any]],
) -> Dict[str, Any]:
"""Explain a price range with local price and news heuristics."""
prices = get_prices(ticker, start_date, end_date)
if not prices:
return {
"symbol": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "No OHLC data for this range",
}
open_price = float(prices[0].open)
close_price = float(prices[-1].close)
high_price = max(float(price.high) for price in prices)
low_price = min(float(price.low) for price in prices)
total_volume = sum(int(price.volume) for price in prices)
price_change_pct = (
((close_price - open_price) / open_price) * 100 if open_price else 0.0
)
categories = categorize_news_rows(news_rows)
news_count = len(news_rows)
dominant_categories = sorted(
(
{"category": key, "count": value["count"]}
for key, value in categories.items()
if value["count"] > 0
),
key=lambda item: item["count"],
reverse=True,
)
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
category_text = (
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}"
if dominant_categories
else "区间内未识别出明显的主题聚类。"
)
summary = (
f"{ticker}{start_date}{end_date} 区间内{direction} {abs(price_change_pct):.2f}%"
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
)
bullish_factors = summarize_bullish_factors(news_rows)
bearish_factors = summarize_bearish_factors(news_rows)
trend_analysis = build_trend_analysis(prices)
llm_source = "local"
range_payload = {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"price_change_pct": round(price_change_pct, 2),
"trading_days": len(prices),
"news_count": news_count,
"dominant_categories": dominant_categories[:5],
"bullish_factors": bullish_factors[:3],
"bearish_factors": bearish_factors[:3],
"trend_analysis": trend_analysis,
"top_news": [
{
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
"title": row.get("title") or "",
"summary": row.get("summary") or "",
"sentiment": row.get("sentiment") or "",
"relevance": row.get("relevance") or "",
"ret_t0": row.get("ret_t0"),
}
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
],
}
llm_analysis = analyze_range_with_llm(range_payload)
if isinstance(llm_analysis, dict):
summary = llm_analysis.get("summary") or summary
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
llm_source = "llm"
key_events = [
{
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
"title": row.get("title") or "Untitled news",
"summary": row.get("summary") or "",
"category": row.get("category") or "",
"id": row.get("id"),
"sentiment": row.get("sentiment"),
"ret_t0": row.get("ret_t0"),
}
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
]
return {
"symbol": ticker,
"start_date": start_date,
"end_date": end_date,
"price_change_pct": round(price_change_pct, 2),
"open_price": open_price,
"close_price": close_price,
"high_price": high_price,
"low_price": low_price,
"total_volume": total_volume,
"trading_days": len(prices),
"news_count": news_count,
"dominant_categories": dominant_categories[:5],
"analysis": {
"summary": summary,
"key_events": key_events,
"bullish_factors": bullish_factors,
"bearish_factors": bearish_factors,
"trend_analysis": trend_analysis,
"analysis_source": llm_source,
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
},
}

View File

@@ -0,0 +1,202 @@
# -*- coding: utf-8 -*-
"""Same-ticker historical similar day search for explain view."""
from __future__ import annotations
from math import sqrt
from typing import Any
from backend.data.market_store import MarketStore
def _safe_float(value: Any, default: float = 0.0) -> float:
try:
parsed = float(value)
except (TypeError, ValueError):
return default
return parsed
def build_daily_feature_rows(
*,
symbol: str,
ohlc_rows: list[dict[str, Any]],
news_rows: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Aggregate price/news context into daily feature rows."""
price_by_date = {str(row.get("date")): row for row in ohlc_rows if row.get("date")}
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date")]
news_by_date: dict[str, list[dict[str, Any]]] = {}
for row in news_rows:
trade_date = str(row.get("trade_date") or "")[:10] or str(row.get("date") or "")[:10]
if not trade_date:
continue
news_by_date.setdefault(trade_date, []).append(row)
features: list[dict[str, Any]] = []
previous_close: float | None = None
for idx, date in enumerate(ordered_dates):
price_row = price_by_date[date]
close_price = _safe_float(price_row.get("close"))
open_price = _safe_float(price_row.get("open"), close_price)
day_news = news_by_date.get(date, [])
positive_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "positive")
negative_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "negative")
high_relevance_count = sum(
1 for item in day_news if str(item.get("relevance") or "").lower() in {"high", "relevant"}
)
ret_1d = (
((close_price - previous_close) / previous_close)
if previous_close not in (None, 0)
else 0.0
)
intraday_ret = ((close_price - open_price) / open_price) if open_price else 0.0
sentiment_score = (
(positive_count - negative_count) / max(len(day_news), 1)
if day_news
else 0.0
)
future_t1 = None
future_t3 = None
if idx + 1 < len(ordered_dates) and close_price:
next_close = _safe_float(price_by_date[ordered_dates[idx + 1]].get("close"))
future_t1 = ((next_close - close_price) / close_price) if next_close else None
if idx + 3 < len(ordered_dates) and close_price:
next_close = _safe_float(price_by_date[ordered_dates[idx + 3]].get("close"))
future_t3 = ((next_close - close_price) / close_price) if next_close else None
features.append(
{
"date": date,
"symbol": symbol,
"n_articles": len(day_news),
"positive_count": positive_count,
"negative_count": negative_count,
"high_relevance_count": high_relevance_count,
"sentiment_score": sentiment_score,
"ret_1d": ret_1d,
"intraday_ret": intraday_ret,
"close": close_price,
"ret_t1_after": future_t1,
"ret_t3_after": future_t3,
"news": [
{
"title": row.get("title") or "",
"sentiment": row.get("sentiment") or "neutral",
}
for row in day_news[:3]
],
}
)
previous_close = close_price
return features
def compute_similarity_scores(
target_vector: list[float],
candidate_vectors: list[tuple[str, list[float], dict[str, Any]]],
) -> list[dict[str, Any]]:
"""Return sorted similarity matches based on normalized Euclidean distance."""
if not candidate_vectors:
return []
dimensions = len(target_vector)
ranges = []
for dimension in range(dimensions):
values = [vector[1][dimension] for vector in candidate_vectors] + [target_vector[dimension]]
min_value = min(values)
max_value = max(values)
ranges.append(max(max_value - min_value, 1e-9))
scored = []
for date, vector, payload in candidate_vectors:
distance = sqrt(
sum(
((target_vector[i] - vector[i]) / ranges[i]) ** 2
for i in range(dimensions)
)
)
similarity = 1.0 / (1.0 + distance)
scored.append(
{
"date": date,
"score": round(similarity, 4),
**payload,
}
)
return sorted(scored, key=lambda item: item["score"], reverse=True)
def find_similar_days(
store: MarketStore,
*,
symbol: str,
target_date: str,
top_k: int = 10,
) -> dict[str, Any]:
"""Find same-ticker historical days most similar to a target day."""
cached = store.get_similar_day_cache(symbol, target_date=target_date)
if cached and cached.get("payload"):
return cached["payload"]
ohlc_rows = store.get_ohlc(symbol, "1900-01-01", target_date)
news_rows = store.get_news_items_enriched(symbol, end_date=target_date, limit=500)
daily_rows = build_daily_feature_rows(symbol=symbol, ohlc_rows=ohlc_rows, news_rows=news_rows)
feature_map = {row["date"]: row for row in daily_rows}
target_row = feature_map.get(target_date)
if not target_row:
return {
"symbol": symbol,
"target_date": target_date,
"items": [],
"error": "No feature row for target date",
}
vector_keys = [
"sentiment_score",
"n_articles",
"positive_count",
"negative_count",
"high_relevance_count",
"ret_1d",
"intraday_ret",
]
target_vector = [_safe_float(target_row.get(key)) for key in vector_keys]
candidates = []
for row in daily_rows:
date = row["date"]
if date == target_date:
continue
payload = {
"n_articles": row["n_articles"],
"sentiment_score": round(row["sentiment_score"], 4),
"ret_1d": round(row["ret_1d"] * 100, 2),
"intraday_ret": round(row["intraday_ret"] * 100, 2),
"ret_t1_after": round(row["ret_t1_after"] * 100, 2) if row["ret_t1_after"] is not None else None,
"ret_t3_after": round(row["ret_t3_after"] * 100, 2) if row["ret_t3_after"] is not None else None,
"top_reasons": [item["title"] for item in row["news"][:2] if item.get("title")],
"news": row["news"],
}
candidates.append(
(
date,
[_safe_float(row.get(key)) for key in vector_keys],
payload,
)
)
items = compute_similarity_scores(target_vector, candidates)[: max(1, min(int(top_k), 20))]
result = {
"symbol": symbol,
"target_date": target_date,
"target_features": {
"sentiment_score": round(target_row["sentiment_score"], 4),
"n_articles": target_row["n_articles"],
"ret_1d": round(target_row["ret_1d"] * 100, 2),
"intraday_ret": round(target_row["intraday_ret"] * 100, 2),
"high_relevance_count": target_row["high_relevance_count"],
},
"items": items,
}
store.upsert_similar_day_cache(symbol, target_date=target_date, payload=result, source="local")
return result

View File

@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""Stock story generation for explain view."""
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any
from backend.data.market_store import MarketStore
def build_stock_story(
*,
symbol: str,
as_of_date: str,
price_rows: list[dict[str, Any]],
news_rows: list[dict[str, Any]],
) -> str:
"""Build a compact markdown story from enriched news and recent price action."""
lines = [f"## {symbol} Story", f"As of `{as_of_date}`"]
if not price_rows:
lines.append("")
lines.append("No OHLC data available for story generation.")
return "\n".join(lines)
open_price = float(price_rows[0].get("open") or price_rows[0].get("close") or 0.0)
close_price = float(price_rows[-1].get("close") or 0.0)
price_change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
high_price = max(float(row.get("high") or row.get("close") or 0.0) for row in price_rows)
low_price = min(float(row.get("low") or row.get("close") or 0.0) for row in price_rows)
lines.append("")
lines.append(
f"The stock moved {'up' if price_change >= 0 else 'down'} "
f"{abs(price_change):.2f}% over the recent window, trading between "
f"${low_price:.2f} and ${high_price:.2f}."
)
positive = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "positive"]
negative = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "negative"]
lines.append("")
lines.append(
f"Recent coverage included {len(news_rows)} relevant articles "
f"({len(positive)} positive / {len(negative)} negative)."
)
if news_rows:
lines.append("")
lines.append("### Key Moments")
ranked_rows = sorted(
news_rows,
key=lambda row: (
0 if str(row.get("relevance") or "").lower() in {"high", "relevant"} else 1,
-abs(float(row.get("ret_t0") or 0.0)),
),
)
for row in ranked_rows[:5]:
trade_date = row.get("trade_date") or str(row.get("date") or "")[:10]
title = row.get("title") or "Untitled"
key_discussion = row.get("key_discussion") or row.get("summary") or ""
sentiment = str(row.get("sentiment") or "neutral").lower()
lines.append(
f"- `{trade_date}` [{sentiment}] {title}: {str(key_discussion).strip()[:220]}"
)
if positive:
lines.append("")
lines.append("### Bullish Threads")
for row in positive[:3]:
reason = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
lines.append(f"- {str(reason).strip()[:220]}")
if negative:
lines.append("")
lines.append("### Bearish Threads")
for row in negative[:3]:
reason = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
lines.append(f"- {str(reason).strip()[:220]}")
return "\n".join(lines)
def get_or_create_stock_story(
store: MarketStore,
*,
symbol: str,
as_of_date: str,
) -> dict[str, Any]:
"""Return cached story or build a new one from recent market context."""
cached = store.get_story_cache(symbol, as_of_date=as_of_date)
if cached:
return {
"symbol": symbol,
"as_of_date": as_of_date,
"story": cached.get("content") or "",
"source": cached.get("source") or "cache",
}
start_date = None
if len(as_of_date) >= 10:
target_date = datetime.strptime(as_of_date[:10], "%Y-%m-%d").date()
start_date = (target_date - timedelta(days=29)).isoformat()
price_rows = (
store.get_ohlc(symbol, start_date, as_of_date)
if start_date
else []
)
news_rows = store.get_news_items_enriched(
symbol,
start_date=start_date,
end_date=as_of_date,
limit=40,
)
story = build_stock_story(
symbol=symbol,
as_of_date=as_of_date,
price_rows=price_rows,
news_rows=news_rows,
)
store.upsert_story_cache(symbol, as_of_date=as_of_date, content=story, source="local")
return {
"symbol": symbol,
"as_of_date": as_of_date,
"story": story,
"source": "local",
}

View File

@@ -17,6 +17,12 @@ from backend.config.bootstrap_config import (
update_bootstrap_values_for_run,
)
from backend.data.provider_utils import normalize_symbol
from backend.data.market_ingest import ingest_symbols
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.explain.range_explainer import build_range_explanation
from backend.explain.similarity_service import find_similar_days
from backend.explain.story_service import get_or_create_stock_story
from backend.utils.msg_adapter import FrontendAdapter
from backend.utils.terminal_dashboard import get_dashboard
from backend.core.pipeline import TradingPipeline
@@ -25,6 +31,7 @@ from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.data.provider_router import get_provider_router
from backend.tools.data_tools import get_prices
from backend.tools.data_tools import get_company_news
logger = logging.getLogger(__name__)
@@ -65,6 +72,7 @@ class Gateway:
self._backtest_end_date: Optional[str] = None
self._dashboard = get_dashboard()
self._market_status_task: Optional[asyncio.Task] = None
self._watchlist_ingest_task: Optional[asyncio.Task] = None
# Session tracking for live returns
self._session_start_portfolio_value: Optional[float] = None
@@ -182,6 +190,17 @@ class Gateway:
def state(self) -> Dict[str, Any]:
return self.state_sync.state
@staticmethod
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
if not rows:
return True
return all(
not row.get("sentiment")
and not row.get("relevance")
and not row.get("key_discussion")
for row in rows
)
async def handle_client(self, websocket: ServerConnection):
"""Handle WebSocket client connection"""
async with self.lock:
@@ -250,6 +269,22 @@ class Gateway:
await self._handle_get_stock_history(websocket, data)
elif msg_type == "get_stock_explain_events":
await self._handle_get_stock_explain_events(websocket, data)
elif msg_type == "get_stock_news":
await self._handle_get_stock_news(websocket, data)
elif msg_type == "get_stock_news_for_date":
await self._handle_get_stock_news_for_date(websocket, data)
elif msg_type == "get_stock_news_timeline":
await self._handle_get_stock_news_timeline(websocket, data)
elif msg_type == "get_stock_news_categories":
await self._handle_get_stock_news_categories(websocket, data)
elif msg_type == "get_stock_range_explain":
await self._handle_get_stock_range_explain(websocket, data)
elif msg_type == "get_stock_story":
await self._handle_get_stock_story(websocket, data)
elif msg_type == "get_stock_similar_days":
await self._handle_get_stock_similar_days(websocket, data)
elif msg_type == "run_stock_enrich":
await self._handle_run_stock_enrich(websocket, data)
except websockets.ConnectionClosed:
pass
@@ -298,20 +333,38 @@ class Gateway:
)
prices = await asyncio.to_thread(
get_prices,
self.storage.market_store.get_ohlc,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
source = "polygon"
if not prices:
prices = await asyncio.to_thread(
get_prices,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
self.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price.model_dump() for price in prices][-120:],
"prices": [
price if isinstance(price, dict) else price.model_dump()
for price in prices
][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
@@ -342,6 +395,636 @@ class Gateway:
),
)
async def _handle_get_stock_news(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
source = "polygon"
if self._news_rows_need_enrichment(news_rows):
news = await asyncio.to_thread(
get_company_news,
ticker,
end_date,
start_date,
limit,
)
if news:
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("company_news")
await asyncio.to_thread(
self.storage.market_store.upsert_news,
ticker,
[item.model_dump() for item in news],
source=source or "provider",
)
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
source = source or "market_store"
await websocket.send(
json.dumps(
{
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_for_date(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(
json.dumps(
{
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
},
ensure_ascii=False,
),
)
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
trade_date=trade_date,
limit=limit,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=trade_date,
end_date=trade_date,
limit=limit,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
trade_date=trade_date,
limit=limit,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": "market_store",
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_timeline(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
timeline = await asyncio.to_thread(
self.storage.market_store.get_news_timeline_enriched,
ticker,
start_date=start_date,
end_date=end_date,
)
if not timeline:
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
timeline = await asyncio.to_thread(
self.storage.market_store.get_news_timeline_enriched,
ticker,
start_date=start_date,
end_date=end_date,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_news_categories(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
categories = await asyncio.to_thread(
self.storage.market_store.get_news_categories_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
await websocket.send(
json.dumps(
{
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_range_explain(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(
json.dumps(
{
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
},
ensure_ascii=False,
),
)
return
article_ids = data.get("article_ids")
if isinstance(article_ids, list) and article_ids:
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_by_ids_enriched,
ticker,
article_ids,
)
if self._news_rows_need_enrichment(news_rows):
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_by_ids_enriched,
ticker,
article_ids,
)
else:
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
if not news_rows:
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
news_rows = await asyncio.to_thread(
self.storage.market_store.get_news_items_enriched,
ticker,
start_date=start_date,
end_date=end_date,
limit=100,
)
result = await asyncio.to_thread(
build_range_explanation,
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_rows=news_rows,
)
await websocket.send(
json.dumps(
{
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_story(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
as_of_date = str(
data.get("as_of_date")
or self.state_sync.state.get("current_date")
or datetime.now().strftime("%Y-%m-%d")
).strip()[:10]
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
end_date=as_of_date,
limit=80,
)
result = await asyncio.to_thread(
get_or_create_stock_story,
self.storage.market_store,
symbol=ticker,
as_of_date=as_of_date,
)
await websocket.send(
json.dumps(
{
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_similar_days(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(
json.dumps(
{
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
},
ensure_ascii=False,
),
)
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
end_date=target_date,
limit=200,
)
result = await asyncio.to_thread(
find_similar_days,
self.storage.market_store,
symbol=ticker,
target_date=target_date,
top_k=top_k,
)
await websocket.send(
json.dumps(
{
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_run_stock_enrich(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
},
ensure_ascii=False,
),
)
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
},
ensure_ascii=False,
),
)
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
self.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(
self.storage.market_store.delete_story_cache,
ticker,
as_of_date=story_date,
)
story_result = await asyncio.to_thread(
get_or_create_stock_story,
self.storage.market_store,
symbol=ticker,
as_of_date=story_date,
)
story_status = {
"as_of_date": story_date,
"source": story_result.get("source") or "local",
}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(
self.storage.market_store.delete_similar_day_cache,
ticker,
target_date=target_date,
)
similar_result = await asyncio.to_thread(
find_similar_days,
self.storage.market_store,
symbol=ticker,
target_date=target_date,
top_k=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(
json.dumps(
{
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
@@ -410,6 +1093,7 @@ class Gateway:
},
)
await self._handle_reload_runtime_assets()
self._schedule_watchlist_market_store_refresh(tickers)
@staticmethod
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
@@ -538,6 +1222,48 @@ class Gateway:
trades=trades,
)
def _schedule_watchlist_market_store_refresh(
self,
tickers: List[str],
) -> None:
"""Kick off a non-blocking Polygon refresh for the updated watchlist."""
if not tickers:
return
if self._watchlist_ingest_task and not self._watchlist_ingest_task.done():
self._watchlist_ingest_task.cancel()
self._watchlist_ingest_task = asyncio.create_task(
self._refresh_market_store_for_watchlist(tickers),
)
async def _refresh_market_store_for_watchlist(
self,
tickers: List[str],
) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await self.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await self.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await self.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
@@ -896,4 +1622,6 @@ class Gateway:
self._backtest_task.cancel()
if self._market_status_task:
self._market_status_task.cancel()
if self._watchlist_ingest_task:
self._watchlist_ingest_task.cancel()
self._dashboard.stop()

View File

@@ -65,6 +65,18 @@ class MarketService:
self._session_start_values: Optional[Dict[str, float]] = None
self._session_start_timestamp: Optional[int] = None
def get_live_quote_provider(self) -> Optional[str]:
"""Return the active live quote provider for UI/debugging."""
if self.backtest_mode:
return "backtest"
if self.mock_mode:
return "mock"
if self._price_manager and hasattr(self._price_manager, "provider"):
provider = getattr(self._price_manager, "provider", None)
if isinstance(provider, str) and provider.strip():
return provider.strip().lower()
return None
@property
def mode_name(self) -> str:
if self.backtest_mode:
@@ -532,6 +544,7 @@ class MarketService:
"status": MarketStatus.OPEN,
"status_text": "Backtest Mode",
"is_trading_day": True,
"live_quote_provider": self.get_live_quote_provider(),
}
now = self._now_nyse()
@@ -544,6 +557,7 @@ class MarketService:
"status": MarketStatus.CLOSED,
"status_text": "Market Closed (Non-trading Day)",
"is_trading_day": False,
"live_quote_provider": self.get_live_quote_provider(),
}
market_open, market_close = self._get_market_hours(today)
@@ -553,6 +567,7 @@ class MarketService:
"status": MarketStatus.CLOSED,
"status_text": "Market Closed",
"is_trading_day": is_trading,
"live_quote_provider": self.get_live_quote_provider(),
}
# Determine status based on current time
@@ -563,6 +578,7 @@ class MarketService:
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
elif now > market_close:
return {
@@ -571,6 +587,7 @@ class MarketService:
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
else:
return {
@@ -579,6 +596,7 @@ class MarketService:
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
async def check_and_broadcast_market_status(self):

View File

@@ -0,0 +1,280 @@
# -*- coding: utf-8 -*-
"""Query-oriented storage for explain/research data."""
from __future__ import annotations
import json
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable
from backend.data.schema import CompanyNews
SCHEMA = """
CREATE TABLE IF NOT EXISTS news_items (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
published_at TEXT,
trade_date TEXT,
source TEXT,
title TEXT NOT NULL,
summary TEXT,
url TEXT,
related TEXT,
category TEXT,
raw_json TEXT NOT NULL,
ingest_run_date TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
ON news_items (ticker, trade_date DESC, published_at DESC);
"""
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
base = item.url or item.title or f"{ticker}-{fallback_index}"
return f"{ticker}:{base}"
def _resolve_trade_date(date_value: str | None) -> str | None:
if not date_value:
return None
normalized = str(date_value).strip()
if not normalized:
return None
if "T" in normalized:
return normalized.split("T", 1)[0]
if " " in normalized:
return normalized.split(" ", 1)[0]
return normalized[:10]
class ResearchDb:
"""Small SQLite helper for explain-oriented news storage."""
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(SCHEMA)
def upsert_news_items(
self,
*,
ticker: str,
items: Iterable[CompanyNews],
ingest_run_date: str | None = None,
) -> list[dict[str, Any]]:
"""Persist provider news and return normalized rows."""
normalized_rows: list[dict[str, Any]] = []
timestamp = datetime.utcnow().isoformat(timespec="seconds")
symbol = str(ticker or "").strip().upper()
if not symbol:
return normalized_rows
with self._connect() as conn:
for index, item in enumerate(items):
news_id = _resolve_news_id(symbol, item, index)
trade_date = _resolve_trade_date(item.date)
payload = item.model_dump()
row = {
"id": news_id,
"ticker": symbol,
"published_at": item.date,
"trade_date": trade_date,
"source": item.source,
"title": item.title,
"summary": item.summary,
"url": item.url,
"related": item.related,
"category": item.category,
"raw_json": _json_dumps(payload),
"ingest_run_date": ingest_run_date,
"created_at": timestamp,
}
conn.execute(
"""
INSERT INTO news_items
(id, ticker, published_at, trade_date, source, title, summary, url,
related, category, raw_json, ingest_run_date, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
ticker = excluded.ticker,
published_at = excluded.published_at,
trade_date = excluded.trade_date,
source = excluded.source,
title = excluded.title,
summary = excluded.summary,
url = excluded.url,
related = excluded.related,
category = excluded.category,
raw_json = excluded.raw_json,
ingest_run_date = excluded.ingest_run_date
""",
(
row["id"],
row["ticker"],
row["published_at"],
row["trade_date"],
row["source"],
row["title"],
row["summary"],
row["url"],
row["related"],
row["category"],
row["raw_json"],
row["ingest_run_date"],
row["created_at"],
),
)
normalized_rows.append(row)
return normalized_rows
def get_news_items(
self,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
"""Return normalized news rows for explain UI."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return []
sql = """
SELECT id, ticker, published_at, trade_date, source, title, summary,
url, related, category
FROM news_items
WHERE ticker = ?
"""
params: list[Any] = [symbol]
if start_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
params.append(start_date)
if end_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
params.append(end_date)
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
params.append(max(1, int(limit)))
with self._connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [
{
"id": row["id"],
"ticker": row["ticker"],
"date": row["published_at"] or row["trade_date"],
"trade_date": row["trade_date"],
"source": row["source"],
"title": row["title"],
"summary": row["summary"],
"url": row["url"],
"related": row["related"],
"category": row["category"],
}
for row in rows
]
def get_news_timeline(
self,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
) -> list[dict[str, Any]]:
"""Aggregate news counts per trade date for chart markers."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return []
sql = """
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
COUNT(*) AS count,
COUNT(DISTINCT source) AS source_count,
MAX(title) AS top_title
FROM news_items
WHERE ticker = ?
"""
params: list[Any] = [symbol]
if start_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
params.append(start_date)
if end_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
params.append(end_date)
sql += """
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
ORDER BY date ASC
"""
with self._connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [
{
"date": row["date"],
"count": int(row["count"] or 0),
"source_count": int(row["source_count"] or 0),
"top_title": row["top_title"] or "",
}
for row in rows
if row["date"]
]
def get_news_by_ids(
self,
*,
ticker: str,
article_ids: Iterable[str],
) -> list[dict[str, Any]]:
"""Return selected persisted news items."""
symbol = str(ticker or "").strip().upper()
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
if not symbol or not ids:
return []
placeholders = ",".join("?" for _ in ids)
sql = f"""
SELECT id, ticker, published_at, trade_date, source, title, summary,
url, related, category
FROM news_items
WHERE ticker = ? AND id IN ({placeholders})
ORDER BY COALESCE(published_at, trade_date) DESC
"""
with self._connect() as conn:
rows = conn.execute(sql, [symbol, *ids]).fetchall()
return [
{
"id": row["id"],
"ticker": row["ticker"],
"date": row["published_at"] or row["trade_date"],
"trade_date": row["trade_date"],
"source": row["source"],
"title": row["title"],
"summary": row["summary"],
"url": row["url"],
"related": row["related"],
"category": row["category"],
}
for row in rows
]

View File

@@ -10,6 +10,8 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from backend.data.market_store import MarketStore
from .research_db import ResearchDb
from .runtime_db import RuntimeDb
logger = logging.getLogger(__name__)
@@ -64,6 +66,8 @@ class StorageService:
self.state_dir.mkdir(parents=True, exist_ok=True)
self.server_state_file = self.state_dir / "server_state.json"
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
self.research_db = ResearchDb(self.state_dir / "research.db")
self.market_store = MarketStore()
# Feed history (for agent messages)
self.max_feed_history = 200

236
backend/tests/test_cli.py Normal file
View File

@@ -0,0 +1,236 @@
# -*- coding: utf-8 -*-
from pathlib import Path
from backend import cli
def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp_path):
project_root = tmp_path
(project_root / ".env").write_text("FINNHUB_API_KEY=test\n", encoding="utf-8")
calls = []
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
monkeypatch.setattr(
cli,
"auto_update_market_store",
lambda config_name, end_date=None: calls.append(("auto_update_market_store", config_name, end_date)),
)
monkeypatch.setattr(
cli,
"auto_enrich_market_store",
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
),
)
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
def fake_run(cmd, check=True, **kwargs):
calls.append(("subprocess.run", cmd, check))
return 0
monkeypatch.setattr(cli.subprocess, "run", fake_run)
cli.live(
mock=False,
config_name="smoke_fullstack",
host="0.0.0.0",
port=8765,
trigger_time="now",
poll_interval=10,
clean=False,
enable_memory=False,
)
assert any(item[0] == "run_data_updater" for item in calls)
assert any(
item[0] == "auto_update_market_store" and item[1] == "smoke_fullstack"
for item in calls
)
assert any(
item[0] == "auto_enrich_market_store" and item[1] == "smoke_fullstack"
for item in calls
)
run_call = next(item for item in calls if item[0] == "subprocess.run")
assert run_call[1][:6] == [
cli.sys.executable,
"-u",
"-m",
"backend.main",
"--mode",
"live",
]
def test_backtest_runs_full_market_store_prepare_before_start(monkeypatch, tmp_path):
project_root = tmp_path
calls = []
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
monkeypatch.setattr(
cli,
"auto_prepare_backtest_market_store",
lambda config_name, start_date, end_date: calls.append(
("auto_prepare_backtest_market_store", config_name, start_date, end_date)
),
)
monkeypatch.setattr(
cli,
"auto_enrich_market_store",
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
),
)
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
def fake_run(cmd, check=True, **kwargs):
calls.append(("subprocess.run", cmd, check))
return 0
monkeypatch.setattr(cli.subprocess, "run", fake_run)
cli.backtest(
start="2026-03-01",
end="2026-03-10",
config_name="smoke_fullstack",
host="0.0.0.0",
port=8765,
poll_interval=10,
clean=False,
enable_memory=False,
)
assert any(item[0] == "run_data_updater" for item in calls)
assert any(
item[0] == "auto_prepare_backtest_market_store"
and item[1:] == ("smoke_fullstack", "2026-03-01", "2026-03-10")
for item in calls
)
assert any(
item[0] == "auto_enrich_market_store"
and item[1] == "smoke_fullstack"
and item[2] == "2026-03-10"
for item in calls
)
run_call = next(item for item in calls if item[0] == "subprocess.run")
assert run_call[1][:6] == [
cli.sys.executable,
"-u",
"-m",
"backend.main",
"--mode",
"backtest",
]
def test_ingest_enrich_runs_batch_enrichment(monkeypatch):
calls = []
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL", "MSFT"])
class DummyStore:
pass
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
monkeypatch.setattr(
cli,
"enrich_symbols",
lambda store, symbols, start_date=None, end_date=None, limit=200, analysis_source="local", skip_existing=True: calls.append(
("enrich_symbols", symbols, start_date, end_date, limit, analysis_source, skip_existing)
) or [
{
"symbol": symbol,
"news_count": 3,
"queued_count": 3,
"analyzed": 3,
"skipped_existing_count": 0,
"deduped_count": 0,
"llm_count": 0,
"local_count": 3,
}
for symbol in symbols
],
)
cli.ingest_enrich(
tickers=None,
start="2026-03-01",
end="2026-03-10",
limit=150,
force=False,
config_name="smoke_fullstack",
)
assert calls == [
("enrich_symbols", ["AAPL", "MSFT"], "2026-03-01", "2026-03-10", 150, "local", True)
]
def test_ingest_report_reads_market_store_report(monkeypatch):
calls = []
printed = []
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL"])
class DummyStore:
def get_enrich_report(self, symbols=None, start_date=None, end_date=None):
calls.append(("get_enrich_report", symbols, start_date, end_date))
return [
{
"symbol": "AAPL",
"raw_news_count": 10,
"analyzed_news_count": 8,
"coverage_pct": 80.0,
"llm_count": 5,
"local_count": 3,
"latest_trade_date": "2026-03-16",
"latest_analysis_at": "2026-03-16T09:00:00",
}
]
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
monkeypatch.setattr(cli, "get_explain_model_info", lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"})
monkeypatch.setattr(cli, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
cli.ingest_report(
tickers=None,
start="2026-03-01",
end="2026-03-16",
config_name="smoke_fullstack",
only_problematic=False,
)
assert calls == [
("get_enrich_report", ["AAPL"], "2026-03-01", "2026-03-16")
]
assert printed
assert getattr(printed[0], "caption", "") == "Explain LLM: DASHSCOPE:qwen-max"
def test_filter_problematic_report_rows_keeps_low_coverage_and_no_llm():
rows = [
{
"symbol": "AAPL",
"coverage_pct": 100.0,
"llm_count": 2,
},
{
"symbol": "MSFT",
"coverage_pct": 80.0,
"llm_count": 1,
},
{
"symbol": "NVDA",
"coverage_pct": 100.0,
"llm_count": 0,
},
]
filtered = cli._filter_problematic_report_rows(rows)
assert [row["symbol"] for row in filtered] == ["MSFT", "NVDA"]

View File

@@ -0,0 +1,384 @@
# -*- coding: utf-8 -*-
import json
from types import SimpleNamespace
import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
class DummyWebSocket:
def __init__(self):
self.messages = []
async def send(self, payload: str):
self.messages.append(json.loads(payload))
class DummyStateSync:
def __init__(self, current_date="2026-03-16"):
self.state = {"current_date": current_date}
self.system_messages = []
def set_broadcast_fn(self, _fn):
return None
def update_state(self, *_args, **_kwargs):
return None
async def on_system_message(self, message):
self.system_messages.append(message)
class FakeMarketStore:
def __init__(self):
self.calls = []
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100):
self.calls.append(("get_news_items", symbol, start_date, end_date, limit))
return [
{
"id": "news-1",
"ticker": symbol,
"date": end_date,
"trade_date": end_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
}
]
def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit))
target_date = trade_date or end_date
return [
{
"id": "news-1",
"ticker": symbol,
"date": target_date,
"trade_date": target_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
"sentiment": "negative",
"relevance": "high",
"key_discussion": "Key discussion",
}
]
def get_news_by_ids_enriched(self, symbol, article_ids):
self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids)))
return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}]
def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit))
return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}}
def get_story_cache(self, symbol, *, as_of_date):
self.calls.append(("get_story_cache", symbol, as_of_date))
return None
def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"):
self.calls.append(("upsert_story_cache", symbol, as_of_date, source))
def delete_story_cache(self, symbol, *, as_of_date=None):
self.calls.append(("delete_story_cache", symbol, as_of_date))
return 1
def get_similar_day_cache(self, symbol, *, target_date):
self.calls.append(("get_similar_day_cache", symbol, target_date))
return None
def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"):
self.calls.append(("upsert_similar_day_cache", symbol, target_date, source))
def delete_similar_day_cache(self, symbol, *, target_date=None):
self.calls.append(("delete_similar_day_cache", symbol, target_date))
return 1
def get_ohlc(self, symbol, start_date, end_date):
self.calls.append(("get_ohlc", symbol, start_date, end_date))
return [
{"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103},
{"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107},
]
def make_gateway(market_store=None):
storage = SimpleNamespace(market_store=market_store or FakeMarketStore())
pipeline = SimpleNamespace(state_sync=None)
market_service = SimpleNamespace()
state_sync = DummyStateSync()
return Gateway(
market_service=market_service,
storage_service=storage,
pipeline=pipeline,
state_sync=state_sync,
config={"mode": "live"},
)
@pytest.mark.asyncio
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_timeline(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
]
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
]
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["macro"]["count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows):
return {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"news_count": len(news_rows),
}
monkeypatch.setattr(
gateway_module,
"build_range_explanation",
fake_build_range_explanation,
)
await gateway._handle_get_stock_range_explain(
websocket,
{"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
]
assert websocket.messages[-1] == {
"type": "stock_range_explain_loaded",
"ticker": "AAPL",
"result": {
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"news_count": 1,
},
}
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"build_range_explanation",
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
)
await gateway._handle_get_stock_range_explain(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"article_ids": ["news-99"],
},
)
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
assert websocket.messages[-1]["result"]["news_count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_for_date(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "limit": 10},
)
assert market_store.calls == [
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
]
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
assert websocket.messages[-1]["date"] == "2026-03-16"
@pytest.mark.asyncio
async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert "AAPL Story" in websocket.messages[-1]["story"]
@pytest.mark.asyncio
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_similar_days(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "top_k": 5},
)
assert websocket.messages[-1]["type"] == "stock_similar_days_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert isinstance(websocket.messages[-1]["items"], list)
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"force": True,
"rebuild_story": True,
"rebuild_similar_days": True,
"story_date": "2026-03-16",
"target_date": "2026-03-16",
},
)
assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls
assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert websocket.messages[-1]["stats"]["analyzed"] == 2
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"only_local_to_llm": True,
},
)
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"]
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["coro_name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task)
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway()
monkeypatch.setattr(
gateway_module,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
for symbol in symbols
],
)
await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1]

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
from unittest.mock import patch
import pandas as pd
from backend.data.historical_price_manager import HistoricalPriceManager
def test_preload_data_prefers_market_db():
manager = HistoricalPriceManager()
manager.subscribe(["AAPL"])
market_rows = [
{
"symbol": "AAPL",
"date": "2026-03-09",
"open": 100.0,
"high": 103.0,
"low": 99.0,
"close": 102.0,
"volume": 10_000,
"vwap": 101.0,
"transactions": 500,
"source": "polygon",
}
]
with (
patch.object(manager._market_store, "get_ohlc", return_value=market_rows),
patch.object(manager._router, "load_local_price_frame") as load_csv,
):
manager.preload_data("2026-03-01", "2026-03-10")
load_csv.assert_not_called()
assert "AAPL" in manager._price_cache
assert float(manager._price_cache["AAPL"].iloc[0]["close"]) == 102.0
def test_preload_data_falls_back_to_csv():
manager = HistoricalPriceManager()
manager.subscribe(["MSFT"])
csv_df = pd.DataFrame(
{
"time": ["2026-03-09"],
"open": [200.0],
"high": [205.0],
"low": [198.0],
"close": [204.0],
"volume": [20_000],
}
)
csv_df["time"] = pd.to_datetime(csv_df["time"])
csv_df["Date"] = csv_df["time"]
csv_df.set_index("Date", inplace=True)
with (
patch.object(manager._market_store, "get_ohlc", return_value=[]),
patch.object(manager._router, "load_local_price_frame", return_value=csv_df) as load_csv,
):
manager.preload_data("2026-03-01", "2026-03-10")
load_csv.assert_called_once_with("MSFT")
assert "MSFT" in manager._price_cache
assert float(manager._price_cache["MSFT"].iloc[0]["close"]) == 204.0

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
from backend.enrich import llm_enricher
class DummyResponse:
def __init__(self, metadata):
self.metadata = metadata
class DummyModel:
def __init__(self, metadata):
self.metadata = metadata
self.calls = []
async def __call__(self, messages, structured_model=None, **kwargs):
self.calls.append(
{
"messages": messages,
"structured_model": structured_model,
"kwargs": kwargs,
}
)
return DummyResponse(self.metadata)
def test_analyze_news_row_with_llm_uses_agentscope_model(monkeypatch):
model = DummyModel(
{
"id": "news-1",
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Demand remains resilient",
"summary": "Structured summary",
"reason_growth": "Orders improved",
"reason_decrease": "",
}
)
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_news_row_with_llm(
{
"id": "news-1",
"title": "Apple expands AI features",
"summary": "New devices and software updates were announced.",
}
)
assert result["sentiment"] == "positive"
assert result["summary"] == "Structured summary"
assert result["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsItem
def test_analyze_news_rows_with_llm_uses_agentscope_structured_batch(monkeypatch):
model = DummyModel(
{
"items": [
{
"id": "news-1",
"relevance": "high",
"sentiment": "negative",
"key_discussion": "Margin pressure",
"summary": "Batch summary",
"reason_growth": "",
"reason_decrease": "Costs rose",
}
]
}
)
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_news_rows_with_llm(
[
{
"id": "news-1",
"title": "Apple margins pressured",
"summary": "Costs increased this quarter.",
}
]
)
assert result["news-1"]["sentiment"] == "negative"
assert result["news-1"]["reason_decrease"] == "Costs rose"
assert result["news-1"]["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsBatch
def test_analyze_range_with_llm_uses_agentscope_structured_output(monkeypatch):
model = DummyModel(
{
"summary": "该股在区间内震荡下行,相关新闻主要集中在盈利预期和供应链扰动。",
"trend_analysis": "前半段受利空新闻压制,后半段跌幅收敛。",
"bullish_factors": ["估值消化后出现部分承接"],
"bearish_factors": ["盈利预期下修", "供应链扰动持续"],
}
)
monkeypatch.setattr(llm_enricher, "llm_range_analysis_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_range_with_llm(
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"price_change_pct": -3.42,
}
)
assert result["summary"].startswith("该股在区间内震荡下行")
assert result["model_label"] == "DASHSCOPE:qwen-max"
assert result["bearish_factors"] == ["盈利预期下修", "供应链扰动持续"]
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.RangeAnalysisPayload

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from pathlib import Path
from backend.data.market_store import MarketStore
def test_get_enrich_report_summarizes_coverage(tmp_path: Path):
store = MarketStore(tmp_path / "market_research.db")
store.upsert_news(
"AAPL",
[
{
"id": "news-1",
"published_utc": "2026-03-10T12:00:00Z",
"title": "Apple earnings beat",
"summary": "Revenue topped expectations",
"tickers": ["AAPL"],
},
{
"id": "news-2",
"published_utc": "2026-03-11T12:00:00Z",
"title": "Apple supply chain warning",
"summary": "Outlook softened",
"tickers": ["AAPL"],
},
],
)
store.set_trade_dates(
[
{"news_id": "news-1", "symbol": "AAPL", "trade_date": "2026-03-10"},
{"news_id": "news-2", "symbol": "AAPL", "trade_date": "2026-03-11"},
]
)
store.upsert_news_analysis(
"AAPL",
[
{
"news_id": "news-1",
"trade_date": "2026-03-10",
"summary": "LLM enriched",
"analysis_source": "llm",
}
],
analysis_source="llm",
)
rows = store.get_enrich_report(["AAPL"])
assert len(rows) == 1
assert rows[0]["symbol"] == "AAPL"
assert rows[0]["raw_news_count"] == 2
assert rows[0]["analyzed_news_count"] == 1
assert rows[0]["coverage_pct"] == 50.0
assert rows[0]["llm_count"] == 1

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
from backend.enrich import news_enricher
def test_classify_news_row_falls_back_to_local_rules(monkeypatch):
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
result = news_enricher.classify_news_row(
{
"title": "Apple shares drop after weak guidance",
"summary": "Investors reacted negatively to softer-than-expected outlook.",
}
)
assert result["analysis_source"] == "local"
assert result["sentiment"] == "negative"
assert result["summary"]
def test_classify_news_row_prefers_llm_when_available(monkeypatch):
monkeypatch.setattr(
news_enricher,
"analyze_news_row_with_llm",
lambda row: {
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Demand resilience",
"summary": "LLM summary",
"reason_growth": "Orders remain strong",
"reason_decrease": "",
"raw_json": {"provider": "llm"},
},
)
result = news_enricher.classify_news_row(
{
"title": "Apple expands AI features",
"summary": "New devices and software updates were announced.",
}
)
assert result["analysis_source"] == "llm"
assert result["sentiment"] == "positive"
assert result["summary"] == "LLM summary"
def test_build_analysis_rows_prefers_batch_llm_and_dedupes(monkeypatch):
monkeypatch.setattr(news_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(news_enricher, "get_env_int", lambda key, default=0: 8)
monkeypatch.setattr(
news_enricher,
"analyze_news_rows_with_llm",
lambda rows: {
"news-1": {
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Batch result",
"summary": "Batch summary",
"reason_growth": "Growth",
"reason_decrease": "",
"raw_json": {"provider": "batch"},
}
},
)
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
rows = news_enricher.build_analysis_rows(
symbol="AAPL",
news_rows=[
{"id": "news-1", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
{"id": "news-2", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
],
ohlc_rows=[],
)
rows, stats = rows
assert len(rows) == 1
assert rows[0]["analysis_source"] == "llm"
assert rows[0]["summary"] == "Batch summary"
assert stats["deduped_count"] == 1
assert stats["llm_count"] == 1
def test_enrich_news_for_symbol_skips_existing(monkeypatch):
class DummyStore:
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
return [
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
]
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
return {"news-1"}
def get_ohlc(self, symbol, start_date, end_date):
return []
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
self.rows = rows
return len(rows)
monkeypatch.setattr(
news_enricher,
"build_analysis_rows",
lambda symbol, news_rows, ohlc_rows: (
[
{
"news_id": row["id"],
"trade_date": row["trade_date"],
"summary": row["summary"],
"analysis_source": "local",
}
for row in news_rows
],
{"deduped_count": 0, "llm_count": 0, "local_count": len(news_rows)},
),
)
store = DummyStore()
result = news_enricher.enrich_news_for_symbol(store, "AAPL")
assert result["news_count"] == 2
assert result["queued_count"] == 1
assert result["skipped_existing_count"] == 1
assert len(store.rows) == 1
assert store.rows[0]["news_id"] == "news-2"
def test_enrich_news_for_symbol_only_reanalyzes_local(monkeypatch):
class DummyStore:
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
return [
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
{"id": "news-3", "trade_date": "2026-03-12", "title": "Three", "summary": "Three"},
]
def get_analyzed_news_sources(self, symbol, start_date=None, end_date=None):
return {"news-1": "local", "news-2": "llm"}
def get_ohlc(self, symbol, start_date, end_date):
return []
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
self.rows = rows
return len(rows)
monkeypatch.setattr(
news_enricher,
"build_analysis_rows",
lambda symbol, news_rows, ohlc_rows: (
[
{
"news_id": row["id"],
"trade_date": row["trade_date"],
"summary": row["summary"],
"analysis_source": "llm" if row["id"] == "news-1" else "local",
}
for row in news_rows
],
{"deduped_count": 0, "llm_count": 1, "local_count": 0},
),
)
store = DummyStore()
result = news_enricher.enrich_news_for_symbol(
store,
"AAPL",
only_reanalyze_local=True,
)
assert result["news_count"] == 3
assert result["queued_count"] == 1
assert result["skipped_existing_count"] == 2
assert result["only_reanalyze_local"] is True
assert result["upgraded_local_to_llm_count"] == 1
assert result["execution_summary"]["upgraded_dates"] == ["2026-03-10"]
assert result["execution_summary"]["remaining_local_titles"] == []
assert result["execution_summary"]["skipped_missing_analysis_count"] == 1
assert result["execution_summary"]["skipped_non_local_count"] == 1
assert [row["news_id"] for row in store.rows] == ["news-1"]

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from types import SimpleNamespace
from backend.explain import range_explainer
def test_build_range_explanation_prefers_llm_text_when_available(monkeypatch):
monkeypatch.setattr(
range_explainer,
"get_prices",
lambda ticker, start_date, end_date: [
SimpleNamespace(open=100, close=98, high=102, low=97, volume=1000),
SimpleNamespace(open=98, close=96, high=99, low=95, volume=1100),
SimpleNamespace(open=96, close=97, high=98, low=94, volume=1200),
],
)
monkeypatch.setattr(
range_explainer,
"analyze_range_with_llm",
lambda payload: {
"summary": "区间内整体偏弱,主题集中在盈利预期和供应链风险。",
"trend_analysis": "前半段快速下探,后半段出现修复。",
"bullish_factors": ["回调后出现承接"],
"bearish_factors": ["盈利预期承压"],
"model_label": "DASHSCOPE:qwen-max",
},
)
result = range_explainer.build_range_explanation(
ticker="AAPL",
start_date="2026-03-10",
end_date="2026-03-16",
news_rows=[
{
"id": "news-1",
"trade_date": "2026-03-10",
"title": "Apple margin pressure concerns grow",
"summary": "Investors focused on weaker margin outlook.",
"sentiment": "negative",
"relevance": "high",
"ret_t0": -0.02,
"reason_decrease": "盈利预期承压",
"category": "earnings",
}
],
)
assert result["analysis"]["summary"] == "区间内整体偏弱,主题集中在盈利预期和供应链风险。"
assert result["analysis"]["trend_analysis"] == "前半段快速下探,后半段出现修复。"
assert result["analysis"]["bullish_factors"] == ["回调后出现承接"]
assert result["analysis"]["analysis_source"] == "llm"
assert result["analysis"]["analysis_model_label"] == "DASHSCOPE:qwen-max"
assert result["news_count"] == 1