# -*- coding: utf-8 -*- """Query-oriented storage for explain/research data.""" from __future__ import annotations import json import sqlite3 from datetime import datetime from pathlib import Path from typing import Any, Iterable from shared.schema import CompanyNews SCHEMA = """ CREATE TABLE IF NOT EXISTS news_items ( id TEXT PRIMARY KEY, ticker TEXT NOT NULL, published_at TEXT, trade_date TEXT, source TEXT, title TEXT NOT NULL, summary TEXT, url TEXT, related TEXT, category TEXT, raw_json TEXT NOT NULL, ingest_run_date TEXT, created_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date ON news_items (ticker, trade_date DESC, published_at DESC); """ def _json_dumps(value: Any) -> str: return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str) def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str: base = item.url or item.title or f"{ticker}-{fallback_index}" return f"{ticker}:{base}" def _resolve_trade_date(date_value: str | None) -> str | None: if not date_value: return None normalized = str(date_value).strip() if not normalized: return None if "T" in normalized: return normalized.split("T", 1)[0] if " " in normalized: return normalized.split(" ", 1)[0] return normalized[:10] class ResearchDb: """Small SQLite helper for explain-oriented news storage.""" def __init__(self, db_path: Path): self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_db() def _connect(self) -> sqlite3.Connection: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") return conn def _init_db(self): with self._connect() as conn: conn.executescript(SCHEMA) def upsert_news_items( self, *, ticker: str, items: Iterable[CompanyNews], ingest_run_date: str | None = None, ) -> list[dict[str, Any]]: """Persist provider news and return normalized rows.""" normalized_rows: list[dict[str, Any]] = [] timestamp = datetime.utcnow().isoformat(timespec="seconds") symbol = str(ticker or "").strip().upper() if not symbol: return normalized_rows with self._connect() as conn: for index, item in enumerate(items): news_id = _resolve_news_id(symbol, item, index) trade_date = _resolve_trade_date(item.date) payload = item.model_dump() row = { "id": news_id, "ticker": symbol, "published_at": item.date, "trade_date": trade_date, "source": item.source, "title": item.title, "summary": item.summary, "url": item.url, "related": item.related, "category": item.category, "raw_json": _json_dumps(payload), "ingest_run_date": ingest_run_date, "created_at": timestamp, } conn.execute( """ INSERT INTO news_items (id, ticker, published_at, trade_date, source, title, summary, url, related, category, raw_json, ingest_run_date, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET ticker = excluded.ticker, published_at = excluded.published_at, trade_date = excluded.trade_date, source = excluded.source, title = excluded.title, summary = excluded.summary, url = excluded.url, related = excluded.related, category = excluded.category, raw_json = excluded.raw_json, ingest_run_date = excluded.ingest_run_date """, ( row["id"], row["ticker"], row["published_at"], row["trade_date"], row["source"], row["title"], row["summary"], row["url"], row["related"], row["category"], row["raw_json"], row["ingest_run_date"], row["created_at"], ), ) normalized_rows.append(row) return normalized_rows def get_news_items( self, *, ticker: str, start_date: str | None = None, end_date: str | None = None, limit: int = 20, ) -> list[dict[str, Any]]: """Return normalized news rows for explain UI.""" symbol = str(ticker or "").strip().upper() if not symbol: return [] sql = """ SELECT id, ticker, published_at, trade_date, source, title, summary, url, related, category FROM news_items WHERE ticker = ? """ params: list[Any] = [symbol] if start_date: sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?" params.append(start_date) if end_date: sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?" params.append(end_date) sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?" params.append(max(1, int(limit))) with self._connect() as conn: rows = conn.execute(sql, params).fetchall() return [ { "id": row["id"], "ticker": row["ticker"], "date": row["published_at"] or row["trade_date"], "trade_date": row["trade_date"], "source": row["source"], "title": row["title"], "summary": row["summary"], "url": row["url"], "related": row["related"], "category": row["category"], } for row in rows ] def get_news_timeline( self, *, ticker: str, start_date: str | None = None, end_date: str | None = None, ) -> list[dict[str, Any]]: """Aggregate news counts per trade date for chart markers.""" symbol = str(ticker or "").strip().upper() if not symbol: return [] sql = """ SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date, COUNT(*) AS count, COUNT(DISTINCT source) AS source_count, MAX(title) AS top_title FROM news_items WHERE ticker = ? """ params: list[Any] = [symbol] if start_date: sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?" params.append(start_date) if end_date: sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?" params.append(end_date) sql += """ GROUP BY COALESCE(trade_date, substr(published_at, 1, 10)) ORDER BY date ASC """ with self._connect() as conn: rows = conn.execute(sql, params).fetchall() return [ { "date": row["date"], "count": int(row["count"] or 0), "source_count": int(row["source_count"] or 0), "top_title": row["top_title"] or "", } for row in rows if row["date"] ] def get_news_by_ids( self, *, ticker: str, article_ids: Iterable[str], ) -> list[dict[str, Any]]: """Return selected persisted news items.""" symbol = str(ticker or "").strip().upper() ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()] if not symbol or not ids: return [] placeholders = ",".join("?" for _ in ids) sql = f""" SELECT id, ticker, published_at, trade_date, source, title, summary, url, related, category FROM news_items WHERE ticker = ? AND id IN ({placeholders}) ORDER BY COALESCE(published_at, trade_date) DESC """ with self._connect() as conn: rows = conn.execute(sql, [symbol, *ids]).fetchall() return [ { "id": row["id"], "ticker": row["ticker"], "date": row["published_at"] or row["trade_date"], "trade_date": row["trade_date"], "source": row["source"], "title": row["title"], "summary": row["summary"], "url": row["url"], "related": row["related"], "category": row["category"], } for row in rows ]