# -*- coding: utf-8 -*- """ WebSocket Gateway for frontend communication """ import asyncio import json import logging from datetime import datetime, timedelta from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set import websockets from websockets.asyncio.server import ServerConnection from backend.config.bootstrap_config import ( resolve_runtime_config, update_bootstrap_values_for_run, ) from backend.data.provider_utils import normalize_symbol from backend.data.market_ingest import ingest_symbols from backend.enrich.llm_enricher import llm_enrichment_enabled from backend.enrich.news_enricher import enrich_news_for_symbol from backend.explain.range_explainer import build_range_explanation from backend.explain.similarity_service import find_similar_days from backend.explain.story_service import get_or_create_stock_story from backend.utils.msg_adapter import FrontendAdapter from backend.utils.terminal_dashboard import get_dashboard from backend.core.pipeline import TradingPipeline from backend.core.state_sync import StateSync from backend.services.market import MarketService from backend.services.storage import StorageService from backend.data.provider_router import get_provider_router from backend.tools.data_tools import get_prices from backend.tools.data_tools import get_company_news logger = logging.getLogger(__name__) class Gateway: """WebSocket Gateway for frontend communication""" def __init__( self, market_service: MarketService, storage_service: StorageService, pipeline: TradingPipeline, state_sync: Optional[StateSync] = None, scheduler_callback: Optional[Callable] = None, config: Dict[str, Any] = None, ): self.market_service = market_service self.storage = storage_service self.pipeline = pipeline self.scheduler_callback = scheduler_callback self.config = config or {} self.mode = self.config.get("mode", "live") self.is_backtest = self.mode == "backtest" or self.config.get( "backtest_mode", False, ) self.state_sync = state_sync or StateSync(storage=storage_service) # self.state_sync.set_mode(self.is_backtest) self.state_sync.set_broadcast_fn(self.broadcast) self.pipeline.state_sync = self.state_sync self.connected_clients: Set[ServerConnection] = set() self.lock = asyncio.Lock() self._backtest_task: Optional[asyncio.Task] = None self._backtest_start_date: Optional[str] = None self._backtest_end_date: Optional[str] = None self._dashboard = get_dashboard() self._market_status_task: Optional[asyncio.Task] = None self._watchlist_ingest_task: Optional[asyncio.Task] = None # Session tracking for live returns self._session_start_portfolio_value: Optional[float] = None self._provider_router = get_provider_router() self._loop: Optional[asyncio.AbstractEventLoop] = None self._project_root = Path(__file__).resolve().parents[2] async def start(self, host: str = "0.0.0.0", port: int = 8766): """Start gateway server""" logger.info(f"Starting gateway on {host}:{port}") self._loop = asyncio.get_running_loop() self._provider_router.add_listener(self._on_provider_usage_changed) # Initialize terminal dashboard self._dashboard.set_config( mode=self.mode, config_name=self.config.get("config_name", "default"), host=host, port=port, poll_interval=self.config.get("poll_interval", 10), mock=self.config.get("mock_mode", False), tickers=self.config.get("tickers", []), initial_cash=self.storage.initial_cash, start_date=self._backtest_start_date or "", end_date=self._backtest_end_date or "", data_sources=self._provider_router.get_usage_snapshot(), ) self._dashboard.start() self.state_sync.load_state() self.market_service.set_price_recorder(self.storage.record_price_point) self.state_sync.update_state("status", "running") self.state_sync.update_state("server_mode", self.mode) self.state_sync.update_state("is_backtest", self.is_backtest) self.state_sync.update_state( "is_mock_mode", self.config.get("mock_mode", False), ) self.state_sync.update_state("tickers", self.config.get("tickers", [])) self.state_sync.update_state( "runtime_config", { "tickers": self.config.get("tickers", []), "initial_cash": self.config.get( "initial_cash", self.storage.initial_cash, ), "margin_requirement": self.config.get("margin_requirement"), "max_comm_cycles": self.config.get("max_comm_cycles"), "enable_memory": self.config.get("enable_memory", False), }, ) self.state_sync.update_state( "data_sources", self._provider_router.get_usage_snapshot(), ) # Load and display existing portfolio state if available summary = self.storage.load_file("summary") if summary: holdings = self.storage.load_file("holdings") or [] trades = self.storage.load_file("trades") or [] current_date = self.state_sync.state.get("current_date") self._dashboard.update( date=current_date or "-", status="running", portfolio=summary, holdings=holdings, trades=trades, ) logger.info( "Loaded existing portfolio: $%s", f"{summary.get('totalAssetValue', 0):,.2f}", ) await self.market_service.start(broadcast_func=self.broadcast) if self.scheduler_callback: await self.scheduler_callback(callback=self.on_strategy_trigger) # Start market status monitoring (only for live mode) if not self.is_backtest: self._market_status_task = asyncio.create_task( self._market_status_monitor(), ) async with websockets.serve( self.handle_client, host, port, ping_interval=30, ping_timeout=60, ): logger.info( f"Gateway started: ws://{host}:{port}, mode={self.mode}", ) await asyncio.Future() def _on_provider_usage_changed(self, snapshot: Dict[str, Any]): """Handle provider routing updates from the shared router.""" self.state_sync.update_state("data_sources", snapshot) self._dashboard.update(data_sources=snapshot) if self._loop and self._loop.is_running(): asyncio.run_coroutine_threadsafe( self.broadcast( { "type": "data_sources_update", "data_sources": snapshot, }, ), self._loop, ) @property def state(self) -> Dict[str, Any]: return self.state_sync.state @staticmethod def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool: if not rows: return True return all( not row.get("sentiment") and not row.get("relevance") and not row.get("key_discussion") for row in rows ) async def handle_client(self, websocket: ServerConnection): """Handle WebSocket client connection""" async with self.lock: self.connected_clients.add(websocket) await self._send_initial_state(websocket) await self._handle_client_messages(websocket) async with self.lock: self.connected_clients.discard(websocket) async def _send_initial_state(self, websocket: ServerConnection): state_payload = self.state_sync.get_initial_state_payload( include_dashboard=True, ) state_payload["data_sources"] = ( self._provider_router.get_usage_snapshot() ) # Include market status in initial state state_payload[ "market_status" ] = self.market_service.get_market_status() # Include live returns if session is active if self.storage.is_live_session_active: live_returns = self.storage.get_live_returns() if "portfolio" in state_payload: state_payload["portfolio"].update(live_returns) await websocket.send( json.dumps( {"type": "initial_state", "state": state_payload}, ensure_ascii=False, default=str, ), ) async def _handle_client_messages( self, websocket: ServerConnection, ): try: async for message in websocket: data = json.loads(message) msg_type = data.get("type", "unknown") if msg_type == "ping": await websocket.send( json.dumps( { "type": "pong", "timestamp": datetime.now().isoformat(), }, ensure_ascii=False, ), ) elif msg_type == "get_state": await self._send_initial_state(websocket) elif msg_type == "start_backtest": await self._handle_start_backtest(data) elif msg_type == "reload_runtime_assets": await self._handle_reload_runtime_assets() elif msg_type == "update_watchlist": await self._handle_update_watchlist(websocket, data) elif msg_type == "get_stock_history": await self._handle_get_stock_history(websocket, data) elif msg_type == "get_stock_explain_events": await self._handle_get_stock_explain_events(websocket, data) elif msg_type == "get_stock_news": await self._handle_get_stock_news(websocket, data) elif msg_type == "get_stock_news_for_date": await self._handle_get_stock_news_for_date(websocket, data) elif msg_type == "get_stock_news_timeline": await self._handle_get_stock_news_timeline(websocket, data) elif msg_type == "get_stock_news_categories": await self._handle_get_stock_news_categories(websocket, data) elif msg_type == "get_stock_range_explain": await self._handle_get_stock_range_explain(websocket, data) elif msg_type == "get_stock_story": await self._handle_get_stock_story(websocket, data) elif msg_type == "get_stock_similar_days": await self._handle_get_stock_similar_days(websocket, data) elif msg_type == "run_stock_enrich": await self._handle_run_stock_enrich(websocket, data) except websockets.ConnectionClosed: pass except json.JSONDecodeError: pass async def _handle_get_stock_history( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) if not ticker: await websocket.send( json.dumps( { "type": "stock_history_loaded", "ticker": "", "prices": [], "source": None, "error": "invalid ticker", }, ensure_ascii=False, ), ) return lookback_days = data.get("lookback_days", 90) try: lookback_days = max(7, min(int(lookback_days), 365)) except (TypeError, ValueError): lookback_days = 90 end_date = self.state_sync.state.get("current_date") if not end_date: end_date = datetime.now().strftime("%Y-%m-%d") try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") except ValueError: end_dt = datetime.now() end_date = end_dt.strftime("%Y-%m-%d") start_date = (end_dt - timedelta(days=lookback_days)).strftime( "%Y-%m-%d", ) prices = await asyncio.to_thread( self.storage.market_store.get_ohlc, ticker, start_date, end_date, ) source = "polygon" if not prices: prices = await asyncio.to_thread( get_prices, ticker, start_date, end_date, ) usage_snapshot = self._provider_router.get_usage_snapshot() source = usage_snapshot.get("last_success", {}).get("prices") if prices: await asyncio.to_thread( self.storage.market_store.upsert_ohlc, ticker, [price.model_dump() for price in prices], source=source or "provider", ) await websocket.send( json.dumps( { "type": "stock_history_loaded", "ticker": ticker, "prices": [ price if isinstance(price, dict) else price.model_dump() for price in prices ][-120:], "source": source, "start_date": start_date, "end_date": end_date, }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_explain_events( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker) await websocket.send( json.dumps( { "type": "stock_explain_events_loaded", "ticker": ticker, "events": snapshot.get("events", []), "signals": snapshot.get("signals", []), "trades": snapshot.get("trades", []), }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_news( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) if not ticker: await websocket.send( json.dumps( { "type": "stock_news_loaded", "ticker": "", "news": [], "source": None, "error": "invalid ticker", }, ensure_ascii=False, ), ) return lookback_days = data.get("lookback_days", 30) limit = data.get("limit", 12) try: lookback_days = max(7, min(int(lookback_days), 180)) except (TypeError, ValueError): lookback_days = 30 try: limit = max(1, min(int(limit), 30)) except (TypeError, ValueError): limit = 12 end_date = self.state_sync.state.get("current_date") if not end_date: end_date = datetime.now().strftime("%Y-%m-%d") try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") except ValueError: end_dt = datetime.now() end_date = end_dt.strftime("%Y-%m-%d") start_date = (end_dt - timedelta(days=lookback_days)).strftime( "%Y-%m-%d", ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, start_date=start_date, end_date=end_date, limit=limit, ) source = "polygon" if self._news_rows_need_enrichment(news_rows): news = await asyncio.to_thread( get_company_news, ticker, end_date, start_date, limit, ) if news: usage_snapshot = self._provider_router.get_usage_snapshot() source = usage_snapshot.get("last_success", {}).get("company_news") await asyncio.to_thread( self.storage.market_store.upsert_news, ticker, [item.model_dump() for item in news], source=source or "provider", ) await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=max(limit, 50), ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, start_date=start_date, end_date=end_date, limit=limit, ) source = source or "market_store" await websocket.send( json.dumps( { "type": "stock_news_loaded", "ticker": ticker, "news": news_rows[-limit:], "source": source, "start_date": start_date, "end_date": end_date, }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_news_for_date( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) trade_date = str(data.get("date") or "").strip() if not ticker or not trade_date: await websocket.send( json.dumps( { "type": "stock_news_for_date_loaded", "ticker": ticker, "date": trade_date, "news": [], "error": "ticker and date are required", }, ensure_ascii=False, ), ) return limit = data.get("limit", 20) try: limit = max(1, min(int(limit), 50)) except (TypeError, ValueError): limit = 20 news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, trade_date=trade_date, limit=limit, ) if self._news_rows_need_enrichment(news_rows): await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=trade_date, end_date=trade_date, limit=limit, ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, trade_date=trade_date, limit=limit, ) await websocket.send( json.dumps( { "type": "stock_news_for_date_loaded", "ticker": ticker, "date": trade_date, "news": news_rows, "source": "market_store", }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_news_timeline( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) if not ticker: await websocket.send( json.dumps( { "type": "stock_news_timeline_loaded", "ticker": "", "timeline": [], "error": "invalid ticker", }, ensure_ascii=False, ), ) return lookback_days = data.get("lookback_days", 90) try: lookback_days = max(7, min(int(lookback_days), 365)) except (TypeError, ValueError): lookback_days = 90 end_date = self.state_sync.state.get("current_date") if not end_date: end_date = datetime.now().strftime("%Y-%m-%d") try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") except ValueError: end_dt = datetime.now() end_date = end_dt.strftime("%Y-%m-%d") start_date = (end_dt - timedelta(days=lookback_days)).strftime( "%Y-%m-%d", ) timeline = await asyncio.to_thread( self.storage.market_store.get_news_timeline_enriched, ticker, start_date=start_date, end_date=end_date, ) if not timeline: await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=200, ) timeline = await asyncio.to_thread( self.storage.market_store.get_news_timeline_enriched, ticker, start_date=start_date, end_date=end_date, ) await websocket.send( json.dumps( { "type": "stock_news_timeline_loaded", "ticker": ticker, "timeline": timeline, "start_date": start_date, "end_date": end_date, }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_news_categories( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) if not ticker: await websocket.send( json.dumps( { "type": "stock_news_categories_loaded", "ticker": "", "categories": {}, "error": "invalid ticker", }, ensure_ascii=False, ), ) return lookback_days = data.get("lookback_days", 90) try: lookback_days = max(7, min(int(lookback_days), 365)) except (TypeError, ValueError): lookback_days = 90 end_date = self.state_sync.state.get("current_date") if not end_date: end_date = datetime.now().strftime("%Y-%m-%d") try: end_dt = datetime.strptime(end_date, "%Y-%m-%d") except ValueError: end_dt = datetime.now() end_date = end_dt.strftime("%Y-%m-%d") start_date = (end_dt - timedelta(days=lookback_days)).strftime( "%Y-%m-%d", ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, start_date=start_date, end_date=end_date, limit=200, ) if self._news_rows_need_enrichment(news_rows): await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=200, ) categories = await asyncio.to_thread( self.storage.market_store.get_news_categories_enriched, ticker, start_date=start_date, end_date=end_date, limit=200, ) await websocket.send( json.dumps( { "type": "stock_news_categories_loaded", "ticker": ticker, "categories": categories, "start_date": start_date, "end_date": end_date, }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_range_explain( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) start_date = str(data.get("start_date") or "").strip() end_date = str(data.get("end_date") or "").strip() if not ticker or not start_date or not end_date: await websocket.send( json.dumps( { "type": "stock_range_explain_loaded", "ticker": ticker, "result": {"error": "ticker, start_date, end_date are required"}, }, ensure_ascii=False, ), ) return article_ids = data.get("article_ids") if isinstance(article_ids, list) and article_ids: news_rows = await asyncio.to_thread( self.storage.market_store.get_news_by_ids_enriched, ticker, article_ids, ) if self._news_rows_need_enrichment(news_rows): await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=100, ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_by_ids_enriched, ticker, article_ids, ) else: news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, start_date=start_date, end_date=end_date, limit=100, ) if not news_rows: await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=100, ) news_rows = await asyncio.to_thread( self.storage.market_store.get_news_items_enriched, ticker, start_date=start_date, end_date=end_date, limit=100, ) result = await asyncio.to_thread( build_range_explanation, ticker=ticker, start_date=start_date, end_date=end_date, news_rows=news_rows, ) await websocket.send( json.dumps( { "type": "stock_range_explain_loaded", "ticker": ticker, "result": result, }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_story( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) if not ticker: await websocket.send( json.dumps( { "type": "stock_story_loaded", "ticker": "", "story": "", "error": "invalid ticker", }, ensure_ascii=False, ), ) return as_of_date = str( data.get("as_of_date") or self.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d") ).strip()[:10] await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, end_date=as_of_date, limit=80, ) result = await asyncio.to_thread( get_or_create_stock_story, self.storage.market_store, symbol=ticker, as_of_date=as_of_date, ) await websocket.send( json.dumps( { "type": "stock_story_loaded", "ticker": ticker, "as_of_date": as_of_date, "story": result.get("story") or "", "source": result.get("source") or "local", }, ensure_ascii=False, default=str, ), ) async def _handle_get_stock_similar_days( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) target_date = str(data.get("date") or "").strip()[:10] if not ticker or not target_date: await websocket.send( json.dumps( { "type": "stock_similar_days_loaded", "ticker": ticker, "date": target_date, "items": [], "error": "ticker and date are required", }, ensure_ascii=False, ), ) return top_k = data.get("top_k", 8) try: top_k = max(1, min(int(top_k), 20)) except (TypeError, ValueError): top_k = 8 await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, end_date=target_date, limit=200, ) result = await asyncio.to_thread( find_similar_days, self.storage.market_store, symbol=ticker, target_date=target_date, top_k=top_k, ) await websocket.send( json.dumps( { "type": "stock_similar_days_loaded", "ticker": ticker, "date": target_date, **result, }, ensure_ascii=False, default=str, ), ) async def _handle_run_stock_enrich( self, websocket: ServerConnection, data: Dict[str, Any], ): ticker = normalize_symbol(data.get("ticker", "")) start_date = str(data.get("start_date") or "").strip()[:10] end_date = str(data.get("end_date") or "").strip()[:10] story_date = str(data.get("story_date") or end_date or "").strip()[:10] target_date = str(data.get("target_date") or "").strip()[:10] force = bool(data.get("force", False)) rebuild_story = bool(data.get("rebuild_story", True)) rebuild_similar_days = bool(data.get("rebuild_similar_days", True)) only_local_to_llm = bool(data.get("only_local_to_llm", False)) limit = data.get("limit", 200) try: limit = max(10, min(int(limit), 500)) except (TypeError, ValueError): limit = 200 if not ticker or not start_date or not end_date: await websocket.send( json.dumps( { "type": "stock_enrich_completed", "ticker": ticker, "start_date": start_date, "end_date": end_date, "error": "ticker, start_date, end_date are required", }, ensure_ascii=False, ), ) return if only_local_to_llm and not llm_enrichment_enabled(): await websocket.send( json.dumps( { "type": "stock_enrich_completed", "ticker": ticker, "start_date": start_date, "end_date": end_date, "error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider", }, ensure_ascii=False, ), ) return result = await asyncio.to_thread( enrich_news_for_symbol, self.storage.market_store, ticker, start_date=start_date, end_date=end_date, limit=limit, skip_existing=not force, only_reanalyze_local=only_local_to_llm, ) story_status = None if rebuild_story and story_date: await asyncio.to_thread( self.storage.market_store.delete_story_cache, ticker, as_of_date=story_date, ) story_result = await asyncio.to_thread( get_or_create_stock_story, self.storage.market_store, symbol=ticker, as_of_date=story_date, ) story_status = { "as_of_date": story_date, "source": story_result.get("source") or "local", } similar_status = None if rebuild_similar_days and target_date: await asyncio.to_thread( self.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date, ) similar_result = await asyncio.to_thread( find_similar_days, self.storage.market_store, symbol=ticker, target_date=target_date, top_k=8, ) similar_status = { "target_date": target_date, "count": len(similar_result.get("items") or []), "error": similar_result.get("error"), } await websocket.send( json.dumps( { "type": "stock_enrich_completed", "ticker": ticker, "start_date": start_date, "end_date": end_date, "story_date": story_date or None, "target_date": target_date or None, "force": force, "only_local_to_llm": only_local_to_llm, "stats": result, "story_status": story_status, "similar_status": similar_status, }, ensure_ascii=False, default=str, ), ) async def _handle_start_backtest(self, data: Dict[str, Any]): if not self.is_backtest: return dates = data.get("dates", []) if dates and self._backtest_task is None: task = asyncio.create_task( self._run_backtest_dates(dates), ) task.add_done_callback(self._handle_backtest_exception) self._backtest_task = task async def _handle_reload_runtime_assets(self): """Reload prompt, skills, and safe runtime config without restart.""" config_name = self.config.get("config_name", "default") runtime_config = resolve_runtime_config( project_root=self._project_root, config_name=config_name, enable_memory=self.config.get("enable_memory", False), ) result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config) runtime_updates = self._apply_runtime_config(runtime_config) await self.state_sync.on_system_message( "Runtime assets reloaded.", ) await self.broadcast( { "type": "runtime_assets_reloaded", **result, **runtime_updates, }, ) async def _handle_update_watchlist( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: """Persist a new watchlist to BOOTSTRAP.md and hot-reload it.""" tickers = self._normalize_watchlist(data.get("tickers")) if not tickers: await websocket.send( json.dumps( { "type": "error", "message": "update_watchlist requires at least one valid ticker.", }, ensure_ascii=False, ), ) return config_name = self.config.get("config_name", "default") update_bootstrap_values_for_run( project_root=self._project_root, config_name=config_name, updates={"tickers": tickers}, ) await self.state_sync.on_system_message( f"Watchlist updated: {', '.join(tickers)}", ) await self.broadcast( { "type": "watchlist_updated", "config_name": config_name, "tickers": tickers, }, ) await self._handle_reload_runtime_assets() self._schedule_watchlist_market_store_refresh(tickers) @staticmethod def _normalize_watchlist(raw_tickers: Any) -> List[str]: """Parse watchlist payloads from websocket messages.""" if raw_tickers is None: return [] if isinstance(raw_tickers, str): candidates = raw_tickers.split(",") elif isinstance(raw_tickers, list): candidates = raw_tickers else: candidates = [raw_tickers] tickers: List[str] = [] for candidate in candidates: symbol = normalize_symbol(str(candidate).strip().strip("\"'")) if symbol and symbol not in tickers: tickers.append(symbol) return tickers def _apply_runtime_config( self, runtime_config: Dict[str, Any], ) -> Dict[str, Any]: """Apply runtime config to gateway-owned services and state.""" warnings: List[str] = [] ticker_changes = self.market_service.update_tickers( runtime_config.get("tickers", []), ) self.config["tickers"] = ticker_changes["active"] self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"]) self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config( margin_requirement=runtime_config["margin_requirement"], ) self.config["margin_requirement"] = self.pipeline.pm.portfolio.get( "margin_requirement", runtime_config["margin_requirement"], ) requested_initial_cash = float(runtime_config["initial_cash"]) current_initial_cash = float(self.storage.initial_cash) initial_cash_applied = requested_initial_cash == current_initial_cash if not initial_cash_applied: if ( self.storage.can_apply_initial_cash() and self.pipeline.pm.can_apply_initial_cash() ): initial_cash_applied = self.storage.apply_initial_cash( requested_initial_cash, ) if initial_cash_applied: self.pipeline.pm.apply_runtime_portfolio_config( initial_cash=requested_initial_cash, ) self.config["initial_cash"] = self.storage.initial_cash else: warnings.append( "initial_cash changed in BOOTSTRAP.md but was not applied " "because the run already has positions, margin usage, or trades.", ) requested_enable_memory = bool(runtime_config["enable_memory"]) current_enable_memory = bool(self.config.get("enable_memory", False)) if requested_enable_memory != current_enable_memory: warnings.append( "enable_memory changed in BOOTSTRAP.md but still requires a restart " "because long-term memory contexts are created at startup.", ) self._sync_runtime_state() return { "runtime_config_requested": runtime_config, "runtime_config_applied": { "tickers": list(self.config.get("tickers", [])), "initial_cash": self.storage.initial_cash, "margin_requirement": self.config["margin_requirement"], "max_comm_cycles": self.config["max_comm_cycles"], "enable_memory": self.config.get("enable_memory", False), }, "runtime_config_status": { "tickers": True, "initial_cash": initial_cash_applied, "margin_requirement": pm_apply_result["margin_requirement"], "max_comm_cycles": True, "enable_memory": requested_enable_memory == current_enable_memory, }, "ticker_changes": ticker_changes, "runtime_config_warnings": warnings, } def _sync_runtime_state(self) -> None: """Refresh persisted state and dashboard after runtime config changes.""" self.state_sync.update_state("tickers", self.config.get("tickers", [])) self.state_sync.update_state( "runtime_config", { "tickers": self.config.get("tickers", []), "initial_cash": self.storage.initial_cash, "margin_requirement": self.config.get("margin_requirement"), "max_comm_cycles": self.config.get("max_comm_cycles"), "enable_memory": self.config.get("enable_memory", False), }, ) self.storage.update_server_state_from_dashboard(self.state_sync.state) self.state_sync.save_state() self._dashboard.tickers = list(self.config.get("tickers", [])) self._dashboard.initial_cash = self.storage.initial_cash self._dashboard.enable_memory = bool( self.config.get("enable_memory", False), ) summary = self.storage.load_file("summary") or {} holdings = self.storage.load_file("holdings") or [] trades = self.storage.load_file("trades") or [] self._dashboard.update( portfolio=summary, holdings=holdings, trades=trades, ) def _schedule_watchlist_market_store_refresh( self, tickers: List[str], ) -> None: """Kick off a non-blocking Polygon refresh for the updated watchlist.""" if not tickers: return if self._watchlist_ingest_task and not self._watchlist_ingest_task.done(): self._watchlist_ingest_task.cancel() self._watchlist_ingest_task = asyncio.create_task( self._refresh_market_store_for_watchlist(tickers), ) async def _refresh_market_store_for_watchlist( self, tickers: List[str], ) -> None: """Refresh the long-lived market store after a watchlist update.""" try: await self.state_sync.on_system_message( f"正在同步自选股市场数据: {', '.join(tickers)}", ) results = await asyncio.to_thread( ingest_symbols, tickers, mode="incremental", ) summary = ", ".join( f"{item['symbol']} prices={item['prices']} news={item['news']}" for item in results ) await self.state_sync.on_system_message( f"自选股市场数据已同步: {summary}", ) except asyncio.CancelledError: raise except Exception as exc: logger.warning("Watchlist market store refresh failed: %s", exc) await self.state_sync.on_system_message( f"自选股市场数据同步失败: {exc}", ) async def broadcast(self, message: Dict[str, Any]): """Broadcast message to all connected clients""" if not self.connected_clients: return message_json = json.dumps(message, ensure_ascii=False, default=str) async with self.lock: tasks = [ self._send_to_client(client, message_json) for client in self.connected_clients.copy() ] if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def _send_to_client( self, client: ServerConnection, message: str, ): try: await client.send(message) except websockets.ConnectionClosed: async with self.lock: self.connected_clients.discard(client) async def _market_status_monitor(self): """Periodically check and broadcast market status changes""" while True: try: await self.market_service.check_and_broadcast_market_status() # On market open, start live session tracking status = self.market_service.get_market_status() if ( status["status"] == "open" and not self.storage.is_live_session_active ): self.storage.start_live_session() summary = self.storage.load_file("summary") or {} self._session_start_portfolio_value = summary.get( "totalAssetValue", self.storage.initial_cash, ) logger.info( "Session start portfolio: " f"${self._session_start_portfolio_value:,.2f}", ) elif ( status["status"] != "open" and self.storage.is_live_session_active ): self.storage.end_live_session() self._session_start_portfolio_value = None # Update and broadcast live returns if session is active if self.storage.is_live_session_active: await self._update_and_broadcast_live_returns() await asyncio.sleep(60) # Check every minute except asyncio.CancelledError: break except Exception as e: logger.error(f"Market status monitor error: {e}") await asyncio.sleep(60) async def _update_and_broadcast_live_returns(self): """Calculate and broadcast live returns for current session""" if not self.storage.is_live_session_active: return # Get current prices and calculate portfolio value prices = self.market_service.get_all_prices() if not prices or not any(p > 0 for p in prices.values()): return # Load current internal state to get baseline values state = self.storage.load_internal_state() # Get latest values from history (if available) equity_history = state.get("equity_history", []) baseline_history = state.get("baseline_history", []) baseline_vw_history = state.get("baseline_vw_history", []) momentum_history = state.get("momentum_history", []) current_equity = equity_history[-1]["v"] if equity_history else None current_baseline = ( baseline_history[-1]["v"] if baseline_history else None ) current_baseline_vw = ( baseline_vw_history[-1]["v"] if baseline_vw_history else None ) current_momentum = ( momentum_history[-1]["v"] if momentum_history else None ) # Update live returns with current values point = self.storage.update_live_returns( current_equity=current_equity, current_baseline=current_baseline, current_baseline_vw=current_baseline_vw, current_momentum=current_momentum, ) # Broadcast if we have new data if point: live_returns = self.storage.get_live_returns() await self.broadcast( { "type": "team_summary", "equity_return": live_returns["equity_return"], "baseline_return": live_returns["baseline_return"], "baseline_vw_return": live_returns["baseline_vw_return"], "momentum_return": live_returns["momentum_return"], }, ) async def on_strategy_trigger(self, date: str): """Handle trading cycle trigger""" logger.info(f"Strategy triggered for {date}") tickers = self.config.get("tickers", []) if self.is_backtest: await self._run_backtest_cycle(date, tickers) else: await self._run_live_cycle(date, tickers) async def _run_backtest_cycle(self, date: str, tickers: List[str]): """Run backtest cycle with pre-loaded prices""" self.market_service.set_backtest_date(date) await self.market_service.emit_market_open() await self.state_sync.on_cycle_start(date) self._dashboard.update(date=date, status="Analyzing...") prices = self.market_service.get_open_prices() close_prices = self.market_service.get_close_prices() market_caps = self._get_market_caps(tickers, date) result = await self.pipeline.run_cycle( tickers=tickers, date=date, prices=prices, close_prices=close_prices, market_caps=market_caps, ) await self.market_service.emit_market_close() settlement_result = result.get("settlement_result") self._save_cycle_results(result, date, close_prices, settlement_result) await self._broadcast_portfolio_updates(result, close_prices) await self._finalize_cycle(date) async def _run_live_cycle(self, date: str, tickers: List[str]): """ Run live cycle with real market timing. - Analysis runs immediately - Execution waits for market open (or uses current prices if already open) - Settlement waits for market close """ # Get actual trading date (might be next trading day if weekend) trading_date = self.market_service.get_live_trading_date() logger.info( f"Live cycle: triggered={date}, trading_date={trading_date}", ) await self.state_sync.on_cycle_start(trading_date) self._dashboard.update(date=trading_date, status="Analyzing...") market_caps = self._get_market_caps(tickers, trading_date) # Run pipeline with async price callbacks result = await self.pipeline.run_cycle( tickers=tickers, date=trading_date, market_caps=market_caps, get_open_prices_fn=self.market_service.wait_for_open_prices, get_close_prices_fn=self.market_service.wait_for_close_prices, ) close_prices = self.market_service.get_all_prices() settlement_result = result.get("settlement_result") self._save_cycle_results( result, trading_date, close_prices, settlement_result, ) await self._broadcast_portfolio_updates(result, close_prices) await self._finalize_cycle(trading_date) async def _finalize_cycle(self, date: str): """Finalize cycle: broadcast state and update dashboard""" summary = self.storage.load_file("summary") or {} # Include live returns if session is active if self.storage.is_live_session_active: live_returns = self.storage.get_live_returns() summary.update(live_returns) await self.state_sync.on_cycle_end(date, portfolio_summary=summary) holdings = self.storage.load_file("holdings") or [] trades = self.storage.load_file("trades") or [] leaderboard = self.storage.load_file("leaderboard") or [] if leaderboard: await self.state_sync.on_leaderboard_update(leaderboard) self._dashboard.update( date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades, ) def _get_market_caps( self, tickers: List[str], date: str, ) -> Dict[str, float]: """ Get market caps for tickers (stub implementation) Args: tickers: List of tickers date: Trading date Returns: Dict mapping ticker to market cap """ from ..tools.data_tools import get_market_cap market_caps = {} for ticker in tickers: try: market_cap = get_market_cap(ticker, date) if market_cap: market_caps[ticker] = market_cap else: market_caps[ticker] = 1e9 except Exception: market_caps[ticker] = 1e9 return market_caps async def _broadcast_portfolio_updates( self, result: Dict[str, Any], prices: Dict[str, float], ): portfolio = result.get("portfolio", {}) if portfolio: holdings = FrontendAdapter.build_holdings(portfolio, prices) if holdings: await self.state_sync.on_holdings_update(holdings) stats = FrontendAdapter.build_stats(portfolio, prices) if stats: await self.state_sync.on_stats_update(stats) executed_trades = result.get("executed_trades", []) if executed_trades: await self.state_sync.on_trades_executed(executed_trades) def _save_cycle_results( self, result: Dict[str, Any], date: str, prices: Dict[str, float], settlement_result: Optional[Dict[str, Any]] = None, ): portfolio = result.get("portfolio", {}) executed_trades = result.get("executed_trades", []) # Extract baseline values from settlement result baseline_values = None if settlement_result: baseline_values = settlement_result.get("baseline_values") if portfolio: self.storage.update_dashboard_after_cycle( portfolio=portfolio, prices=prices, date=date, executed_trades=executed_trades, baseline_values=baseline_values, ) async def _run_backtest_dates(self, dates: List[str]): self.state_sync.set_backtest_dates(dates) self._dashboard.update(days_total=len(dates), days_completed=0) await self.state_sync.on_system_message( f"Starting backtest - {len(dates)} trading days", ) try: for i, date in enumerate(dates): self._dashboard.update(days_completed=i) await self.on_strategy_trigger(date=date) await asyncio.sleep(0.1) await self.state_sync.on_system_message( f"Backtest complete - {len(dates)} days", ) # Update dashboard with final state summary = self.storage.load_file("summary") or {} self._dashboard.update( status="Complete", portfolio=summary, days_completed=len(dates), ) self._dashboard.stop() self._dashboard.print_final_summary() except Exception as e: error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}" logger.error(error_msg, exc_info=True) await self.state_sync.on_system_message(error_msg) self._dashboard.update(status=f"Failed: {str(e)}") self._dashboard.stop() raise finally: self._backtest_task = None def _handle_backtest_exception(self, task: asyncio.Task): """Handle exceptions from backtest task""" try: task.result() except asyncio.CancelledError: logger.info("Backtest task was cancelled") except Exception as e: logger.error( f"Backtest task failed with exception:{type(e).__name__}:{e}", exc_info=True, ) def set_backtest_dates(self, dates: List[str]): self.state_sync.set_backtest_dates(dates) if dates: self._backtest_start_date = dates[0] self._backtest_end_date = dates[-1] self._dashboard.days_total = len(dates) def stop(self): self.state_sync.save_state() self.market_service.stop() if self._backtest_task: self._backtest_task.cancel() if self._market_status_task: self._market_status_task.cancel() if self._watchlist_ingest_task: self._watchlist_ingest_task.cancel() self._dashboard.stop()