Add explain analysis workflow and UI
This commit is contained in:
214
backend/explain/range_explainer.py
Normal file
214
backend/explain/range_explainer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Local range explanation built from price and persisted news."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.enrich.llm_enricher import analyze_range_with_llm
|
||||
from backend.explain.category_engine import categorize_news_rows
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
|
||||
def _rank_event_score(row: Dict[str, Any]) -> float:
|
||||
relevance = str(row.get("relevance") or "").strip().lower()
|
||||
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
|
||||
relevance,
|
||||
0.5,
|
||||
)
|
||||
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
|
||||
return relevance_score + impact_score
|
||||
|
||||
|
||||
def summarize_bullish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "positive":
|
||||
continue
|
||||
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def summarize_bearish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "negative":
|
||||
continue
|
||||
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def build_trend_analysis(prices: list[Any]) -> str:
|
||||
if len(prices) < 2:
|
||||
return "区间样本较短,暂不具备足够趋势信息。"
|
||||
if len(prices) < 3:
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
|
||||
|
||||
mid = len(prices) // 2
|
||||
first_open = float(prices[0].open)
|
||||
first_close = float(prices[mid].close)
|
||||
second_open = float(prices[mid].open)
|
||||
second_close = float(prices[-1].close)
|
||||
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
|
||||
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
|
||||
return (
|
||||
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%,"
|
||||
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%,"
|
||||
"说明价格驱动在区间内部出现了阶段性切换。"
|
||||
)
|
||||
|
||||
|
||||
def build_range_explanation(
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
news_rows: list[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Explain a price range with local price and news heuristics."""
|
||||
prices = get_prices(ticker, start_date, end_date)
|
||||
if not prices:
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "No OHLC data for this range",
|
||||
}
|
||||
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
high_price = max(float(price.high) for price in prices)
|
||||
low_price = min(float(price.low) for price in prices)
|
||||
total_volume = sum(int(price.volume) for price in prices)
|
||||
price_change_pct = (
|
||||
((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
)
|
||||
|
||||
categories = categorize_news_rows(news_rows)
|
||||
news_count = len(news_rows)
|
||||
dominant_categories = sorted(
|
||||
(
|
||||
{"category": key, "count": value["count"]}
|
||||
for key, value in categories.items()
|
||||
if value["count"] > 0
|
||||
),
|
||||
key=lambda item: item["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
|
||||
category_text = (
|
||||
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}。"
|
||||
if dominant_categories
|
||||
else "区间内未识别出明显的主题聚类。"
|
||||
)
|
||||
summary = (
|
||||
f"{ticker} 在 {start_date} 至 {end_date} 区间内{direction} {abs(price_change_pct):.2f}%,"
|
||||
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
|
||||
)
|
||||
|
||||
bullish_factors = summarize_bullish_factors(news_rows)
|
||||
bearish_factors = summarize_bearish_factors(news_rows)
|
||||
trend_analysis = build_trend_analysis(prices)
|
||||
llm_source = "local"
|
||||
|
||||
range_payload = {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"bullish_factors": bullish_factors[:3],
|
||||
"bearish_factors": bearish_factors[:3],
|
||||
"trend_analysis": trend_analysis,
|
||||
"top_news": [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "",
|
||||
"summary": row.get("summary") or "",
|
||||
"sentiment": row.get("sentiment") or "",
|
||||
"relevance": row.get("relevance") or "",
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
|
||||
],
|
||||
}
|
||||
llm_analysis = analyze_range_with_llm(range_payload)
|
||||
if isinstance(llm_analysis, dict):
|
||||
summary = llm_analysis.get("summary") or summary
|
||||
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
|
||||
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
|
||||
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
|
||||
llm_source = "llm"
|
||||
|
||||
key_events = [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "Untitled news",
|
||||
"summary": row.get("summary") or "",
|
||||
"category": row.get("category") or "",
|
||||
"id": row.get("id"),
|
||||
"sentiment": row.get("sentiment"),
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"open_price": open_price,
|
||||
"close_price": close_price,
|
||||
"high_price": high_price,
|
||||
"low_price": low_price,
|
||||
"total_volume": total_volume,
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"analysis": {
|
||||
"summary": summary,
|
||||
"key_events": key_events,
|
||||
"bullish_factors": bullish_factors,
|
||||
"bearish_factors": bearish_factors,
|
||||
"trend_analysis": trend_analysis,
|
||||
"analysis_source": llm_source,
|
||||
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user