Initial commit of integrated agent system
This commit is contained in:
35
backend/core/__init__.py
Normal file
35
backend/core/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Core pipeline and orchestration logic.
|
||||
|
||||
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
|
||||
not immediately pull in the gateway runtime graph.
|
||||
"""
|
||||
|
||||
from .pipeline import TradingPipeline
|
||||
from .state_sync import StateSync
|
||||
|
||||
|
||||
def create_agents(*args, **kwargs):
|
||||
from .pipeline_runner import create_agents as _create_agents
|
||||
|
||||
return _create_agents(*args, **kwargs)
|
||||
|
||||
|
||||
def create_long_term_memory(*args, **kwargs):
|
||||
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
|
||||
|
||||
return _create_long_term_memory(*args, **kwargs)
|
||||
|
||||
|
||||
def stop_gateway(*args, **kwargs):
|
||||
from .pipeline_runner import stop_gateway as _stop_gateway
|
||||
|
||||
return _stop_gateway(*args, **kwargs)
|
||||
|
||||
__all__ = [
|
||||
"TradingPipeline",
|
||||
"StateSync",
|
||||
"create_agents",
|
||||
"create_long_term_memory",
|
||||
"stop_gateway",
|
||||
]
|
||||
1684
backend/core/pipeline.py
Normal file
1684
backend/core/pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
481
backend/core/pipeline_runner.py
Normal file
481
backend/core/pipeline_runner.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pipeline Runner - Independent trading pipeline execution
|
||||
|
||||
This module provides functions to start/stop trading pipelines
|
||||
that can be called from the REST API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.runtime.manager import (
|
||||
TradingRuntimeManager,
|
||||
set_global_runtime_manager,
|
||||
clear_global_runtime_manager,
|
||||
set_shutdown_event,
|
||||
clear_shutdown_event,
|
||||
is_shutdown_requested,
|
||||
)
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
# Global gateway reference for cleanup
|
||||
_gateway_instance: Optional[Gateway] = None
|
||||
|
||||
|
||||
def _set_gateway(gateway: Optional[Gateway]) -> None:
|
||||
"""Set global gateway reference."""
|
||||
global _gateway_instance
|
||||
_gateway_instance = gateway
|
||||
|
||||
|
||||
def stop_gateway() -> None:
|
||||
"""Stop the running gateway if exists."""
|
||||
global _gateway_instance
|
||||
if _gateway_instance is not None:
|
||||
try:
|
||||
_gateway_instance.stop()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
|
||||
finally:
|
||||
_gateway_instance = None
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||
try:
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
from agentscope.model import DashScopeChatModel
|
||||
from agentscope.embedding import DashScopeTextEmbedding
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
api_key = os.getenv("MEMORY_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
memory_dir = str(run_dir / "memory")
|
||||
|
||||
return ReMeTaskLongTermMemory(
|
||||
agent_name=agent_name,
|
||||
user_name=agent_name,
|
||||
model=DashScopeChatModel(
|
||||
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
||||
api_key=api_key,
|
||||
stream=False,
|
||||
),
|
||||
embedding_model=DashScopeTextEmbedding(
|
||||
model_name=os.getenv("MEMORY_EMBEDDING_MODEL", "text-embedding-v4"),
|
||||
api_key=api_key,
|
||||
dimensions=1024,
|
||||
),
|
||||
**{
|
||||
"vector_store.default.backend": "local",
|
||||
"vector_store.default.params.store_dir": memory_dir,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_agents(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
initial_cash: float,
|
||||
margin_requirement: float,
|
||||
enable_long_term_memory: bool = False,
|
||||
):
|
||||
"""Create all agents for the system."""
|
||||
analysts = []
|
||||
long_term_memories = []
|
||||
|
||||
# Initialize workspace manager and assets
|
||||
workspace_manager = WorkspaceManager()
|
||||
workspace_manager.initialize_default_assets(
|
||||
config_name=run_id,
|
||||
agent_ids=list(ANALYST_TYPES.keys()) + ["risk_manager", "portfolio_manager"],
|
||||
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
|
||||
)
|
||||
|
||||
profiles = load_agent_profiles()
|
||||
skills_manager = SkillsManager()
|
||||
active_skill_map = skills_manager.prepare_active_skills(
|
||||
config_name=run_id,
|
||||
agent_defaults={
|
||||
agent_id: profile.get("skills", [])
|
||||
for agent_id, profile in profiles.items()
|
||||
},
|
||||
)
|
||||
|
||||
# Create analyst agents
|
||||
for analyst_type in ANALYST_TYPES:
|
||||
model = get_agent_model(analyst_type)
|
||||
formatter = get_agent_formatter(analyst_type)
|
||||
toolkit = create_agent_toolkit(
|
||||
analyst_type,
|
||||
run_id,
|
||||
active_skill_dirs=active_skill_map.get(analyst_type, []),
|
||||
)
|
||||
|
||||
long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
long_term_memory = create_long_term_memory(analyst_type, run_id, run_dir)
|
||||
if long_term_memory:
|
||||
long_term_memories.append(long_term_memory)
|
||||
|
||||
analyst = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=toolkit,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
agent_id=analyst_type,
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=long_term_memory,
|
||||
)
|
||||
analysts.append(analyst)
|
||||
|
||||
# Create risk manager
|
||||
risk_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
risk_long_term_memory = create_long_term_memory("risk_manager", run_id, run_dir)
|
||||
if risk_long_term_memory:
|
||||
long_term_memories.append(risk_long_term_memory)
|
||||
|
||||
risk_manager = RiskAgent(
|
||||
model=get_agent_model("risk_manager"),
|
||||
formatter=get_agent_formatter("risk_manager"),
|
||||
name="risk_manager",
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=risk_long_term_memory,
|
||||
toolkit=create_agent_toolkit(
|
||||
"risk_manager",
|
||||
run_id,
|
||||
active_skill_dirs=active_skill_map.get("risk_manager", []),
|
||||
),
|
||||
)
|
||||
|
||||
# Create portfolio manager
|
||||
pm_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
pm_long_term_memory = create_long_term_memory("portfolio_manager", run_id, run_dir)
|
||||
if pm_long_term_memory:
|
||||
long_term_memories.append(pm_long_term_memory)
|
||||
|
||||
portfolio_manager = PMAgent(
|
||||
name="portfolio_manager",
|
||||
model=get_agent_model("portfolio_manager"),
|
||||
formatter=get_agent_formatter("portfolio_manager"),
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=pm_long_term_memory,
|
||||
toolkit_factory=create_agent_toolkit,
|
||||
toolkit_factory_kwargs={
|
||||
"active_skill_dirs": active_skill_map.get("portfolio_manager", []),
|
||||
},
|
||||
)
|
||||
|
||||
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||
|
||||
|
||||
async def run_pipeline(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: Dict[str, Any],
|
||||
stop_event: asyncio.Event,
|
||||
message_callback: Optional[Callable[[str, Any], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the trading pipeline with the given configuration.
|
||||
|
||||
Service Startup Order:
|
||||
Phase 1: WebSocket Server - Frontend can connect
|
||||
Phase 2: Market Service - Price data starts flowing
|
||||
Phase 3: Agent Runtime - Create all agents
|
||||
Phase 4: Pipeline & Scheduler - Trading logic ready
|
||||
Phase 5: Gateway Fully Operational - All systems running
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier (timestamp)
|
||||
run_dir: Run directory path
|
||||
bootstrap: Bootstrap configuration
|
||||
stop_event: Event to signal pipeline stop
|
||||
message_callback: Optional callback for sending messages to clients
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set global shutdown event
|
||||
set_shutdown_event(stop_event)
|
||||
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
|
||||
try:
|
||||
# Extract config values
|
||||
tickers = bootstrap.get("tickers", ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA", "AMD", "NFLX", "AVGO", "PLTR", "COIN"])
|
||||
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 0: Initialize runtime manager
|
||||
# ======================================================================
|
||||
logger.info("[Phase 0/5] Initializing runtime manager...")
|
||||
|
||||
from backend.api.runtime import runtime_manager
|
||||
|
||||
if runtime_manager is None:
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||
# These will be started by Gateway in the correct order
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1-2/5] Creating infrastructure services...")
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=run_dir / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=run_id,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create market service (data source)
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Create Agent Runtime
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/5] Creating agent runtime...")
|
||||
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=enable_memory,
|
||||
)
|
||||
|
||||
# Register agents with runtime manager
|
||||
for agent in analysts + [risk_manager, pm]:
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if agent_id:
|
||||
runtime_manager.register_agent(agent_id)
|
||||
|
||||
# Load portfolio state
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
# Create settlement coordinator
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Create Pipeline & Scheduler
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=max_comm_cycles,
|
||||
runtime_manager=runtime_manager,
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
scheduler_callback = None
|
||||
live_scheduler = None
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
trading_dates = backtest_scheduler.get_trading_dates()
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
# Live mode
|
||||
live_scheduler = Scheduler(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
|
||||
# Gateway.start() will handle the final startup sequence:
|
||||
# - WebSocket Server first (frontend can connect)
|
||||
# - Market Service second (price data flows)
|
||||
# - Scheduler last (trading begins)
|
||||
# ======================================================================
|
||||
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
|
||||
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
"enable_memory": enable_memory,
|
||||
},
|
||||
scheduler=live_scheduler,
|
||||
)
|
||||
_set_gateway(gateway)
|
||||
|
||||
# Start pipeline execution
|
||||
async with AsyncExitStack() as stack:
|
||||
# Enter long-term memory contexts
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Start Gateway - this will execute the 4-phase startup:
|
||||
# Phase 1: WebSocket Server (frontend can connect immediately)
|
||||
# Phase 2: Market Service (price updates start flowing)
|
||||
# Phase 3: Market Status Monitor
|
||||
# Phase 4: Scheduler (trading cycles begin)
|
||||
gateway_task = asyncio.create_task(
|
||||
gateway.start(host="0.0.0.0", port=8765)
|
||||
)
|
||||
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
|
||||
|
||||
# Wait for Gateway to fully initialize all phases
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Define the trading cycle callback
|
||||
async def trading_cycle(session_key: str) -> None:
|
||||
"""Execute one trading cycle."""
|
||||
if is_shutdown_requested():
|
||||
return
|
||||
|
||||
runtime_manager.set_session_key(session_key)
|
||||
runtime_manager.log_event("cycle:start", {"session": session_key})
|
||||
|
||||
try:
|
||||
# Fetch market data
|
||||
market_data = await market_service.get_all_data()
|
||||
|
||||
# Run pipeline
|
||||
await pipeline.run_cycle(
|
||||
session_key=session_key,
|
||||
market_data=market_data,
|
||||
)
|
||||
|
||||
runtime_manager.log_event("cycle:complete", {"session": session_key})
|
||||
|
||||
except Exception as e:
|
||||
runtime_manager.log_event("cycle:error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
# Start scheduler
|
||||
if scheduler_callback:
|
||||
await scheduler_callback(trading_cycle)
|
||||
|
||||
# Wait for stop signal
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cancel gateway task
|
||||
if not gateway_task.done():
|
||||
gateway_task.cancel()
|
||||
try:
|
||||
await gateway_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation gracefully
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
logger.info("[Pipeline] Cleaning up...")
|
||||
|
||||
# Stop Gateway
|
||||
try:
|
||||
stop_gateway()
|
||||
logger.info("[Pipeline] Gateway stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"[Pipeline] Error stopping gateway: {e}")
|
||||
|
||||
clear_shutdown_event()
|
||||
clear_global_runtime_manager()
|
||||
from backend.api.runtime import unregister_runtime_manager
|
||||
unregister_runtime_manager()
|
||||
logger.info("[Pipeline] Cleanup complete")
|
||||
362
backend/core/scheduler.py
Normal file
362
backend/core/scheduler.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone for US stock trading
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Market-aware scheduler for live trading.
|
||||
Uses NYSE timezone and trading calendar.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
heartbeat_interval: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._callback: Optional[Callable] = None
|
||||
self._heartbeat_callback: Optional[Callable] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _is_trading_hours(self, now: datetime) -> bool:
|
||||
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
|
||||
market_time = now.time()
|
||||
return time(9, 30) <= market_time <= time(16, 0)
|
||||
|
||||
def set_heartbeat_callback(self, callback: Callable) -> None:
|
||||
"""Register callback for heartbeat triggers."""
|
||||
self._heartbeat_callback = callback
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._callback = callback
|
||||
self._schedule_task()
|
||||
|
||||
# Start heartbeat loop if configured
|
||||
if self.heartbeat_interval and self._heartbeat_callback:
|
||||
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
|
||||
logger.info(
|
||||
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
|
||||
def _schedule_task(self):
|
||||
"""Create the active scheduler task for the current mode."""
|
||||
if not self._callback:
|
||||
raise ValueError("Scheduler callback is not set")
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(self._callback))
|
||||
elif self.mode == "intraday":
|
||||
self._task = asyncio.create_task(
|
||||
self._run_intraday(self._callback),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||
|
||||
def reconfigure(
|
||||
self,
|
||||
*,
|
||||
mode: Optional[str] = None,
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Update scheduler parameters in-place and restart its timing loop."""
|
||||
changed = False
|
||||
|
||||
if mode and mode != self.mode:
|
||||
self.mode = mode
|
||||
changed = True
|
||||
|
||||
if trigger_time and trigger_time != self.trigger_time:
|
||||
self.trigger_time = trigger_time
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
changed = True
|
||||
|
||||
if (
|
||||
interval_minutes is not None
|
||||
and interval_minutes > 0
|
||||
and interval_minutes != self.interval_minutes
|
||||
):
|
||||
self.interval_minutes = interval_minutes
|
||||
changed = True
|
||||
|
||||
if changed and self.running and self._callback:
|
||||
self._schedule_task()
|
||||
logger.info(
|
||||
"Scheduler reconfigured: mode=%s, trigger_time=%s, interval_minutes=%s",
|
||||
self.mode,
|
||||
self.trigger_time,
|
||||
self.interval_minutes,
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
async def _run_heartbeat_loop(self):
|
||||
"""Run heartbeat checks on a separate interval during trading hours."""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
if self._is_trading_day(now) and self._is_trading_hours(now):
|
||||
if self._heartbeat_callback:
|
||||
try:
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.debug(
|
||||
f"[Heartbeat] Triggering heartbeat check for {current_date}",
|
||||
)
|
||||
await self._heartbeat_callback(date=current_date)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Heartbeat] Callback failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Heartbeat] Callback not set, skipping heartbeat",
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
|
||||
# Handle "now" trigger - run immediately on first iteration
|
||||
if self.trigger_now and first_run:
|
||||
first_run = False
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering immediately for {current_date}")
|
||||
await callback(date=current_date)
|
||||
# After immediate run, stop (one-shot mode)
|
||||
self.running = False
|
||||
break
|
||||
|
||||
target_time = datetime.strptime(self.trigger_time, "%H:%M").time()
|
||||
|
||||
# Calculate next trigger datetime
|
||||
if now.time() < target_time:
|
||||
next_run = now.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
else:
|
||||
next_run = (now + timedelta(days=1)).replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
# Skip to next trading day
|
||||
next_run = self._next_trading_day(next_run)
|
||||
next_run = next_run.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
wait_seconds = (next_run - now).total_seconds()
|
||||
logger.info(
|
||||
f"Next trigger: {next_run.strftime('%Y-%m-%d %H:%M %Z')} "
|
||||
f"(in {wait_seconds/3600:.1f} hours)",
|
||||
)
|
||||
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
current_date = self._now_nyse().strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering daily cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
async def _run_intraday(self, callback: Callable):
|
||||
"""Run every N minutes (for future use)"""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
|
||||
if self._is_trading_day(now):
|
||||
logger.info(f"Triggering intraday cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
await asyncio.sleep(self.interval_minutes * 60)
|
||||
|
||||
def stop(self):
|
||||
"""Stop scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
class BacktestScheduler:
|
||||
"""Backtest Scheduler - Runs through historical trading dates"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
trading_calendar: Optional[Any] = None,
|
||||
delay_between_days: float = 0.1,
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.trading_calendar = trading_calendar
|
||||
self.delay_between_days = delay_between_days
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._dates: list = []
|
||||
|
||||
def get_trading_dates(self) -> list:
|
||||
"""Get list of trading dates in the backtest period"""
|
||||
import pandas as pd
|
||||
|
||||
start = pd.to_datetime(self.start_date)
|
||||
end = pd.to_datetime(self.end_date)
|
||||
|
||||
if self.trading_calendar:
|
||||
calendar = mcal.get_calendar(self.trading_calendar)
|
||||
trading_dates = calendar.valid_days(
|
||||
start_date=self.start_date,
|
||||
end_date=self.end_date,
|
||||
)
|
||||
dates = [d.strftime("%Y-%m-%d") for d in trading_dates]
|
||||
else:
|
||||
all_dates = pd.date_range(start=start, end=end, freq="D")
|
||||
dates = [
|
||||
d.strftime("%Y-%m-%d") for d in all_dates if d.weekday() < 5
|
||||
]
|
||||
|
||||
self._dates = dates
|
||||
return dates
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start async backtest scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Backtest scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
dates = self.get_trading_dates()
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
self._task = asyncio.create_task(self._run_async(callback, dates))
|
||||
|
||||
async def _run_async(self, callback: Callable, dates: list):
|
||||
"""Run backtest asynchronously"""
|
||||
for i, date in enumerate(dates, 1):
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
await callback(date=date)
|
||||
|
||||
if self.delay_between_days > 0:
|
||||
await asyncio.sleep(self.delay_between_days)
|
||||
|
||||
logger.info("Backtest complete")
|
||||
self.running = False
|
||||
|
||||
def run(self, callback: Callable, **kwargs):
|
||||
"""Run backtest synchronously through all trading dates"""
|
||||
dates = self.get_trading_dates()
|
||||
results = []
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
for i, date in enumerate(dates, 1):
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
result = callback(date=date, **kwargs)
|
||||
results.append({"date": date, "result": result})
|
||||
|
||||
logger.info("Backtest complete")
|
||||
return results
|
||||
|
||||
def stop(self):
|
||||
"""Stop backtest scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
logger.info("Backtest scheduler stopped")
|
||||
|
||||
def get_total_days(self) -> int:
|
||||
"""Get total number of trading days"""
|
||||
if not self._dates:
|
||||
self.get_trading_dates()
|
||||
return len(self._dates)
|
||||
510
backend/core/state_sync.py
Normal file
510
backend/core/state_sync.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
StateSync - Centralized state synchronization between agents and frontend
|
||||
Handles real-time updates, persistence, and replay support
|
||||
"""
|
||||
# pylint: disable=R0904
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from ..services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateSync:
|
||||
"""
|
||||
Central event dispatcher for agent-frontend synchronization
|
||||
|
||||
Responsibilities:
|
||||
1. Receive events from agents/pipeline
|
||||
2. Persist to storage (feed_history)
|
||||
3. Broadcast to frontend via WebSocket
|
||||
4. Support replay from saved state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: StorageService,
|
||||
broadcast_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Initialize StateSync
|
||||
|
||||
Args:
|
||||
storage: Storage service for persistence
|
||||
broadcast_fn: Async broadcast function - async def broadcast(event: dict) # noqa: E501
|
||||
"""
|
||||
self.storage = storage
|
||||
self._broadcast_fn = broadcast_fn
|
||||
self._state: Dict[str, Any] = {}
|
||||
self._enabled = True
|
||||
self._simulation_date: Optional[str] = None # For backtest timestamps
|
||||
|
||||
def set_simulation_date(self, date: str):
|
||||
"""Set current simulation date for backtest-compatible timestamps"""
|
||||
self._simulation_date = date
|
||||
|
||||
def clear_simulation_date(self):
|
||||
"""Disable backtest timestamp simulation and use wall-clock time."""
|
||||
self._simulation_date = None
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""
|
||||
Get timestamp in milliseconds.
|
||||
Uses simulation date if set (backtest mode), otherwise current time.
|
||||
"""
|
||||
if self._simulation_date:
|
||||
# Parse date and use market close time (16:00) for backtest
|
||||
dt = datetime.strptime(
|
||||
f"{self._simulation_date}",
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
return int(dt.timestamp() * 1000)
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def load_state(self):
|
||||
"""Load server state from storage"""
|
||||
self._state = self.storage.load_server_state()
|
||||
self.storage.update_server_state_from_dashboard(self._state)
|
||||
logger.info(
|
||||
f"StateSync loaded: {len(self._state.get('feed_history', []))} feeds", # noqa: E501
|
||||
)
|
||||
|
||||
def save_state(self):
|
||||
"""Save current state to storage"""
|
||||
self.storage.save_server_state(self._state)
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
"""Get current state"""
|
||||
return self._state
|
||||
|
||||
def set_broadcast_fn(self, fn: Callable):
|
||||
"""Set broadcast function (supports late binding)"""
|
||||
self._broadcast_fn = fn
|
||||
|
||||
def update_state(self, key: str, value: Any):
|
||||
"""Update a state field"""
|
||||
self._state[key] = value
|
||||
|
||||
async def emit(self, event: Dict[str, Any], persist: bool = True):
|
||||
"""
|
||||
Emit an event - persists and broadcasts
|
||||
|
||||
Args:
|
||||
event: Event dictionary, must contain "type"
|
||||
persist: Whether to persist to feed_history
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
|
||||
# frontend displays local wall time correctly instead of date-only UTC.
|
||||
if "timestamp" not in event:
|
||||
ts_ms = event.get("ts")
|
||||
if ts_ms is not None:
|
||||
try:
|
||||
event["timestamp"] = datetime.fromtimestamp(
|
||||
float(ts_ms) / 1000.0,
|
||||
).isoformat()
|
||||
except (TypeError, ValueError, OSError):
|
||||
if self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
elif self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
# Persist to feed_history
|
||||
if persist:
|
||||
self.storage.add_feed_message(self._state, event)
|
||||
self.save_state()
|
||||
|
||||
# Broadcast to frontend
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
|
||||
# ========== Agent Events ==========
|
||||
|
||||
async def on_agent_complete(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
**extra,
|
||||
):
|
||||
"""
|
||||
Called when an agent finishes its reply
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier (e.g., "fundamentals_analyst")
|
||||
content: Agent's output content
|
||||
**extra: Additional fields to include
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "agent_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Agent complete: {agent_id}")
|
||||
|
||||
async def on_memory_retrieved(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
):
|
||||
"""
|
||||
Called when long-term memory is retrieved for an agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
content: Retrieved memory content
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "memory",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory retrieved for: {agent_id}")
|
||||
|
||||
# ========== Conference Events ==========
|
||||
|
||||
async def on_conference_start(self, title: str, date: str):
|
||||
"""Called when conference discussion starts"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_start",
|
||||
"title": title,
|
||||
"date": date,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Conference started: {title}")
|
||||
|
||||
async def on_conference_cycle_start(self, cycle: int, total_cycles: int):
|
||||
"""Called when a conference cycle starts"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_start",
|
||||
"cycle": cycle,
|
||||
"totalCycles": total_cycles,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_message(self, agent_id: str, content: str):
|
||||
"""Called when an agent speaks during conference"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
async def on_conference_cycle_end(self, cycle: int):
|
||||
"""Called when a conference cycle ends"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_end",
|
||||
"cycle": cycle,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_end(self):
|
||||
"""Called when conference discussion ends"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_end",
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Conference ended")
|
||||
|
||||
# ========== Cycle Events ==========
|
||||
|
||||
async def on_cycle_start(self, date: str):
|
||||
"""Called at start of trading cycle"""
|
||||
self._state["current_date"] = date
|
||||
self._state["status"] = "running"
|
||||
if self._state.get("server_mode") == "backtest":
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
else:
|
||||
self.clear_simulation_date()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_start",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
# await self.emit(
|
||||
# {
|
||||
# "type": "system",
|
||||
# "content": f"Starting trading analysis for {date}",
|
||||
# },
|
||||
# )
|
||||
|
||||
async def on_cycle_end(self, date: str, portfolio_summary: Dict = None):
|
||||
"""Called at end of trading cycle"""
|
||||
# Update completed count
|
||||
self._state["trading_days_completed"] = (
|
||||
self._state.get("trading_days_completed", 0) + 1
|
||||
)
|
||||
|
||||
# Broadcast team_summary if available
|
||||
if portfolio_summary:
|
||||
summary_data = {
|
||||
"type": "team_summary",
|
||||
"balance": portfolio_summary.get(
|
||||
"balance",
|
||||
portfolio_summary.get("total_value", 0),
|
||||
),
|
||||
"pnlPct": portfolio_summary.get(
|
||||
"pnlPct",
|
||||
portfolio_summary.get("pnl_percent", 0),
|
||||
),
|
||||
"equity": portfolio_summary.get("equity", []),
|
||||
"baseline": portfolio_summary.get("baseline", []),
|
||||
"baseline_vw": portfolio_summary.get("baseline_vw", []),
|
||||
"momentum": portfolio_summary.get("momentum", []),
|
||||
}
|
||||
|
||||
# Include live returns if available
|
||||
if portfolio_summary.get("equity_return"):
|
||||
summary_data["equity_return"] = portfolio_summary[
|
||||
"equity_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_return"):
|
||||
summary_data["baseline_return"] = portfolio_summary[
|
||||
"baseline_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_vw_return"):
|
||||
summary_data["baseline_vw_return"] = portfolio_summary[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if portfolio_summary.get("momentum_return"):
|
||||
summary_data["momentum_return"] = portfolio_summary[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
if "portfolio" not in self._state:
|
||||
self._state["portfolio"] = {}
|
||||
|
||||
self._state["portfolio"].update(
|
||||
{
|
||||
"total_value": summary_data["balance"],
|
||||
"pnl_percent": summary_data["pnlPct"],
|
||||
"equity": summary_data["equity"],
|
||||
"baseline": summary_data["baseline"],
|
||||
"baseline_vw": summary_data["baseline_vw"],
|
||||
"momentum": summary_data["momentum"],
|
||||
},
|
||||
)
|
||||
|
||||
if summary_data.get("equity_return"):
|
||||
self._state["portfolio"]["equity_return"] = summary_data[
|
||||
"equity_return"
|
||||
]
|
||||
if summary_data.get("baseline_return"):
|
||||
self._state["portfolio"]["baseline_return"] = summary_data[
|
||||
"baseline_return"
|
||||
]
|
||||
if summary_data.get("baseline_vw_return"):
|
||||
self._state["portfolio"]["baseline_vw_return"] = summary_data[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if summary_data.get("momentum_return"):
|
||||
self._state["portfolio"]["momentum_return"] = summary_data[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
await self.emit(summary_data, persist=True)
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_complete",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
|
||||
self.save_state()
|
||||
|
||||
# ========== Portfolio Events ==========
|
||||
|
||||
async def on_holdings_update(self, holdings: List[Dict]):
|
||||
"""Called when holdings change"""
|
||||
self._state["holdings"] = holdings
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_holdings",
|
||||
"data": holdings,
|
||||
},
|
||||
persist=False,
|
||||
) # Holdings change frequently, don't store all in feed_history
|
||||
|
||||
async def on_trades_executed(self, trades: List[Dict]):
|
||||
"""Called when trades are executed"""
|
||||
# Update state with new trades
|
||||
existing = self._state.get("trades", [])
|
||||
self._state["trades"] = trades + existing
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_trades",
|
||||
"mode": "incremental",
|
||||
"data": trades,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_stats_update(self, stats: Dict):
|
||||
"""Called when stats are updated"""
|
||||
self._state["stats"] = stats
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_stats",
|
||||
"data": stats,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_leaderboard_update(self, leaderboard: List[Dict]):
|
||||
"""Called when leaderboard is updated"""
|
||||
self._state["leaderboard"] = leaderboard
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_leaderboard",
|
||||
"data": leaderboard,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
# ========== System Events ==========
|
||||
|
||||
async def on_system_message(self, content: str):
|
||||
"""Emit a system message"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "system",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# ========== Replay Support ==========
|
||||
|
||||
async def replay_feed_history(self, delay_ms: int = 100):
|
||||
"""
|
||||
Replay events from feed_history
|
||||
|
||||
Useful for: frontend reconnection or restoring from saved state
|
||||
"""
|
||||
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("feed_history", [])
|
||||
|
||||
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
||||
for event in reversed(feed_history):
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
logger.info(f"Replayed {len(feed_history)} events")
|
||||
|
||||
def get_initial_state_payload(
|
||||
self,
|
||||
include_dashboard: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build initial state payload for new client connections
|
||||
|
||||
Args:
|
||||
include_dashboard: Whether to load dashboard files
|
||||
|
||||
Returns:
|
||||
Dictionary suitable for sending to frontend
|
||||
"""
|
||||
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("feed_history", [])
|
||||
last_day_history = self.storage.runtime_db.get_last_day_feed_events(
|
||||
current_date=self._state.get("current_date"),
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("last_day_history", [])
|
||||
|
||||
payload = {
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"tickers": self._state.get("tickers"),
|
||||
"runtime_config": self._state.get("runtime_config"),
|
||||
"feed_history": feed_history,
|
||||
"last_day_history": last_day_history,
|
||||
"current_date": self._state.get("current_date"),
|
||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||
"trading_days_completed": self._state.get(
|
||||
"trading_days_completed",
|
||||
0,
|
||||
),
|
||||
"holdings": self._state.get("holdings", []),
|
||||
"trades": self._state.get("trades", []),
|
||||
"stats": self._state.get("stats", {}),
|
||||
"leaderboard": self._state.get("leaderboard", []),
|
||||
"portfolio": self._state.get("portfolio", {}),
|
||||
"realtime_prices": self._state.get("realtime_prices", {}),
|
||||
"data_sources": self._state.get("data_sources", {}),
|
||||
"price_history": self._state.get("price_history", {}),
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self._state)
|
||||
payload["dashboard"] = {
|
||||
"summary": dashboard_snapshot.get("summary"),
|
||||
"holdings": dashboard_snapshot.get("holdings"),
|
||||
"stats": dashboard_snapshot.get("stats"),
|
||||
"trades": dashboard_snapshot.get("trades"),
|
||||
"leaderboard": dashboard_snapshot.get("leaderboard"),
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def _calculate_progress(self) -> float:
|
||||
"""Calculate backtest progress percentage"""
|
||||
total = self._state.get("trading_days_total", 0)
|
||||
completed = self._state.get("trading_days_completed", 0)
|
||||
return completed / total if total > 0 else 0.0
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
"""Set total trading days for backtest progress tracking"""
|
||||
self._state["trading_days_total"] = len(dates)
|
||||
self._state["trading_days_completed"] = 0
|
||||
Reference in New Issue
Block a user