feat: Add evaluation hooks, skill adaptation and team pipeline config

- Add EvaluationHook for post-execution agent evaluation
- Add SkillAdaptationHook for dynamic skill adaptation
- Add team/ directory with team coordination logic
- Add TEAM_PIPELINE.yaml for smoke_fullstack pipeline config
- Update RuntimeView, TraderView and RuntimeSettingsPanel UI
- Add runtimeApi and websocket services
- Add runtime_state.json to smoke_fullstack state

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-19 18:52:12 +08:00
parent f4a2b7f3af
commit 4b5ac86b83
87 changed files with 5042 additions and 744 deletions

View File

@@ -10,6 +10,8 @@ import json
import logging
import os
import re
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional
from agentscope.message import Msg
@@ -21,6 +23,26 @@ from backend.core.state_sync import StateSync
from backend.utils.trade_executor import PortfolioTradeExecutor
from backend.runtime.manager import TradingRuntimeManager
from backend.runtime.session import TradingSessionKey
from backend.agents.team_pipeline_config import (
resolve_active_analysts,
update_active_analysts,
)
from backend.agents import AnalystAgent
from backend.agents.toolkit_factory import create_agent_toolkit
from backend.agents.workspace_manager import WorkspaceManager
from backend.agents.prompt_loader import PromptLoader
from backend.llm.models import get_agent_formatter, get_agent_model
from backend.config.constants import ANALYST_TYPES
# Team infrastructure imports (graceful import - may not exist yet)
try:
from backend.agents.team.team_coordinator import TeamCoordinator
from backend.agents.team.msg_hub import MsgHub as TeamMsgHub
TEAM_COORD_AVAILABLE = True
except ImportError:
TEAM_COORD_AVAILABLE = False
TeamCoordinator = None
TeamMsgHub = None
logger = logging.getLogger(__name__)
@@ -77,6 +99,13 @@ class TradingPipeline:
self.agent_factory = agent_factory
self.runtime_manager = runtime_manager
self._session_key: Optional[str] = None
self._dynamic_analysts: Dict[str, Any] = {}
if hasattr(self.pm, "set_team_controller"):
self.pm.set_team_controller(
create_agent_callback=self._create_runtime_analyst,
remove_agent_callback=self._remove_runtime_analyst,
)
async def run_cycle(
self,
@@ -115,16 +144,17 @@ class TradingPipeline:
_log(f"Starting cycle {date} - {len(tickers)} tickers")
session_key = TradingSessionKey(date=date).key()
self._session_key = session_key
active_analysts = self._get_active_analysts()
if self.runtime_manager:
self.runtime_manager.set_session_key(session_key)
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
self._runtime_batch_status(self.analysts, "analysis_in_progress")
self._runtime_batch_status(active_analysts, "analysis_in_progress")
# Phase 0: Clear short-term memory to avoid cross-day context pollution
_log("Phase 0: Clearing memory")
await self._clear_all_agent_memory()
participants = self.analysts + [self.risk_manager, self.pm]
participants = self._all_analysts() + [self.risk_manager, self.pm]
# Single MsgHub for entire cycle - no nesting
async with MsgHub(
@@ -135,9 +165,13 @@ class TradingPipeline:
"system",
),
):
# Phase 1.1: Analysts
_log("Phase 1.1: Analyst analysis")
analyst_results = await self._run_analysts_with_sync(tickers, date)
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
_log("Phase 1.1: Analyst analysis (parallel)")
analyst_results = await self._run_analysts_parallel(
tickers,
date,
active_analysts=active_analysts,
)
# Phase 1.2: Risk Manager
_log("Phase 1.2: Risk assessment")
@@ -164,6 +198,7 @@ class TradingPipeline:
final_predictions = await self._collect_final_predictions(
tickers,
date,
active_analysts=active_analysts,
)
# Record final predictions for leaderboard ranking
@@ -212,7 +247,7 @@ class TradingPipeline:
if close_prices and self.settlement_coordinator:
_log("Phase 5: Daily review and generate memories")
self._runtime_batch_status(
[self.risk_manager] + self.analysts + [self.pm],
[self.risk_manager] + self._all_analysts() + [self.pm],
"settlement",
)
@@ -246,13 +281,13 @@ class TradingPipeline:
conference_summary=self.conference_summary,
)
self._runtime_batch_status(
[self.risk_manager] + self.analysts + [self.pm],
[self.risk_manager] + self._all_analysts() + [self.pm],
"reflection",
)
_log(f"Cycle complete: {date}")
self._runtime_batch_status(
self.analysts + [self.risk_manager, self.pm],
self._all_analysts() + [self.risk_manager, self.pm],
"idle",
)
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
@@ -288,7 +323,7 @@ class TradingPipeline:
},
)
for analyst in self.analysts:
for analyst in self._all_analysts():
analyst.reload_runtime_assets(
active_skill_dirs=active_skill_map.get(analyst.name, []),
)
@@ -302,7 +337,7 @@ class TradingPipeline:
return {
"config_name": config_name,
"reloaded_agents": [agent.name for agent in self.analysts]
"reloaded_agents": [agent.name for agent in self._all_analysts()]
+ ["risk_manager", "portfolio_manager"],
"active_skills": {
agent_id: [path.name for path in paths]
@@ -313,7 +348,7 @@ class TradingPipeline:
async def _clear_all_agent_memory(self):
"""Clear short-term memory for all agents"""
for analyst in self.analysts:
for analyst in self._all_analysts():
await analyst.memory.clear()
await self.risk_manager.memory.clear()
@@ -395,7 +430,7 @@ class TradingPipeline:
trajectories = {}
# Capture analyst trajectories
for analyst in self.analysts:
for analyst in self._all_analysts():
try:
msgs = await analyst.memory.get_memory()
if msgs:
@@ -605,7 +640,7 @@ class TradingPipeline:
)
# Record for analysts
for analyst in self.analysts:
for analyst in self._all_analysts():
if (
hasattr(analyst, "long_term_memory")
and analyst.long_term_memory is not None
@@ -724,67 +759,82 @@ class TradingPipeline:
date=date,
)
# Run discussion cycles (no new MsgHub - use parent's)
for cycle in range(self.max_comm_cycles):
# Conference participants: analysts + PM
conference_participants = self._get_active_analysts() + [self.pm]
# Use TeamMsgHub for conference if available
if TEAM_COORD_AVAILABLE and TeamMsgHub is not None:
_log(
"Phase 2.1: Conference discussion - "
f"Conference {cycle + 1}/{self.max_comm_cycles}",
f"Phase 2.1: Conference using TeamMsgHub with "
f"{len(conference_participants)} participants"
)
conference_hub = TeamMsgHub(participants=conference_participants)
else:
_log("Phase 2.1: Conference using standard MsgHub context")
conference_hub = None
if self.state_sync:
await self.state_sync.on_conference_cycle_start(
cycle=cycle + 1,
total_cycles=self.max_comm_cycles,
# Run discussion cycles
async with conference_hub if conference_hub else nullcontext(None):
for cycle in range(self.max_comm_cycles):
_log(
"Phase 2.1: Conference discussion - "
f"Conference {cycle + 1}/{self.max_comm_cycles}",
)
# PM sets agenda or asks questions
pm_prompt = self._build_pm_discussion_prompt(
cycle=cycle,
tickers=tickers,
date=date,
prices=prices,
analyst_results=analyst_results,
risk_assessment=risk_assessment,
)
if self.state_sync:
await self.state_sync.on_conference_cycle_start(
cycle=cycle + 1,
total_cycles=self.max_comm_cycles,
)
pm_msg = Msg(name="system", content=pm_prompt, role="user")
pm_response = await self.pm.reply(pm_msg)
if self.state_sync:
pm_content = self._extract_text_content(pm_response.content)
await self.state_sync.on_conference_message(
agent_id="portfolio_manager",
content=pm_content,
)
# Analysts share perspectives
for analyst in self.analysts:
analyst_prompt = self._build_analyst_discussion_prompt(
# PM sets agenda or asks questions
pm_prompt = self._build_pm_discussion_prompt(
cycle=cycle,
tickers=tickers,
date=date,
prices=prices,
analyst_results=analyst_results,
risk_assessment=risk_assessment,
)
analyst_msg = Msg(
name="system",
content=analyst_prompt,
role="user",
)
analyst_response = await analyst.reply(analyst_msg)
pm_msg = Msg(name="system", content=pm_prompt, role="user")
pm_response = await self.pm.reply(pm_msg)
if self.state_sync:
analyst_content = self._extract_text_content(
analyst_response.content,
)
pm_content = self._extract_text_content(pm_response.content)
await self.state_sync.on_conference_message(
agent_id=analyst.name,
content=analyst_content,
agent_id="portfolio_manager",
content=pm_content,
)
if self.state_sync:
await self.state_sync.on_conference_cycle_end(
cycle=cycle + 1,
)
# Analysts share perspectives (supports per-round active team updates)
for analyst in self._get_active_analysts():
analyst_prompt = self._build_analyst_discussion_prompt(
cycle=cycle,
tickers=tickers,
date=date,
)
analyst_msg = Msg(
name="system",
content=analyst_prompt,
role="user",
)
analyst_response = await analyst.reply(analyst_msg)
if self.state_sync:
analyst_content = self._extract_text_content(
analyst_response.content,
)
await self.state_sync.on_conference_message(
agent_id=analyst.name,
content=analyst_content,
)
if self.state_sync:
await self.state_sync.on_conference_cycle_end(
cycle=cycle + 1,
)
# Generate conference summary by PM
_log(
@@ -885,6 +935,7 @@ class TradingPipeline:
self,
tickers: List[str],
date: str,
active_analysts: Optional[List[Any]] = None,
) -> List[Dict[str, Any]]:
"""
Collect final predictions from all analysts as simple text responses.
@@ -892,14 +943,15 @@ class TradingPipeline:
"""
_log(
"Phase 2.2: Analysts generate final structured predictions\n"
f" Starting _collect_final_predictions for {len(self.analysts)} analysts",
f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts",
)
final_predictions = []
for i, analyst in enumerate(self.analysts):
analysts = active_analysts or self.analysts
for i, analyst in enumerate(analysts):
_log(
"Phase 2.2: Analysts generate final structured predictions\n"
f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}",
f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}",
)
prompt = (
@@ -995,11 +1047,13 @@ class TradingPipeline:
self,
tickers: List[str],
date: str,
active_analysts: Optional[List[Any]] = None,
) -> List[Dict[str, Any]]:
"""Run all analysts with real-time sync after each completion"""
results = []
analysts = active_analysts or self.analysts
for analyst in self.analysts:
for analyst in analysts:
content = (
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
f"Provide investment signals with confidence scores and reasoning."
@@ -1029,15 +1083,107 @@ class TradingPipeline:
return results
async def _run_analysts_parallel(
self,
tickers: List[str],
date: str,
active_analysts: Optional[List[Any]] = None,
) -> List[Dict[str, Any]]:
"""Run all analysts in parallel using TeamCoordinator.
This method replaces the sequential analyst loop with parallel execution
using the TeamCoordinator for orchestration.
Args:
tickers: List of stock tickers to analyze
date: Trading date
active_analysts: Optional list of analysts to run
Returns:
List of analyst result dictionaries
"""
analysts = active_analysts or self.analysts
if not analysts:
return []
if not TEAM_COORD_AVAILABLE:
_log("TeamCoordinator not available, falling back to sequential execution")
return await self._run_analysts_with_sync(
tickers=tickers,
date=date,
active_analysts=active_analysts,
)
_log(
f"Phase 1.1: Running {len(analysts)} analysts in parallel "
f"[{', '.join(a.name for a in analysts)}]"
)
# Build the analyst prompt
content = (
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
f"Provide investment signals with confidence scores and reasoning."
)
# Create coordinator for parallel execution
coordinator = TeamCoordinator(
participants=analysts,
task_content=content,
)
# Run analysts in parallel via TeamCoordinator
results = await coordinator.run_phase(
"analyst_analysis",
metadata={"tickers": tickers, "date": date},
)
# Process results and sync
processed_results = []
for i, (analyst, result) in enumerate(zip(analysts, results)):
if result is not None:
extracted = self._extract_result_from_msg(result)
processed_results.append(extracted)
# Sync retrieved memory
await self._sync_memory_if_retrieved(analyst)
# Broadcast agent result via StateSync
if self.state_sync:
text_content = self._extract_text_content(result.content)
await self.state_sync.on_agent_complete(
agent_id=analyst.name,
content=text_content,
)
else:
logger.warning(
"Analyst %s returned no result",
analyst.name,
)
processed_results.append({
"agent": analyst.name,
"content": "",
"success": False,
})
_log(
f"Phase 1.1: Parallel analyst execution complete "
f"({len(processed_results)}/{len(analysts)} successful)"
)
return processed_results
async def _run_analysts(
self,
tickers: List[str],
date: str,
active_analysts: Optional[List[Any]] = None,
) -> List[Dict[str, Any]]:
"""Run all analysts (without sync, for backward compatibility)"""
results = []
analysts = active_analysts or self.analysts
for analyst in self.analysts:
for analyst in analysts:
content = (
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
f"Provide investment signals with confidence scores and reasoning."
@@ -1461,6 +1607,83 @@ class TradingPipeline:
for agent in agents:
self._runtime_update_status(agent, status)
def _all_analysts(self) -> List[Any]:
"""Return static analysts plus runtime-created analysts."""
return list(self.analysts) + list(self._dynamic_analysts.values())
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
"""Create one runtime analyst instance."""
if analyst_type not in ANALYST_TYPES:
return (
f"Unknown analyst_type '{analyst_type}'. "
f"Available: {', '.join(ANALYST_TYPES.keys())}"
)
if agent_id in {agent.name for agent in self._all_analysts()}:
return f"Analyst '{agent_id}' already exists."
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
personas = PromptLoader().load_yaml_config("analyst", "personas")
persona = personas.get(analyst_type, {})
WorkspaceManager(project_root=project_root).ensure_agent_assets(
config_name=config_name,
agent_id=agent_id,
role_seed=persona.get("description", "").strip(),
style_seed="\n".join(f"- {item}" for item in persona.get("focus", [])),
policy_seed=(
"State a clear signal, confidence, and the conditions "
"that would invalidate the thesis."
),
)
agent = AnalystAgent(
analyst_type=analyst_type,
toolkit=create_agent_toolkit(
agent_id=agent_id,
config_name=config_name,
active_skill_dirs=[],
),
model=get_agent_model(analyst_type),
formatter=get_agent_formatter(analyst_type),
agent_id=agent_id,
config={"config_name": config_name},
)
self._dynamic_analysts[agent_id] = agent
update_active_analysts(
project_root=project_root,
config_name=config_name,
available_analysts=[item.name for item in self._all_analysts()],
add=[agent_id],
)
return f"Created runtime analyst '{agent_id}' ({analyst_type})."
def _remove_runtime_analyst(self, agent_id: str) -> str:
"""Remove one runtime-created analyst instance."""
if agent_id not in self._dynamic_analysts:
return f"Runtime analyst '{agent_id}' not found."
self._dynamic_analysts.pop(agent_id, None)
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
update_active_analysts(
project_root=project_root,
config_name=config_name,
available_analysts=[item.name for item in self._all_analysts()],
remove=[agent_id],
)
return f"Removed runtime analyst '{agent_id}'."
def _get_active_analysts(self) -> List[Any]:
"""Resolve active analyst participants from run-scoped team pipeline config."""
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
analyst_map = {agent.name: agent for agent in self._all_analysts()}
active_ids = resolve_active_analysts(
project_root=project_root,
config_name=config_name,
available_analysts=list(analyst_map.keys()),
)
return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map]
def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None:
if not self.runtime_manager:
return

View File

@@ -61,7 +61,7 @@ def stop_gateway() -> None:
_gateway_instance = None
async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
"""Create ReMeTaskLongTermMemory for an agent."""
try:
from agentscope.memory import ReMeTaskLongTermMemory
@@ -206,6 +206,13 @@ async def run_pipeline(
"""
Run the trading pipeline with the given configuration.
Service Startup Order:
Phase 1: WebSocket Server - Frontend can connect
Phase 2: Market Service - Price data starts flowing
Phase 3: Agent Runtime - Create all agents
Phase 4: Pipeline & Scheduler - Trading logic ready
Phase 5: Gateway Fully Operational - All systems running
Args:
run_id: Unique run identifier (timestamp)
run_dir: Run directory path
@@ -219,7 +226,9 @@ async def run_pipeline(
# Set global shutdown event
set_shutdown_event(stop_event)
logger.info(f"[Pipeline {run_id}] Starting...")
logger.info(f"[Pipeline {run_id}] ======================================")
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
logger.info(f"[Pipeline {run_id}] ======================================")
try:
# Extract config values
@@ -230,15 +239,21 @@ async def run_pipeline(
schedule_mode = bootstrap.get("schedule_mode", "daily")
trigger_time = bootstrap.get("trigger_time", "09:30")
interval_minutes = int(bootstrap.get("interval_minutes", 60))
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
mode = bootstrap.get("mode", "live")
start_date = bootstrap.get("start_date")
end_date = bootstrap.get("end_date")
enable_memory = bootstrap.get("enable_memory", False)
enable_mock = bootstrap.get("enable_mock", False)
is_backtest = mode == "backtest"
is_mock = mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
# ======================================================================
# PHASE 0: Initialize runtime manager
# ======================================================================
logger.info("[Phase 0/5] Initializing runtime manager...")
# Get or create runtime manager
from backend.api.runtime import runtime_manager
if runtime_manager is None:
@@ -255,16 +270,11 @@ async def run_pipeline(
from backend.api.runtime import register_runtime_manager
register_runtime_manager(runtime_manager)
# Create market service
market_service = MarketService(
tickers=tickers,
poll_interval=10,
mock_mode=is_mock and not is_backtest,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
backtest_start_date=start_date if is_backtest else None,
backtest_end_date=end_date if is_backtest else None,
)
# ======================================================================
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
# These will be started by Gateway in the correct order
# ======================================================================
logger.info("[Phase 1-2/5] Creating infrastructure services...")
# Create storage service
storage_service = StorageService(
@@ -278,7 +288,22 @@ async def run_pipeline(
else:
storage_service.update_leaderboard_model_info()
# Create agents and pipeline
# Create market service (data source)
market_service = MarketService(
tickers=tickers,
poll_interval=10,
mock_mode=is_mock and not is_backtest,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
backtest_start_date=start_date if is_backtest else None,
backtest_end_date=end_date if is_backtest else None,
)
# ======================================================================
# PHASE 3: Create Agent Runtime
# ======================================================================
logger.info("[Phase 3/5] Creating agent runtime...")
analysts, risk_manager, pm, long_term_memories = create_agents(
run_id=run_id,
run_dir=run_dir,
@@ -303,6 +328,11 @@ async def run_pipeline(
initial_capital=initial_cash,
)
# ======================================================================
# PHASE 4: Create Pipeline & Scheduler
# ======================================================================
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
# Create pipeline
pipeline = TradingPipeline(
analysts=analysts,
@@ -336,6 +366,7 @@ async def run_pipeline(
mode=schedule_mode,
trigger_time=trigger_time,
interval_minutes=interval_minutes,
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
config={"config_name": run_id},
)
@@ -344,7 +375,15 @@ async def run_pipeline(
scheduler_callback = scheduler_callback_fn
# Create Gateway for WebSocket connections (after pipeline and scheduler are ready)
# ======================================================================
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
# Gateway.start() will handle the final startup sequence:
# - WebSocket Server first (frontend can connect)
# - Market Service second (price data flows)
# - Scheduler last (trading begins)
# ======================================================================
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
gateway = Gateway(
market_service=market_service,
storage_service=storage_service,
@@ -359,6 +398,7 @@ async def run_pipeline(
"schedule_mode": schedule_mode,
"interval_minutes": interval_minutes,
"trigger_time": trigger_time,
"heartbeat_interval": heartbeat_interval,
"initial_cash": initial_cash,
"margin_requirement": margin_requirement,
"max_comm_cycles": max_comm_cycles,
@@ -374,13 +414,17 @@ async def run_pipeline(
for memory in long_term_memories:
await stack.enter_async_context(memory)
# Start Gateway in background task
# Start Gateway - this will execute the 4-phase startup:
# Phase 1: WebSocket Server (frontend can connect immediately)
# Phase 2: Market Service (price updates start flowing)
# Phase 3: Market Status Monitor
# Phase 4: Scheduler (trading cycles begin)
gateway_task = asyncio.create_task(
gateway.start(host="0.0.0.0", port=8765)
)
logger.info("[Pipeline] Gateway started on ws://localhost:8765")
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
# Give Gateway a moment to start
# Wait for Gateway to fully initialize all phases
await asyncio.sleep(0.5)
# Define the trading cycle callback

View File

@@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles
"""
import asyncio
import logging
from datetime import datetime, timedelta
from datetime import datetime, time, timedelta
from typing import Any, Callable, Optional
from zoneinfo import ZoneInfo
@@ -28,17 +28,21 @@ class Scheduler:
mode: str = "daily",
trigger_time: Optional[str] = None,
interval_minutes: Optional[int] = None,
heartbeat_interval: Optional[int] = None,
config: Optional[dict] = None,
):
self.mode = mode
self.trigger_time = trigger_time or "09:30" # NYSE timezone
self.trigger_now = self.trigger_time == "now"
self.interval_minutes = interval_minutes or 60
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
self.config = config or {}
self.running = False
self._task: Optional[asyncio.Task] = None
self._heartbeat_task: Optional[asyncio.Task] = None
self._callback: Optional[Callable] = None
self._heartbeat_callback: Optional[Callable] = None
def _now_nyse(self) -> datetime:
"""Get current time in NYSE timezone"""
@@ -53,6 +57,15 @@ class Scheduler:
)
return len(valid_days) > 0
def _is_trading_hours(self, now: datetime) -> bool:
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
market_time = now.time()
return time(9, 30) <= market_time <= time(16, 0)
def set_heartbeat_callback(self, callback: Callable) -> None:
"""Register callback for heartbeat triggers."""
self._heartbeat_callback = callback
def _next_trading_day(self, from_date: datetime) -> datetime:
"""Find the next trading day from given date"""
check_date = from_date
@@ -72,6 +85,13 @@ class Scheduler:
self._callback = callback
self._schedule_task()
# Start heartbeat loop if configured
if self.heartbeat_interval and self._heartbeat_callback:
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
logger.info(
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
)
logger.info(
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
)
@@ -132,6 +152,30 @@ class Scheduler:
return changed
async def _run_heartbeat_loop(self):
"""Run heartbeat checks on a separate interval during trading hours."""
while self.running:
now = self._now_nyse()
if self._is_trading_day(now) and self._is_trading_hours(now):
if self._heartbeat_callback:
try:
current_date = now.strftime("%Y-%m-%d")
logger.debug(
f"[Heartbeat] Triggering heartbeat check for {current_date}",
)
await self._heartbeat_callback(date=current_date)
except Exception as e:
logger.error(
f"[Heartbeat] Callback failed: {e}",
exc_info=True,
)
else:
logger.warning(
"[Heartbeat] Callback not set, skipping heartbeat",
)
await asyncio.sleep(self.heartbeat_interval)
async def _run_daily(self, callback: Callable):
"""Run once per trading day at specified time (NYSE timezone)"""
first_run = True
@@ -206,6 +250,9 @@ class Scheduler:
if self._task:
self._task.cancel()
self._task = None
if self._heartbeat_task:
self._heartbeat_task.cancel()
self._heartbeat_task = None
logger.info("Scheduler stopped")