Files
evotraders/backend/enrich/llm_enricher.py
2026-03-30 17:46:44 +08:00

302 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""Optional AgentScope-backed news enrichment with safe local fallback."""
from __future__ import annotations
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from pydantic import BaseModel, Field
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
from backend.llm.models import create_model
logger = logging.getLogger(__name__)
class EnrichedNewsItem(BaseModel):
"""Structured output schema for one enriched article."""
id: str = Field(description="The source article id")
relevance: str = Field(description="One of high, medium, low")
sentiment: str = Field(description="One of positive, negative, neutral")
key_discussion: str = Field(description="Concise core discussion")
summary: str = Field(description="Concise factual summary")
reason_growth: str = Field(description="Growth-oriented reason if present")
reason_decrease: str = Field(description="Downside-oriented reason if present")
class EnrichedNewsBatch(BaseModel):
"""Structured output schema for a batch of enriched articles."""
items: list[EnrichedNewsItem]
class RangeAnalysisPayload(BaseModel):
"""Structured output schema for range explanation text."""
summary: str = Field(description="Concise Chinese range summary for the selected window")
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
def get_explain_model_info() -> dict[str, str]:
"""Resolve provider/model used by explain enrichment."""
provider = canonicalize_model_provider(
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
or get_env_str("MODEL_PROVIDER", "OPENAI"),
)
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
"MODEL_NAME",
"gpt-4o-mini",
)
return {
"provider": provider,
"model_name": model_name,
"label": f"{provider}:{model_name}",
}
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
if isinstance(payload, BaseModel):
payload = payload.model_dump()
if not isinstance(payload, dict):
return None
return {
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
"summary": str(payload.get("summary") or "").strip() or None,
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
"raw_json": payload,
}
def _run_async(coro: Any) -> Any:
"""Run an async AgentScope model call from sync code, even inside a running loop."""
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(coro)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(asyncio.run, coro)
return future.result()
def _get_explain_model():
"""Create an AgentScope model for explain enrichment."""
model_info = get_explain_model_info()
return create_model(
model_name=model_info["model_name"],
provider=model_info["provider"],
stream=False,
generate_kwargs={"temperature": 0.1},
)
def llm_enrichment_enabled() -> bool:
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
return False
provider = get_explain_model_info()["provider"]
provider_key_map = {
"OPENAI": "OPENAI_API_KEY",
"ANTHROPIC": "ANTHROPIC_API_KEY",
"DASHSCOPE": "DASHSCOPE_API_KEY",
"ALIBABA": "DASHSCOPE_API_KEY",
"GEMINI": "GOOGLE_API_KEY",
"GOOGLE": "GOOGLE_API_KEY",
"DEEPSEEK": "DEEPSEEK_API_KEY",
"GROQ": "GROQ_API_KEY",
"OPENROUTER": "OPENROUTER_API_KEY",
}
env_key = provider_key_map.get(provider)
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
def llm_range_analysis_enabled() -> bool:
"""Return whether LLM range analysis should be attempted."""
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
if raw_value is not None and str(raw_value).strip() != "":
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
return llm_enrichment_enabled()
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
"""Generate explain-oriented structured analysis for one article."""
if not llm_enrichment_enabled():
return None
model = _get_explain_model()
title = str(row.get("title") or "").strip()
summary = str(row.get("summary") or "").strip()
messages = [
{
"role": "system",
"content": (
"You produce concise structured financial news analysis. "
"Use only the requested fields and keep content factual."
),
},
{
"role": "user",
"content": (
"Analyze this stock-news article for an explain UI.\n"
"Rules:\n"
"- relevance must be one of: high, medium, low\n"
"- sentiment must be one of: positive, negative, neutral\n"
"- keep each text field concise and factual\n"
f"- article id: {str(row.get('id') or '').strip()}\n"
f"Title: {title}\n"
f"Summary: {summary}\n"
),
},
]
try:
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
except Exception as e:
logger.warning(f"LLM enrichment failed: {e}")
return None
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
if payload:
payload.setdefault("raw_json", {})
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
return payload
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
"""Generate structured analysis for multiple articles in one request."""
if not llm_enrichment_enabled() or not rows:
return {}
payload_rows = [
{
"id": str(row.get("id") or "").strip(),
"title": str(row.get("title") or "").strip(),
"summary": str(row.get("summary") or "").strip(),
}
for row in rows
if str(row.get("id") or "").strip()
]
if not payload_rows:
return {}
model = _get_explain_model()
messages = [
{
"role": "system",
"content": (
"You produce concise structured financial news analysis in JSON. "
"Preserve ids exactly and do not invent extra items."
),
},
{
"role": "user",
"content": (
"Analyze these stock-news articles for an explain UI.\n"
"For each item return: id, relevance, sentiment, key_discussion, summary, "
"reason_growth, reason_decrease.\n"
"Rules:\n"
"- relevance must be one of: high, medium, low\n"
"- sentiment must be one of: positive, negative, neutral\n"
"- keep all text concise and factual\n"
f"Articles: {payload_rows}"
),
},
]
try:
response = _run_async(
model(messages=messages, structured_model=EnrichedNewsBatch),
)
except Exception:
return {}
metadata = getattr(response, "metadata", None)
if isinstance(metadata, BaseModel):
metadata = metadata.model_dump()
items = metadata.get("items") if isinstance(metadata, dict) else None
if not isinstance(items, list):
return {}
results: dict[str, dict[str, Any]] = {}
for item in items:
normalized = _normalize_enrichment_payload(item)
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
if normalized and news_id:
normalized.setdefault("raw_json", {})
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
results[news_id] = normalized
return results
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
"""Generate explain-oriented range summary and factor refinement."""
if not llm_range_analysis_enabled():
return None
model = _get_explain_model()
messages = [
{
"role": "system",
"content": (
"You write concise Chinese stock range analysis for an explain UI. "
"Use only the supplied facts. Keep the tone factual and analyst-like."
),
},
{
"role": "user",
"content": (
"请基于给定事实生成区间分析。\n"
"输出字段summary, trend_analysis, bullish_factors, bearish_factors。\n"
"要求:\n"
"- 全部使用简体中文\n"
"- summary 1到2句概括区间走势、新闻密度和主导主题\n"
"- trend_analysis 1句解释区间内部阶段变化\n"
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
"- 不要编造未提供的信息\n"
f"事实数据: {payload}"
),
},
]
try:
response = _run_async(
model(messages=messages, structured_model=RangeAnalysisPayload),
)
except Exception as e:
logger.warning(f"LLM enrichment failed: {e}")
return None
metadata = getattr(response, "metadata", None)
if isinstance(metadata, BaseModel):
metadata = metadata.model_dump()
if not isinstance(metadata, dict):
return None
return {
"summary": str(metadata.get("summary") or "").strip() or None,
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
"bullish_factors": [
str(item).strip()
for item in list(metadata.get("bullish_factors") or [])
if str(item).strip()
][:3],
"bearish_factors": [
str(item).strip()
for item in list(metadata.get("bearish_factors") or [])
if str(item).strip()
][:3],
"model_provider": get_explain_model_info()["provider"],
"model_name": get_explain_model_info()["model_name"],
"model_label": get_explain_model_info()["label"],
}