Files
evotraders/backend/explain/range_explainer.py

215 lines
7.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""Local range explanation built from price and persisted news."""
from __future__ import annotations
from typing import Any, Dict
from backend.enrich.llm_enricher import analyze_range_with_llm
from backend.explain.category_engine import categorize_news_rows
from backend.tools.data_tools import get_prices
def _rank_event_score(row: Dict[str, Any]) -> float:
relevance = str(row.get("relevance") or "").strip().lower()
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
relevance,
0.5,
)
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
return relevance_score + impact_score
def summarize_bullish_factors(
news_rows: list[Dict[str, Any]],
*,
limit: int = 5,
) -> list[str]:
factors = []
for row in news_rows:
if str(row.get("sentiment") or "").strip().lower() != "positive":
continue
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
if candidate:
factors.append(str(candidate).strip())
seen = set()
output = []
for factor in factors:
if factor in seen:
continue
seen.add(factor)
output.append(factor[:200])
if len(output) >= limit:
break
return output
def summarize_bearish_factors(
news_rows: list[Dict[str, Any]],
*,
limit: int = 5,
) -> list[str]:
factors = []
for row in news_rows:
if str(row.get("sentiment") or "").strip().lower() != "negative":
continue
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
if candidate:
factors.append(str(candidate).strip())
seen = set()
output = []
for factor in factors:
if factor in seen:
continue
seen.add(factor)
output.append(factor[:200])
if len(output) >= limit:
break
return output
def build_trend_analysis(prices: list[Any]) -> str:
if len(prices) < 2:
return "区间样本较短,暂不具备足够趋势信息。"
if len(prices) < 3:
open_price = float(prices[0].open)
close_price = float(prices[-1].close)
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
mid = len(prices) // 2
first_open = float(prices[0].open)
first_close = float(prices[mid].close)
second_open = float(prices[mid].open)
second_close = float(prices[-1].close)
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
return (
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%"
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%"
"说明价格驱动在区间内部出现了阶段性切换。"
)
def build_range_explanation(
*,
ticker: str,
start_date: str,
end_date: str,
news_rows: list[Dict[str, Any]],
) -> Dict[str, Any]:
"""Explain a price range with local price and news heuristics."""
prices = get_prices(ticker, start_date, end_date)
if not prices:
return {
"symbol": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "No OHLC data for this range",
}
open_price = float(prices[0].open)
close_price = float(prices[-1].close)
high_price = max(float(price.high) for price in prices)
low_price = min(float(price.low) for price in prices)
total_volume = sum(int(price.volume) for price in prices)
price_change_pct = (
((close_price - open_price) / open_price) * 100 if open_price else 0.0
)
categories = categorize_news_rows(news_rows)
news_count = len(news_rows)
dominant_categories = sorted(
(
{"category": key, "count": value["count"]}
for key, value in categories.items()
if value["count"] > 0
),
key=lambda item: item["count"],
reverse=True,
)
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
category_text = (
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}"
if dominant_categories
else "区间内未识别出明显的主题聚类。"
)
summary = (
f"{ticker}{start_date}{end_date} 区间内{direction} {abs(price_change_pct):.2f}%"
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
)
bullish_factors = summarize_bullish_factors(news_rows)
bearish_factors = summarize_bearish_factors(news_rows)
trend_analysis = build_trend_analysis(prices)
llm_source = "local"
range_payload = {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"price_change_pct": round(price_change_pct, 2),
"trading_days": len(prices),
"news_count": news_count,
"dominant_categories": dominant_categories[:5],
"bullish_factors": bullish_factors[:3],
"bearish_factors": bearish_factors[:3],
"trend_analysis": trend_analysis,
"top_news": [
{
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
"title": row.get("title") or "",
"summary": row.get("summary") or "",
"sentiment": row.get("sentiment") or "",
"relevance": row.get("relevance") or "",
"ret_t0": row.get("ret_t0"),
}
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
],
}
llm_analysis = analyze_range_with_llm(range_payload)
if isinstance(llm_analysis, dict):
summary = llm_analysis.get("summary") or summary
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
llm_source = "llm"
key_events = [
{
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
"title": row.get("title") or "Untitled news",
"summary": row.get("summary") or "",
"category": row.get("category") or "",
"id": row.get("id"),
"sentiment": row.get("sentiment"),
"ret_t0": row.get("ret_t0"),
}
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
]
return {
"symbol": ticker,
"start_date": start_date,
"end_date": end_date,
"price_change_pct": round(price_change_pct, 2),
"open_price": open_price,
"close_price": close_price,
"high_price": high_price,
"low_price": low_price,
"total_volume": total_volume,
"trading_days": len(prices),
"news_count": news_count,
"dominant_categories": dominant_categories[:5],
"analysis": {
"summary": summary,
"key_events": key_events,
"bullish_factors": bullish_factors,
"bearish_factors": bearish_factors,
"trend_analysis": trend_analysis,
"analysis_source": llm_source,
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
},
}