512 lines
17 KiB
Python
512 lines
17 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
StateSync - Centralized state synchronization between agents and frontend
|
|
Handles real-time updates, persistence, and replay support
|
|
"""
|
|
# pylint: disable=R0904
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
from ..services.storage import StorageService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class StateSync:
|
|
"""
|
|
Central event dispatcher for agent-frontend synchronization
|
|
|
|
Responsibilities:
|
|
1. Receive events from agents/pipeline
|
|
2. Persist to storage (feed_history)
|
|
3. Broadcast to frontend via WebSocket
|
|
4. Support replay from saved state
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
storage: StorageService,
|
|
broadcast_fn: Optional[Callable] = None,
|
|
):
|
|
"""
|
|
Initialize StateSync
|
|
|
|
Args:
|
|
storage: Storage service for persistence
|
|
broadcast_fn: Async broadcast function - async def broadcast(event: dict) # noqa: E501
|
|
"""
|
|
self.storage = storage
|
|
self._broadcast_fn = broadcast_fn
|
|
self._state: Dict[str, Any] = {}
|
|
self._enabled = True
|
|
self._simulation_date: Optional[str] = None # For backtest timestamps
|
|
|
|
def set_simulation_date(self, date: str):
|
|
"""Set current simulation date for backtest-compatible timestamps"""
|
|
self._simulation_date = date
|
|
|
|
def clear_simulation_date(self):
|
|
"""Disable backtest timestamp simulation and use wall-clock time."""
|
|
self._simulation_date = None
|
|
|
|
def _get_timestamp_ms(self) -> int:
|
|
"""
|
|
Get timestamp in milliseconds.
|
|
Uses simulation date if set (backtest mode), otherwise current time.
|
|
"""
|
|
if self._simulation_date:
|
|
# Parse date and use market close time (16:00) for backtest
|
|
dt = datetime.strptime(
|
|
f"{self._simulation_date}",
|
|
"%Y-%m-%d",
|
|
)
|
|
return int(dt.timestamp() * 1000)
|
|
return int(datetime.now().timestamp() * 1000)
|
|
|
|
def load_state(self):
|
|
"""Load server state from storage"""
|
|
self._state = self.storage.load_server_state()
|
|
self.storage.update_server_state_from_dashboard(self._state)
|
|
logger.info(
|
|
f"StateSync loaded: {len(self._state.get('feed_history', []))} feeds", # noqa: E501
|
|
)
|
|
|
|
def save_state(self):
|
|
"""Save current state to storage"""
|
|
self.storage.save_server_state(self._state)
|
|
|
|
@property
|
|
def state(self) -> Dict[str, Any]:
|
|
"""Get current state"""
|
|
return self._state
|
|
|
|
def set_broadcast_fn(self, fn: Callable):
|
|
"""Set broadcast function (supports late binding)"""
|
|
self._broadcast_fn = fn
|
|
|
|
def update_state(self, key: str, value: Any):
|
|
"""Update a state field"""
|
|
self._state[key] = value
|
|
|
|
async def emit(self, event: Dict[str, Any], persist: bool = True):
|
|
"""
|
|
Emit an event - persists and broadcasts
|
|
|
|
Args:
|
|
event: Event dictionary, must contain "type"
|
|
persist: Whether to persist to feed_history
|
|
"""
|
|
if not self._enabled:
|
|
return
|
|
|
|
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
|
|
# frontend displays local wall time correctly instead of date-only UTC.
|
|
if "timestamp" not in event:
|
|
ts_ms = event.get("ts")
|
|
if ts_ms is not None:
|
|
try:
|
|
event["timestamp"] = datetime.fromtimestamp(
|
|
float(ts_ms) / 1000.0,
|
|
).isoformat()
|
|
except (TypeError, ValueError, OSError):
|
|
if self._simulation_date:
|
|
event["timestamp"] = f"{self._simulation_date}"
|
|
else:
|
|
event["timestamp"] = datetime.now().isoformat()
|
|
elif self._simulation_date:
|
|
event["timestamp"] = f"{self._simulation_date}"
|
|
else:
|
|
event["timestamp"] = datetime.now().isoformat()
|
|
|
|
# Persist to feed_history
|
|
if persist:
|
|
self.storage.add_feed_message(self._state, event)
|
|
self.save_state()
|
|
|
|
# Broadcast to frontend
|
|
if self._broadcast_fn:
|
|
await self._broadcast_fn(event)
|
|
|
|
# ========== Agent Events ==========
|
|
|
|
async def on_agent_complete(
|
|
self,
|
|
agent_id: str,
|
|
content: str,
|
|
**extra,
|
|
):
|
|
"""
|
|
Called when an agent finishes its reply
|
|
|
|
Args:
|
|
agent_id: Agent identifier (e.g., "fundamentals_analyst")
|
|
content: Agent's output content
|
|
**extra: Additional fields to include
|
|
"""
|
|
ts_ms = self._get_timestamp_ms()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "agent_message",
|
|
"agentId": agent_id,
|
|
"content": content,
|
|
"ts": ts_ms,
|
|
**extra,
|
|
},
|
|
)
|
|
|
|
logger.info(f"Agent complete: {agent_id}")
|
|
|
|
async def on_memory_retrieved(
|
|
self,
|
|
agent_id: str,
|
|
content: str,
|
|
):
|
|
"""
|
|
Called when long-term memory is retrieved for an agent
|
|
|
|
Args:
|
|
agent_id: Agent identifier
|
|
content: Retrieved memory content
|
|
"""
|
|
ts_ms = self._get_timestamp_ms()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "memory",
|
|
"agentId": agent_id,
|
|
"content": content,
|
|
"ts": ts_ms,
|
|
},
|
|
)
|
|
|
|
logger.info(f"Memory retrieved for: {agent_id}")
|
|
|
|
# ========== Conference Events ==========
|
|
|
|
async def on_conference_start(self, title: str, date: str):
|
|
"""Called when conference discussion starts"""
|
|
ts_ms = self._get_timestamp_ms()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "conference_start",
|
|
"title": title,
|
|
"date": date,
|
|
"ts": ts_ms,
|
|
},
|
|
)
|
|
|
|
logger.info(f"Conference started: {title}")
|
|
|
|
async def on_conference_cycle_start(self, cycle: int, total_cycles: int):
|
|
"""Called when a conference cycle starts"""
|
|
await self.emit(
|
|
{
|
|
"type": "conference_cycle_start",
|
|
"cycle": cycle,
|
|
"totalCycles": total_cycles,
|
|
},
|
|
persist=False,
|
|
)
|
|
|
|
async def on_conference_message(self, agent_id: str, content: str):
|
|
"""Called when an agent speaks during conference"""
|
|
ts_ms = self._get_timestamp_ms()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "conference_message",
|
|
"agentId": agent_id,
|
|
"content": content,
|
|
"ts": ts_ms,
|
|
},
|
|
)
|
|
|
|
async def on_conference_cycle_end(self, cycle: int):
|
|
"""Called when a conference cycle ends"""
|
|
await self.emit(
|
|
{
|
|
"type": "conference_cycle_end",
|
|
"cycle": cycle,
|
|
},
|
|
persist=False,
|
|
)
|
|
|
|
async def on_conference_end(self):
|
|
"""Called when conference discussion ends"""
|
|
ts_ms = self._get_timestamp_ms()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "conference_end",
|
|
"ts": ts_ms,
|
|
},
|
|
)
|
|
|
|
logger.info("Conference ended")
|
|
|
|
# ========== Cycle Events ==========
|
|
|
|
async def on_cycle_start(self, date: str):
|
|
"""Called at start of trading cycle"""
|
|
self._state["current_date"] = date
|
|
self._state["status"] = "running"
|
|
if self._state.get("server_mode") == "backtest":
|
|
self.set_simulation_date(
|
|
date,
|
|
) # Set for backtest-compatible timestamps
|
|
else:
|
|
self.clear_simulation_date()
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "day_start",
|
|
"date": date,
|
|
"progress": self._calculate_progress(),
|
|
},
|
|
)
|
|
# await self.emit(
|
|
# {
|
|
# "type": "system",
|
|
# "content": f"Starting trading analysis for {date}",
|
|
# },
|
|
# )
|
|
|
|
async def on_cycle_end(self, date: str, portfolio_summary: Dict = None):
|
|
"""Called at end of trading cycle"""
|
|
# Update completed count
|
|
self._state["trading_days_completed"] = (
|
|
self._state.get("trading_days_completed", 0) + 1
|
|
)
|
|
|
|
# Broadcast team_summary if available
|
|
if portfolio_summary:
|
|
summary_data = {
|
|
"type": "team_summary",
|
|
"balance": portfolio_summary.get(
|
|
"balance",
|
|
portfolio_summary.get("total_value", 0),
|
|
),
|
|
"pnlPct": portfolio_summary.get(
|
|
"pnlPct",
|
|
portfolio_summary.get("pnl_percent", 0),
|
|
),
|
|
"equity": portfolio_summary.get("equity", []),
|
|
"baseline": portfolio_summary.get("baseline", []),
|
|
"baseline_vw": portfolio_summary.get("baseline_vw", []),
|
|
"momentum": portfolio_summary.get("momentum", []),
|
|
}
|
|
|
|
# Include live returns if available
|
|
if portfolio_summary.get("equity_return"):
|
|
summary_data["equity_return"] = portfolio_summary[
|
|
"equity_return"
|
|
]
|
|
if portfolio_summary.get("baseline_return"):
|
|
summary_data["baseline_return"] = portfolio_summary[
|
|
"baseline_return"
|
|
]
|
|
if portfolio_summary.get("baseline_vw_return"):
|
|
summary_data["baseline_vw_return"] = portfolio_summary[
|
|
"baseline_vw_return"
|
|
]
|
|
if portfolio_summary.get("momentum_return"):
|
|
summary_data["momentum_return"] = portfolio_summary[
|
|
"momentum_return"
|
|
]
|
|
|
|
if "portfolio" not in self._state:
|
|
self._state["portfolio"] = {}
|
|
|
|
self._state["portfolio"].update(
|
|
{
|
|
"total_value": summary_data["balance"],
|
|
"pnl_percent": summary_data["pnlPct"],
|
|
"equity": summary_data["equity"],
|
|
"baseline": summary_data["baseline"],
|
|
"baseline_vw": summary_data["baseline_vw"],
|
|
"momentum": summary_data["momentum"],
|
|
},
|
|
)
|
|
|
|
if summary_data.get("equity_return"):
|
|
self._state["portfolio"]["equity_return"] = summary_data[
|
|
"equity_return"
|
|
]
|
|
if summary_data.get("baseline_return"):
|
|
self._state["portfolio"]["baseline_return"] = summary_data[
|
|
"baseline_return"
|
|
]
|
|
if summary_data.get("baseline_vw_return"):
|
|
self._state["portfolio"]["baseline_vw_return"] = summary_data[
|
|
"baseline_vw_return"
|
|
]
|
|
if summary_data.get("momentum_return"):
|
|
self._state["portfolio"]["momentum_return"] = summary_data[
|
|
"momentum_return"
|
|
]
|
|
|
|
await self.emit(summary_data, persist=True)
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "day_complete",
|
|
"date": date,
|
|
"progress": self._calculate_progress(),
|
|
},
|
|
)
|
|
|
|
self.save_state()
|
|
|
|
# ========== Portfolio Events ==========
|
|
|
|
async def on_holdings_update(self, holdings: List[Dict]):
|
|
"""Called when holdings change"""
|
|
self._state["holdings"] = holdings
|
|
await self.emit(
|
|
{
|
|
"type": "team_holdings",
|
|
"data": holdings,
|
|
},
|
|
persist=False,
|
|
) # Holdings change frequently, don't store all in feed_history
|
|
|
|
async def on_trades_executed(self, trades: List[Dict]):
|
|
"""Called when trades are executed"""
|
|
# Update state with new trades
|
|
existing = self._state.get("trades", [])
|
|
self._state["trades"] = trades + existing
|
|
|
|
await self.emit(
|
|
{
|
|
"type": "team_trades",
|
|
"mode": "incremental",
|
|
"data": trades,
|
|
},
|
|
persist=False,
|
|
)
|
|
|
|
async def on_stats_update(self, stats: Dict):
|
|
"""Called when stats are updated"""
|
|
self._state["stats"] = stats
|
|
await self.emit(
|
|
{
|
|
"type": "team_stats",
|
|
"data": stats,
|
|
},
|
|
persist=False,
|
|
)
|
|
|
|
async def on_leaderboard_update(self, leaderboard: List[Dict]):
|
|
"""Called when leaderboard is updated"""
|
|
self._state["leaderboard"] = leaderboard
|
|
await self.emit(
|
|
{
|
|
"type": "team_leaderboard",
|
|
"data": leaderboard,
|
|
},
|
|
persist=False,
|
|
)
|
|
|
|
# ========== System Events ==========
|
|
|
|
async def on_system_message(self, content: str):
|
|
"""Emit a system message"""
|
|
await self.emit(
|
|
{
|
|
"type": "system",
|
|
"content": content,
|
|
},
|
|
)
|
|
|
|
# ========== Replay Support ==========
|
|
|
|
async def replay_feed_history(self, delay_ms: int = 100):
|
|
"""
|
|
Replay events from feed_history
|
|
|
|
Useful for: frontend reconnection or restoring from saved state
|
|
"""
|
|
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
|
limit=self.storage.max_feed_history,
|
|
) or self._state.get("feed_history", [])
|
|
|
|
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
|
for event in reversed(feed_history):
|
|
if self._broadcast_fn:
|
|
await self._broadcast_fn(event)
|
|
await asyncio.sleep(delay_ms / 1000)
|
|
|
|
logger.info(f"Replayed {len(feed_history)} events")
|
|
|
|
def get_initial_state_payload(
|
|
self,
|
|
include_dashboard: bool = True,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Build initial state payload for new client connections
|
|
|
|
Args:
|
|
include_dashboard: Whether to load dashboard files
|
|
|
|
Returns:
|
|
Dictionary suitable for sending to frontend
|
|
"""
|
|
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
|
limit=self.storage.max_feed_history,
|
|
) or self._state.get("feed_history", [])
|
|
last_day_history = self.storage.runtime_db.get_last_day_feed_events(
|
|
current_date=self._state.get("current_date"),
|
|
limit=self.storage.max_feed_history,
|
|
) or self._state.get("last_day_history", [])
|
|
|
|
payload = {
|
|
"server_mode": self._state.get("server_mode", "live"),
|
|
"is_mock_mode": self._state.get("is_mock_mode", False),
|
|
"is_backtest": self._state.get("is_backtest", False),
|
|
"tickers": self._state.get("tickers"),
|
|
"runtime_config": self._state.get("runtime_config"),
|
|
"feed_history": feed_history,
|
|
"last_day_history": last_day_history,
|
|
"current_date": self._state.get("current_date"),
|
|
"trading_days_total": self._state.get("trading_days_total", 0),
|
|
"trading_days_completed": self._state.get(
|
|
"trading_days_completed",
|
|
0,
|
|
),
|
|
"holdings": self._state.get("holdings", []),
|
|
"trades": self._state.get("trades", []),
|
|
"stats": self._state.get("stats", {}),
|
|
"leaderboard": self._state.get("leaderboard", []),
|
|
"portfolio": self._state.get("portfolio", {}),
|
|
"realtime_prices": self._state.get("realtime_prices", {}),
|
|
"data_sources": self._state.get("data_sources", {}),
|
|
"price_history": self._state.get("price_history", {}),
|
|
}
|
|
|
|
if include_dashboard:
|
|
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self._state)
|
|
payload["dashboard"] = {
|
|
"summary": dashboard_snapshot.get("summary"),
|
|
"holdings": dashboard_snapshot.get("holdings"),
|
|
"stats": dashboard_snapshot.get("stats"),
|
|
"trades": dashboard_snapshot.get("trades"),
|
|
"leaderboard": dashboard_snapshot.get("leaderboard"),
|
|
}
|
|
|
|
return payload
|
|
|
|
def _calculate_progress(self) -> float:
|
|
"""Calculate backtest progress percentage"""
|
|
total = self._state.get("trading_days_total", 0)
|
|
completed = self._state.get("trading_days_completed", 0)
|
|
return completed / total if total > 0 else 0.0
|
|
|
|
def set_backtest_dates(self, dates: List[str]):
|
|
"""Set total trading days for backtest progress tracking"""
|
|
self._state["trading_days_total"] = len(dates)
|
|
self._state["trading_days_completed"] = 0
|