# -*- coding: utf-8 -*- """Rule-based news categorization for explain UI.""" from __future__ import annotations from typing import Any, Dict, Iterable CATEGORY_KEYWORDS = { "market": [ "market", "stock", "rally", "sell-off", "selloff", "trading", "wall street", "s&p", "nasdaq", "dow", "index", "bull", "bear", "correction", "volatility", ], "policy": [ "regulation", "fed", "federal reserve", "tariff", "sanction", "interest rate", "policy", "government", "congress", "sec", "trade war", "ban", "legislation", "tax", ], "earnings": [ "earnings", "revenue", "profit", "quarter", "eps", "guidance", "forecast", "income", "sales", "beat", "miss", "outlook", "financial results", ], "product_tech": [ "product", "ai", "chip", "cloud", "launch", "patent", "technology", "innovation", "release", "platform", "model", "software", "hardware", "gpu", "autonomous", ], "competition": [ "competitor", "rival", "market share", "overtake", "compete", "competition", "vs", "versus", "battle", "challenge", ], "management": [ "ceo", "executive", "resign", "layoff", "restructure", "management", "leadership", "appoint", "hire", "board", "chairman", ], } def categorize_news_rows(rows: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: """Bucket news rows by keyword categories.""" categories: Dict[str, Dict[str, Any]] = { key: { "label": key, "count": 0, "article_ids": [], } for key in CATEGORY_KEYWORDS } for row in rows: text = " ".join( [ str(row.get("title") or ""), str(row.get("summary") or ""), str(row.get("related") or ""), str(row.get("category") or ""), ] ).lower() article_id = row.get("id") for category, keywords in CATEGORY_KEYWORDS.items(): if any(keyword in text for keyword in keywords): categories[category]["count"] += 1 if article_id: categories[category]["article_ids"].append(article_id) return categories