363 lines
12 KiB
Python
363 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Lightweight news enrichment for explain-oriented market analysis."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from typing import Any
|
|
|
|
from backend.config.env_config import get_env_int
|
|
from backend.enrich.llm_enricher import (
|
|
analyze_news_row_with_llm,
|
|
analyze_news_rows_with_llm,
|
|
llm_enrichment_enabled,
|
|
)
|
|
from backend.data.market_store import MarketStore
|
|
|
|
|
|
POSITIVE_KEYWORDS = (
|
|
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
|
|
"partnership", "approved", "launch", "expands", "profit",
|
|
)
|
|
NEGATIVE_KEYWORDS = (
|
|
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
|
|
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
|
|
)
|
|
HIGH_RELEVANCE_KEYWORDS = (
|
|
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
|
|
"regulation", "acquisition", "buyback", "forecast", "launch",
|
|
)
|
|
|
|
|
|
def _dedupe_key(row: dict[str, Any]) -> str:
|
|
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
|
|
title = str(row.get("title") or "").strip().lower()
|
|
summary = str(row.get("summary") or "").strip().lower()[:160]
|
|
raw = f"{trade_date}::{title}::{summary}"
|
|
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
|
|
chunk_size = max(1, int(size))
|
|
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
|
|
|
|
|
|
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
|
|
"""Return a lightweight explain-oriented analysis for one article."""
|
|
llm_result = analyze_news_row_with_llm(row)
|
|
if isinstance(llm_result, dict):
|
|
merged = dict(llm_result)
|
|
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
|
merged.setdefault("raw_json", row)
|
|
merged["analysis_source"] = "llm"
|
|
return merged
|
|
|
|
title = str(row.get("title") or "").strip()
|
|
summary = str(row.get("summary") or "").strip()
|
|
text = f"{title} {summary}".lower()
|
|
|
|
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
|
|
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
|
|
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
|
|
|
|
if len(positive_hits) > len(negative_hits):
|
|
sentiment = "positive"
|
|
elif len(negative_hits) > len(positive_hits):
|
|
sentiment = "negative"
|
|
else:
|
|
sentiment = "neutral"
|
|
|
|
relevance = "high" if relevance_hits else "medium" if title else "low"
|
|
summary_text = summary or title
|
|
key_discussion = ""
|
|
if relevance_hits:
|
|
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
|
|
elif summary_text:
|
|
key_discussion = summary_text[:160]
|
|
|
|
reason_growth = ""
|
|
reason_decrease = ""
|
|
if sentiment == "positive":
|
|
reason_growth = summary_text[:200]
|
|
elif sentiment == "negative":
|
|
reason_decrease = summary_text[:200]
|
|
|
|
return {
|
|
"relevance": relevance,
|
|
"sentiment": sentiment,
|
|
"key_discussion": key_discussion,
|
|
"summary": summary_text[:280],
|
|
"reason_growth": reason_growth,
|
|
"reason_decrease": reason_decrease,
|
|
"analysis_source": "local",
|
|
"raw_json": row,
|
|
}
|
|
|
|
|
|
def attach_forward_returns(
|
|
*,
|
|
news_rows: list[dict[str, Any]],
|
|
ohlc_rows: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
"""Attach forward-return labels to each analyzed row."""
|
|
if not ohlc_rows:
|
|
return news_rows
|
|
|
|
closes_by_date = {
|
|
str(row.get("date")): float(row.get("close"))
|
|
for row in ohlc_rows
|
|
if row.get("date") is not None and row.get("close") is not None
|
|
}
|
|
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
|
|
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
|
|
|
|
horizons = {
|
|
"ret_t0": 0,
|
|
"ret_t1": 1,
|
|
"ret_t3": 3,
|
|
"ret_t5": 5,
|
|
"ret_t10": 10,
|
|
}
|
|
|
|
enriched: list[dict[str, Any]] = []
|
|
for row in news_rows:
|
|
trade_date = str(row.get("trade_date") or "")[:10]
|
|
base_close = closes_by_date.get(trade_date)
|
|
if not trade_date or base_close in (None, 0):
|
|
enriched.append(row)
|
|
continue
|
|
|
|
next_row = dict(row)
|
|
base_index = date_index.get(trade_date)
|
|
if base_index is None:
|
|
enriched.append(next_row)
|
|
continue
|
|
|
|
for field, offset in horizons.items():
|
|
target_index = base_index + offset
|
|
if target_index >= len(ordered_dates):
|
|
next_row[field] = None
|
|
continue
|
|
target_close = closes_by_date.get(ordered_dates[target_index])
|
|
next_row[field] = (
|
|
(float(target_close) - float(base_close)) / float(base_close)
|
|
if target_close not in (None, 0)
|
|
else None
|
|
)
|
|
enriched.append(next_row)
|
|
return enriched
|
|
|
|
|
|
def build_analysis_rows(
|
|
*,
|
|
symbol: str,
|
|
news_rows: list[dict[str, Any]],
|
|
ohlc_rows: list[dict[str, Any]],
|
|
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
|
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
|
|
llm_results: dict[str, dict[str, Any]] = {}
|
|
if llm_enrichment_enabled():
|
|
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
|
|
for chunk in _chunk_rows(news_rows, batch_size):
|
|
llm_results.update(analyze_news_rows_with_llm(chunk))
|
|
|
|
staged_rows: list[dict[str, Any]] = []
|
|
seen_dedupe_keys: set[str] = set()
|
|
deduped_count = 0
|
|
llm_count = 0
|
|
local_count = 0
|
|
for row in news_rows:
|
|
news_id = str(row.get("id") or "").strip()
|
|
if not news_id:
|
|
continue
|
|
dedupe_key = _dedupe_key(row)
|
|
if dedupe_key in seen_dedupe_keys:
|
|
deduped_count += 1
|
|
continue
|
|
seen_dedupe_keys.add(dedupe_key)
|
|
batch_result = llm_results.get(news_id)
|
|
if isinstance(batch_result, dict):
|
|
analysis = dict(batch_result)
|
|
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
|
analysis.setdefault("raw_json", row)
|
|
analysis["analysis_source"] = "llm"
|
|
llm_count += 1
|
|
else:
|
|
analysis = classify_news_row(row)
|
|
if analysis.get("analysis_source") == "llm":
|
|
llm_count += 1
|
|
else:
|
|
local_count += 1
|
|
staged_rows.append(
|
|
{
|
|
"news_id": news_id,
|
|
"trade_date": str(row.get("trade_date") or "")[:10] or None,
|
|
**analysis,
|
|
}
|
|
)
|
|
return (
|
|
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
|
|
{
|
|
"deduped_count": deduped_count,
|
|
"llm_count": llm_count,
|
|
"local_count": local_count,
|
|
},
|
|
)
|
|
|
|
|
|
def enrich_news_for_symbol(
|
|
store: MarketStore,
|
|
symbol: str,
|
|
*,
|
|
start_date: str | None = None,
|
|
end_date: str | None = None,
|
|
limit: int = 200,
|
|
analysis_source: str = "local",
|
|
skip_existing: bool = True,
|
|
only_reanalyze_local: bool = False,
|
|
) -> dict[str, Any]:
|
|
"""Read raw market news, compute explain fields, and persist them."""
|
|
normalized_symbol = str(symbol or "").strip().upper()
|
|
if not normalized_symbol:
|
|
return {"symbol": "", "analyzed": 0}
|
|
|
|
news_rows = store.get_news_items(
|
|
normalized_symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
)
|
|
total_news_count = len(news_rows)
|
|
skipped_existing_count = 0
|
|
analyzed_sources: dict[str, str] = {}
|
|
skipped_missing_analysis_count = 0
|
|
skipped_non_local_count = 0
|
|
if news_rows and only_reanalyze_local:
|
|
analyzed_sources = store.get_analyzed_news_sources(
|
|
normalized_symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
skipped_missing_analysis_count = sum(
|
|
1
|
|
for row in news_rows
|
|
if str(row.get("id") or "").strip() not in analyzed_sources
|
|
)
|
|
skipped_non_local_count = sum(
|
|
1
|
|
for row in news_rows
|
|
if str(row.get("id") or "").strip() in analyzed_sources
|
|
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
|
)
|
|
skipped_existing_count = sum(
|
|
1
|
|
for row in news_rows
|
|
if str(row.get("id") or "").strip() not in analyzed_sources
|
|
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
|
)
|
|
news_rows = [
|
|
row for row in news_rows
|
|
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
|
|
]
|
|
elif skip_existing and news_rows:
|
|
analyzed_ids = store.get_analyzed_news_ids(
|
|
normalized_symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
skipped_existing_count = sum(
|
|
1
|
|
for row in news_rows
|
|
if str(row.get("id") or "").strip() in analyzed_ids
|
|
)
|
|
news_rows = [
|
|
row for row in news_rows
|
|
if str(row.get("id") or "").strip() not in analyzed_ids
|
|
]
|
|
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
|
|
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
|
|
ohlc_rows = (
|
|
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
|
|
if ohlc_start and ohlc_end
|
|
else []
|
|
)
|
|
analysis_rows, stats = build_analysis_rows(
|
|
symbol=normalized_symbol,
|
|
news_rows=news_rows,
|
|
ohlc_rows=ohlc_rows,
|
|
)
|
|
analyzed = store.upsert_news_analysis(
|
|
normalized_symbol,
|
|
analysis_rows,
|
|
analysis_source=analysis_source,
|
|
)
|
|
upgraded_dates = sorted(
|
|
{
|
|
str(row.get("trade_date") or "")[:10]
|
|
for row in analysis_rows
|
|
if str(row.get("analysis_source") or "").strip().lower() == "llm"
|
|
and str(row.get("trade_date") or "").strip()
|
|
}
|
|
)
|
|
remaining_local_titles = [
|
|
str(row.get("title") or row.get("news_id") or "").strip()
|
|
for row in news_rows
|
|
for analyzed_row in analysis_rows
|
|
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
|
|
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
|
|
][:5]
|
|
return {
|
|
"symbol": normalized_symbol,
|
|
"analyzed": analyzed,
|
|
"news_count": total_news_count,
|
|
"queued_count": len(news_rows),
|
|
"skipped_existing_count": skipped_existing_count,
|
|
"deduped_count": stats["deduped_count"],
|
|
"llm_count": stats["llm_count"],
|
|
"local_count": stats["local_count"],
|
|
"only_reanalyze_local": only_reanalyze_local,
|
|
"upgraded_local_to_llm_count": (
|
|
stats["llm_count"]
|
|
if only_reanalyze_local
|
|
else 0
|
|
),
|
|
"execution_summary": {
|
|
"upgraded_dates": upgraded_dates[:5],
|
|
"remaining_local_titles": remaining_local_titles,
|
|
"skipped_missing_analysis_count": skipped_missing_analysis_count,
|
|
"skipped_non_local_count": skipped_non_local_count,
|
|
},
|
|
}
|
|
|
|
|
|
def enrich_symbols(
|
|
store: MarketStore,
|
|
symbols: list[str],
|
|
*,
|
|
start_date: str | None = None,
|
|
end_date: str | None = None,
|
|
limit: int = 200,
|
|
analysis_source: str = "local",
|
|
skip_existing: bool = True,
|
|
only_reanalyze_local: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Batch enrich multiple symbols for explain-oriented news analysis."""
|
|
results = []
|
|
for symbol in symbols:
|
|
normalized_symbol = str(symbol or "").strip().upper()
|
|
if not normalized_symbol:
|
|
continue
|
|
results.append(
|
|
enrich_news_for_symbol(
|
|
store,
|
|
normalized_symbol,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
analysis_source=analysis_source,
|
|
skip_existing=skip_existing,
|
|
only_reanalyze_local=only_reanalyze_local,
|
|
)
|
|
)
|
|
return results
|