Initial commit of integrated agent system
This commit is contained in:
301
backend/enrich/llm_enricher.py
Normal file
301
backend/enrich/llm_enricher.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Optional AgentScope-backed news enrichment with safe local fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
|
||||
from backend.llm.models import create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnrichedNewsItem(BaseModel):
|
||||
"""Structured output schema for one enriched article."""
|
||||
|
||||
id: str = Field(description="The source article id")
|
||||
relevance: str = Field(description="One of high, medium, low")
|
||||
sentiment: str = Field(description="One of positive, negative, neutral")
|
||||
key_discussion: str = Field(description="Concise core discussion")
|
||||
summary: str = Field(description="Concise factual summary")
|
||||
reason_growth: str = Field(description="Growth-oriented reason if present")
|
||||
reason_decrease: str = Field(description="Downside-oriented reason if present")
|
||||
|
||||
|
||||
class EnrichedNewsBatch(BaseModel):
|
||||
"""Structured output schema for a batch of enriched articles."""
|
||||
|
||||
items: list[EnrichedNewsItem]
|
||||
|
||||
|
||||
class RangeAnalysisPayload(BaseModel):
|
||||
"""Structured output schema for range explanation text."""
|
||||
|
||||
summary: str = Field(description="Concise Chinese range summary for the selected window")
|
||||
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
|
||||
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
|
||||
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
|
||||
|
||||
|
||||
def get_explain_model_info() -> dict[str, str]:
|
||||
"""Resolve provider/model used by explain enrichment."""
|
||||
provider = canonicalize_model_provider(
|
||||
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
|
||||
or get_env_str("MODEL_PROVIDER", "OPENAI"),
|
||||
)
|
||||
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
|
||||
"MODEL_NAME",
|
||||
"gpt-4o-mini",
|
||||
)
|
||||
return {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"label": f"{provider}:{model_name}",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
|
||||
if isinstance(payload, BaseModel):
|
||||
payload = payload.model_dump()
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return {
|
||||
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
|
||||
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
|
||||
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
|
||||
"summary": str(payload.get("summary") or "").strip() or None,
|
||||
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
|
||||
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
|
||||
"raw_json": payload,
|
||||
}
|
||||
|
||||
|
||||
def _run_async(coro: Any) -> Any:
|
||||
"""Run an async AgentScope model call from sync code, even inside a running loop."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_explain_model():
|
||||
"""Create an AgentScope model for explain enrichment."""
|
||||
model_info = get_explain_model_info()
|
||||
return create_model(
|
||||
model_name=model_info["model_name"],
|
||||
provider=model_info["provider"],
|
||||
stream=False,
|
||||
generate_kwargs={"temperature": 0.1},
|
||||
)
|
||||
|
||||
|
||||
def llm_enrichment_enabled() -> bool:
|
||||
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
|
||||
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
|
||||
return False
|
||||
provider = get_explain_model_info()["provider"]
|
||||
provider_key_map = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
env_key = provider_key_map.get(provider)
|
||||
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
|
||||
|
||||
|
||||
def llm_range_analysis_enabled() -> bool:
|
||||
"""Return whether LLM range analysis should be attempted."""
|
||||
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
|
||||
if raw_value is not None and str(raw_value).strip() != "":
|
||||
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
|
||||
return llm_enrichment_enabled()
|
||||
|
||||
|
||||
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented structured analysis for one article."""
|
||||
if not llm_enrichment_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis. "
|
||||
"Use only the requested fields and keep content factual."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze this stock-news article for an explain UI.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep each text field concise and factual\n"
|
||||
f"- article id: {str(row.get('id') or '').strip()}\n"
|
||||
f"Title: {title}\n"
|
||||
f"Summary: {summary}\n"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
|
||||
if payload:
|
||||
payload.setdefault("raw_json", {})
|
||||
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
return payload
|
||||
|
||||
|
||||
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
"""Generate structured analysis for multiple articles in one request."""
|
||||
if not llm_enrichment_enabled() or not rows:
|
||||
return {}
|
||||
|
||||
payload_rows = [
|
||||
{
|
||||
"id": str(row.get("id") or "").strip(),
|
||||
"title": str(row.get("title") or "").strip(),
|
||||
"summary": str(row.get("summary") or "").strip(),
|
||||
}
|
||||
for row in rows
|
||||
if str(row.get("id") or "").strip()
|
||||
]
|
||||
if not payload_rows:
|
||||
return {}
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis in JSON. "
|
||||
"Preserve ids exactly and do not invent extra items."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze these stock-news articles for an explain UI.\n"
|
||||
"For each item return: id, relevance, sentiment, key_discussion, summary, "
|
||||
"reason_growth, reason_decrease.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep all text concise and factual\n"
|
||||
f"Articles: {payload_rows}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=EnrichedNewsBatch),
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
items = metadata.get("items") if isinstance(metadata, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return {}
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for item in items:
|
||||
normalized = _normalize_enrichment_payload(item)
|
||||
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
|
||||
if normalized and news_id:
|
||||
normalized.setdefault("raw_json", {})
|
||||
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
results[news_id] = normalized
|
||||
return results
|
||||
|
||||
|
||||
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented range summary and factor refinement."""
|
||||
if not llm_range_analysis_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You write concise Chinese stock range analysis for an explain UI. "
|
||||
"Use only the supplied facts. Keep the tone factual and analyst-like."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"请基于给定事实生成区间分析。\n"
|
||||
"输出字段:summary, trend_analysis, bullish_factors, bearish_factors。\n"
|
||||
"要求:\n"
|
||||
"- 全部使用简体中文\n"
|
||||
"- summary 1到2句,概括区间走势、新闻密度和主导主题\n"
|
||||
"- trend_analysis 1句,解释区间内部阶段变化\n"
|
||||
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
|
||||
"- 不要编造未提供的信息\n"
|
||||
f"事实数据: {payload}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=RangeAnalysisPayload),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"summary": str(metadata.get("summary") or "").strip() or None,
|
||||
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
|
||||
"bullish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bullish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"bearish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bearish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"model_provider": get_explain_model_info()["provider"],
|
||||
"model_name": get_explain_model_info()["model_name"],
|
||||
"model_label": get_explain_model_info()["label"],
|
||||
}
|
||||
Reference in New Issue
Block a user