Add explain analysis workflow and UI
This commit is contained in:
362
backend/enrich/news_enricher.py
Normal file
362
backend/enrich/news_enricher.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Lightweight news enrichment for explain-oriented market analysis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from backend.config.env_config import get_env_int
|
||||
from backend.enrich.llm_enricher import (
|
||||
analyze_news_row_with_llm,
|
||||
analyze_news_rows_with_llm,
|
||||
llm_enrichment_enabled,
|
||||
)
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
POSITIVE_KEYWORDS = (
|
||||
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
|
||||
"partnership", "approved", "launch", "expands", "profit",
|
||||
)
|
||||
NEGATIVE_KEYWORDS = (
|
||||
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
|
||||
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
|
||||
)
|
||||
HIGH_RELEVANCE_KEYWORDS = (
|
||||
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
|
||||
"regulation", "acquisition", "buyback", "forecast", "launch",
|
||||
)
|
||||
|
||||
|
||||
def _dedupe_key(row: dict[str, Any]) -> str:
|
||||
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
|
||||
title = str(row.get("title") or "").strip().lower()
|
||||
summary = str(row.get("summary") or "").strip().lower()[:160]
|
||||
raw = f"{trade_date}::{title}::{summary}"
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
|
||||
chunk_size = max(1, int(size))
|
||||
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
|
||||
|
||||
|
||||
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return a lightweight explain-oriented analysis for one article."""
|
||||
llm_result = analyze_news_row_with_llm(row)
|
||||
if isinstance(llm_result, dict):
|
||||
merged = dict(llm_result)
|
||||
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
merged.setdefault("raw_json", row)
|
||||
merged["analysis_source"] = "llm"
|
||||
return merged
|
||||
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
text = f"{title} {summary}".lower()
|
||||
|
||||
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
|
||||
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
|
||||
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
|
||||
|
||||
if len(positive_hits) > len(negative_hits):
|
||||
sentiment = "positive"
|
||||
elif len(negative_hits) > len(positive_hits):
|
||||
sentiment = "negative"
|
||||
else:
|
||||
sentiment = "neutral"
|
||||
|
||||
relevance = "high" if relevance_hits else "medium" if title else "low"
|
||||
summary_text = summary or title
|
||||
key_discussion = ""
|
||||
if relevance_hits:
|
||||
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
|
||||
elif summary_text:
|
||||
key_discussion = summary_text[:160]
|
||||
|
||||
reason_growth = ""
|
||||
reason_decrease = ""
|
||||
if sentiment == "positive":
|
||||
reason_growth = summary_text[:200]
|
||||
elif sentiment == "negative":
|
||||
reason_decrease = summary_text[:200]
|
||||
|
||||
return {
|
||||
"relevance": relevance,
|
||||
"sentiment": sentiment,
|
||||
"key_discussion": key_discussion,
|
||||
"summary": summary_text[:280],
|
||||
"reason_growth": reason_growth,
|
||||
"reason_decrease": reason_decrease,
|
||||
"analysis_source": "local",
|
||||
"raw_json": row,
|
||||
}
|
||||
|
||||
|
||||
def attach_forward_returns(
|
||||
*,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Attach forward-return labels to each analyzed row."""
|
||||
if not ohlc_rows:
|
||||
return news_rows
|
||||
|
||||
closes_by_date = {
|
||||
str(row.get("date")): float(row.get("close"))
|
||||
for row in ohlc_rows
|
||||
if row.get("date") is not None and row.get("close") is not None
|
||||
}
|
||||
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
|
||||
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
|
||||
|
||||
horizons = {
|
||||
"ret_t0": 0,
|
||||
"ret_t1": 1,
|
||||
"ret_t3": 3,
|
||||
"ret_t5": 5,
|
||||
"ret_t10": 10,
|
||||
}
|
||||
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for row in news_rows:
|
||||
trade_date = str(row.get("trade_date") or "")[:10]
|
||||
base_close = closes_by_date.get(trade_date)
|
||||
if not trade_date or base_close in (None, 0):
|
||||
enriched.append(row)
|
||||
continue
|
||||
|
||||
next_row = dict(row)
|
||||
base_index = date_index.get(trade_date)
|
||||
if base_index is None:
|
||||
enriched.append(next_row)
|
||||
continue
|
||||
|
||||
for field, offset in horizons.items():
|
||||
target_index = base_index + offset
|
||||
if target_index >= len(ordered_dates):
|
||||
next_row[field] = None
|
||||
continue
|
||||
target_close = closes_by_date.get(ordered_dates[target_index])
|
||||
next_row[field] = (
|
||||
(float(target_close) - float(base_close)) / float(base_close)
|
||||
if target_close not in (None, 0)
|
||||
else None
|
||||
)
|
||||
enriched.append(next_row)
|
||||
return enriched
|
||||
|
||||
|
||||
def build_analysis_rows(
|
||||
*,
|
||||
symbol: str,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
||||
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
|
||||
llm_results: dict[str, dict[str, Any]] = {}
|
||||
if llm_enrichment_enabled():
|
||||
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
|
||||
for chunk in _chunk_rows(news_rows, batch_size):
|
||||
llm_results.update(analyze_news_rows_with_llm(chunk))
|
||||
|
||||
staged_rows: list[dict[str, Any]] = []
|
||||
seen_dedupe_keys: set[str] = set()
|
||||
deduped_count = 0
|
||||
llm_count = 0
|
||||
local_count = 0
|
||||
for row in news_rows:
|
||||
news_id = str(row.get("id") or "").strip()
|
||||
if not news_id:
|
||||
continue
|
||||
dedupe_key = _dedupe_key(row)
|
||||
if dedupe_key in seen_dedupe_keys:
|
||||
deduped_count += 1
|
||||
continue
|
||||
seen_dedupe_keys.add(dedupe_key)
|
||||
batch_result = llm_results.get(news_id)
|
||||
if isinstance(batch_result, dict):
|
||||
analysis = dict(batch_result)
|
||||
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
analysis.setdefault("raw_json", row)
|
||||
analysis["analysis_source"] = "llm"
|
||||
llm_count += 1
|
||||
else:
|
||||
analysis = classify_news_row(row)
|
||||
if analysis.get("analysis_source") == "llm":
|
||||
llm_count += 1
|
||||
else:
|
||||
local_count += 1
|
||||
staged_rows.append(
|
||||
{
|
||||
"news_id": news_id,
|
||||
"trade_date": str(row.get("trade_date") or "")[:10] or None,
|
||||
**analysis,
|
||||
}
|
||||
)
|
||||
return (
|
||||
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
|
||||
{
|
||||
"deduped_count": deduped_count,
|
||||
"llm_count": llm_count,
|
||||
"local_count": local_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def enrich_news_for_symbol(
|
||||
store: MarketStore,
|
||||
symbol: str,
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Read raw market news, compute explain fields, and persist them."""
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
return {"symbol": "", "analyzed": 0}
|
||||
|
||||
news_rows = store.get_news_items(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
total_news_count = len(news_rows)
|
||||
skipped_existing_count = 0
|
||||
analyzed_sources: dict[str, str] = {}
|
||||
skipped_missing_analysis_count = 0
|
||||
skipped_non_local_count = 0
|
||||
if news_rows and only_reanalyze_local:
|
||||
analyzed_sources = store.get_analyzed_news_sources(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_missing_analysis_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
)
|
||||
skipped_non_local_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_sources
|
||||
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
|
||||
]
|
||||
elif skip_existing and news_rows:
|
||||
analyzed_ids = store.get_analyzed_news_ids(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_ids
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_ids
|
||||
]
|
||||
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
|
||||
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
|
||||
ohlc_rows = (
|
||||
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
|
||||
if ohlc_start and ohlc_end
|
||||
else []
|
||||
)
|
||||
analysis_rows, stats = build_analysis_rows(
|
||||
symbol=normalized_symbol,
|
||||
news_rows=news_rows,
|
||||
ohlc_rows=ohlc_rows,
|
||||
)
|
||||
analyzed = store.upsert_news_analysis(
|
||||
normalized_symbol,
|
||||
analysis_rows,
|
||||
analysis_source=analysis_source,
|
||||
)
|
||||
upgraded_dates = sorted(
|
||||
{
|
||||
str(row.get("trade_date") or "")[:10]
|
||||
for row in analysis_rows
|
||||
if str(row.get("analysis_source") or "").strip().lower() == "llm"
|
||||
and str(row.get("trade_date") or "").strip()
|
||||
}
|
||||
)
|
||||
remaining_local_titles = [
|
||||
str(row.get("title") or row.get("news_id") or "").strip()
|
||||
for row in news_rows
|
||||
for analyzed_row in analysis_rows
|
||||
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
|
||||
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
|
||||
][:5]
|
||||
return {
|
||||
"symbol": normalized_symbol,
|
||||
"analyzed": analyzed,
|
||||
"news_count": total_news_count,
|
||||
"queued_count": len(news_rows),
|
||||
"skipped_existing_count": skipped_existing_count,
|
||||
"deduped_count": stats["deduped_count"],
|
||||
"llm_count": stats["llm_count"],
|
||||
"local_count": stats["local_count"],
|
||||
"only_reanalyze_local": only_reanalyze_local,
|
||||
"upgraded_local_to_llm_count": (
|
||||
stats["llm_count"]
|
||||
if only_reanalyze_local
|
||||
else 0
|
||||
),
|
||||
"execution_summary": {
|
||||
"upgraded_dates": upgraded_dates[:5],
|
||||
"remaining_local_titles": remaining_local_titles,
|
||||
"skipped_missing_analysis_count": skipped_missing_analysis_count,
|
||||
"skipped_non_local_count": skipped_non_local_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def enrich_symbols(
|
||||
store: MarketStore,
|
||||
symbols: list[str],
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Batch enrich multiple symbols for explain-oriented news analysis."""
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
continue
|
||||
results.append(
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
analysis_source=analysis_source,
|
||||
skip_existing=skip_existing,
|
||||
only_reanalyze_local=only_reanalyze_local,
|
||||
)
|
||||
)
|
||||
return results
|
||||
Reference in New Issue
Block a user