feat: Add evaluation hooks, skill adaptation and team pipeline config
- Add EvaluationHook for post-execution agent evaluation - Add SkillAdaptationHook for dynamic skill adaptation - Add team/ directory with team coordination logic - Add TEAM_PIPELINE.yaml for smoke_fullstack pipeline config - Update RuntimeView, TraderView and RuntimeSettingsPanel UI - Add runtimeApi and websocket services - Add runtime_state.json to smoke_fullstack state Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
@@ -21,6 +23,26 @@ from backend.core.state_sync import StateSync
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
from backend.runtime.manager import TradingRuntimeManager
|
||||
from backend.runtime.session import TradingSessionKey
|
||||
from backend.agents.team_pipeline_config import (
|
||||
resolve_active_analysts,
|
||||
update_active_analysts,
|
||||
)
|
||||
from backend.agents import AnalystAgent
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
|
||||
# Team infrastructure imports (graceful import - may not exist yet)
|
||||
try:
|
||||
from backend.agents.team.team_coordinator import TeamCoordinator
|
||||
from backend.agents.team.msg_hub import MsgHub as TeamMsgHub
|
||||
TEAM_COORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TEAM_COORD_AVAILABLE = False
|
||||
TeamCoordinator = None
|
||||
TeamMsgHub = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,6 +99,13 @@ class TradingPipeline:
|
||||
self.agent_factory = agent_factory
|
||||
self.runtime_manager = runtime_manager
|
||||
self._session_key: Optional[str] = None
|
||||
self._dynamic_analysts: Dict[str, Any] = {}
|
||||
|
||||
if hasattr(self.pm, "set_team_controller"):
|
||||
self.pm.set_team_controller(
|
||||
create_agent_callback=self._create_runtime_analyst,
|
||||
remove_agent_callback=self._remove_runtime_analyst,
|
||||
)
|
||||
|
||||
async def run_cycle(
|
||||
self,
|
||||
@@ -115,16 +144,17 @@ class TradingPipeline:
|
||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||
session_key = TradingSessionKey(date=date).key()
|
||||
self._session_key = session_key
|
||||
active_analysts = self._get_active_analysts()
|
||||
if self.runtime_manager:
|
||||
self.runtime_manager.set_session_key(session_key)
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||
self._runtime_batch_status(self.analysts, "analysis_in_progress")
|
||||
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||
|
||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
|
||||
participants = self.analysts + [self.risk_manager, self.pm]
|
||||
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||
|
||||
# Single MsgHub for entire cycle - no nesting
|
||||
async with MsgHub(
|
||||
@@ -135,9 +165,13 @@ class TradingPipeline:
|
||||
"system",
|
||||
),
|
||||
):
|
||||
# Phase 1.1: Analysts
|
||||
_log("Phase 1.1: Analyst analysis")
|
||||
analyst_results = await self._run_analysts_with_sync(tickers, date)
|
||||
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Phase 1.2: Risk Manager
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
@@ -164,6 +198,7 @@ class TradingPipeline:
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Record final predictions for leaderboard ranking
|
||||
@@ -212,7 +247,7 @@ class TradingPipeline:
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self.analysts + [self.pm],
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
|
||||
@@ -246,13 +281,13 @@ class TradingPipeline:
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self.analysts + [self.pm],
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
|
||||
_log(f"Cycle complete: {date}")
|
||||
self._runtime_batch_status(
|
||||
self.analysts + [self.risk_manager, self.pm],
|
||||
self._all_analysts() + [self.risk_manager, self.pm],
|
||||
"idle",
|
||||
)
|
||||
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||
@@ -288,7 +323,7 @@ class TradingPipeline:
|
||||
},
|
||||
)
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
analyst.reload_runtime_assets(
|
||||
active_skill_dirs=active_skill_map.get(analyst.name, []),
|
||||
)
|
||||
@@ -302,7 +337,7 @@ class TradingPipeline:
|
||||
|
||||
return {
|
||||
"config_name": config_name,
|
||||
"reloaded_agents": [agent.name for agent in self.analysts]
|
||||
"reloaded_agents": [agent.name for agent in self._all_analysts()]
|
||||
+ ["risk_manager", "portfolio_manager"],
|
||||
"active_skills": {
|
||||
agent_id: [path.name for path in paths]
|
||||
@@ -313,7 +348,7 @@ class TradingPipeline:
|
||||
|
||||
async def _clear_all_agent_memory(self):
|
||||
"""Clear short-term memory for all agents"""
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
await analyst.memory.clear()
|
||||
|
||||
await self.risk_manager.memory.clear()
|
||||
@@ -395,7 +430,7 @@ class TradingPipeline:
|
||||
trajectories = {}
|
||||
|
||||
# Capture analyst trajectories
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
try:
|
||||
msgs = await analyst.memory.get_memory()
|
||||
if msgs:
|
||||
@@ -605,7 +640,7 @@ class TradingPipeline:
|
||||
)
|
||||
|
||||
# Record for analysts
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
if (
|
||||
hasattr(analyst, "long_term_memory")
|
||||
and analyst.long_term_memory is not None
|
||||
@@ -724,67 +759,82 @@ class TradingPipeline:
|
||||
date=date,
|
||||
)
|
||||
|
||||
# Run discussion cycles (no new MsgHub - use parent's)
|
||||
for cycle in range(self.max_comm_cycles):
|
||||
# Conference participants: analysts + PM
|
||||
conference_participants = self._get_active_analysts() + [self.pm]
|
||||
|
||||
# Use TeamMsgHub for conference if available
|
||||
if TEAM_COORD_AVAILABLE and TeamMsgHub is not None:
|
||||
_log(
|
||||
"Phase 2.1: Conference discussion - "
|
||||
f"Conference {cycle + 1}/{self.max_comm_cycles}",
|
||||
f"Phase 2.1: Conference using TeamMsgHub with "
|
||||
f"{len(conference_participants)} participants"
|
||||
)
|
||||
conference_hub = TeamMsgHub(participants=conference_participants)
|
||||
else:
|
||||
_log("Phase 2.1: Conference using standard MsgHub context")
|
||||
conference_hub = None
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_start(
|
||||
cycle=cycle + 1,
|
||||
total_cycles=self.max_comm_cycles,
|
||||
# Run discussion cycles
|
||||
async with conference_hub if conference_hub else nullcontext(None):
|
||||
for cycle in range(self.max_comm_cycles):
|
||||
_log(
|
||||
"Phase 2.1: Conference discussion - "
|
||||
f"Conference {cycle + 1}/{self.max_comm_cycles}",
|
||||
)
|
||||
|
||||
# PM sets agenda or asks questions
|
||||
pm_prompt = self._build_pm_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_start(
|
||||
cycle=cycle + 1,
|
||||
total_cycles=self.max_comm_cycles,
|
||||
)
|
||||
|
||||
pm_msg = Msg(name="system", content=pm_prompt, role="user")
|
||||
pm_response = await self.pm.reply(pm_msg)
|
||||
|
||||
if self.state_sync:
|
||||
pm_content = self._extract_text_content(pm_response.content)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id="portfolio_manager",
|
||||
content=pm_content,
|
||||
)
|
||||
|
||||
# Analysts share perspectives
|
||||
for analyst in self.analysts:
|
||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||
# PM sets agenda or asks questions
|
||||
pm_prompt = self._build_pm_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
|
||||
analyst_msg = Msg(
|
||||
name="system",
|
||||
content=analyst_prompt,
|
||||
role="user",
|
||||
)
|
||||
analyst_response = await analyst.reply(analyst_msg)
|
||||
pm_msg = Msg(name="system", content=pm_prompt, role="user")
|
||||
pm_response = await self.pm.reply(pm_msg)
|
||||
|
||||
if self.state_sync:
|
||||
analyst_content = self._extract_text_content(
|
||||
analyst_response.content,
|
||||
)
|
||||
pm_content = self._extract_text_content(pm_response.content)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id=analyst.name,
|
||||
content=analyst_content,
|
||||
agent_id="portfolio_manager",
|
||||
content=pm_content,
|
||||
)
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_end(
|
||||
cycle=cycle + 1,
|
||||
)
|
||||
# Analysts share perspectives (supports per-round active team updates)
|
||||
for analyst in self._get_active_analysts():
|
||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
)
|
||||
|
||||
analyst_msg = Msg(
|
||||
name="system",
|
||||
content=analyst_prompt,
|
||||
role="user",
|
||||
)
|
||||
analyst_response = await analyst.reply(analyst_msg)
|
||||
|
||||
if self.state_sync:
|
||||
analyst_content = self._extract_text_content(
|
||||
analyst_response.content,
|
||||
)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id=analyst.name,
|
||||
content=analyst_content,
|
||||
)
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_end(
|
||||
cycle=cycle + 1,
|
||||
)
|
||||
|
||||
# Generate conference summary by PM
|
||||
_log(
|
||||
@@ -885,6 +935,7 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Collect final predictions from all analysts as simple text responses.
|
||||
@@ -892,14 +943,15 @@ class TradingPipeline:
|
||||
"""
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Starting _collect_final_predictions for {len(self.analysts)} analysts",
|
||||
f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts",
|
||||
)
|
||||
final_predictions = []
|
||||
|
||||
for i, analyst in enumerate(self.analysts):
|
||||
analysts = active_analysts or self.analysts
|
||||
for i, analyst in enumerate(analysts):
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}",
|
||||
f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}",
|
||||
)
|
||||
|
||||
prompt = (
|
||||
@@ -995,11 +1047,13 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts with real-time sync after each completion"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -1029,15 +1083,107 @@ class TradingPipeline:
|
||||
|
||||
return results
|
||||
|
||||
async def _run_analysts_parallel(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts in parallel using TeamCoordinator.
|
||||
|
||||
This method replaces the sequential analyst loop with parallel execution
|
||||
using the TeamCoordinator for orchestration.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers to analyze
|
||||
date: Trading date
|
||||
active_analysts: Optional list of analysts to run
|
||||
|
||||
Returns:
|
||||
List of analyst result dictionaries
|
||||
"""
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
if not analysts:
|
||||
return []
|
||||
|
||||
if not TEAM_COORD_AVAILABLE:
|
||||
_log("TeamCoordinator not available, falling back to sequential execution")
|
||||
return await self._run_analysts_with_sync(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Running {len(analysts)} analysts in parallel "
|
||||
f"[{', '.join(a.name for a in analysts)}]"
|
||||
)
|
||||
|
||||
# Build the analyst prompt
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
)
|
||||
|
||||
# Create coordinator for parallel execution
|
||||
coordinator = TeamCoordinator(
|
||||
participants=analysts,
|
||||
task_content=content,
|
||||
)
|
||||
|
||||
# Run analysts in parallel via TeamCoordinator
|
||||
results = await coordinator.run_phase(
|
||||
"analyst_analysis",
|
||||
metadata={"tickers": tickers, "date": date},
|
||||
)
|
||||
|
||||
# Process results and sync
|
||||
processed_results = []
|
||||
for i, (analyst, result) in enumerate(zip(analysts, results)):
|
||||
if result is not None:
|
||||
extracted = self._extract_result_from_msg(result)
|
||||
processed_results.append(extracted)
|
||||
|
||||
# Sync retrieved memory
|
||||
await self._sync_memory_if_retrieved(analyst)
|
||||
|
||||
# Broadcast agent result via StateSync
|
||||
if self.state_sync:
|
||||
text_content = self._extract_text_content(result.content)
|
||||
await self.state_sync.on_agent_complete(
|
||||
agent_id=analyst.name,
|
||||
content=text_content,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Analyst %s returned no result",
|
||||
analyst.name,
|
||||
)
|
||||
processed_results.append({
|
||||
"agent": analyst.name,
|
||||
"content": "",
|
||||
"success": False,
|
||||
})
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Parallel analyst execution complete "
|
||||
f"({len(processed_results)}/{len(analysts)} successful)"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def _run_analysts(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts (without sync, for backward compatibility)"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -1461,6 +1607,83 @@ class TradingPipeline:
|
||||
for agent in agents:
|
||||
self._runtime_update_status(agent, status)
|
||||
|
||||
def _all_analysts(self) -> List[Any]:
|
||||
"""Return static analysts plus runtime-created analysts."""
|
||||
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||
|
||||
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||
"""Create one runtime analyst instance."""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
return (
|
||||
f"Unknown analyst_type '{analyst_type}'. "
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||
)
|
||||
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||
return f"Analyst '{agent_id}' already exists."
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = PromptLoader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
WorkspaceManager(project_root=project_root).ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
role_seed=persona.get("description", "").strip(),
|
||||
style_seed="\n".join(f"- {item}" for item in persona.get("focus", [])),
|
||||
policy_seed=(
|
||||
"State a clear signal, confidence, and the conditions "
|
||||
"that would invalidate the thesis."
|
||||
),
|
||||
)
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=create_agent_toolkit(
|
||||
agent_id=agent_id,
|
||||
config_name=config_name,
|
||||
active_skill_dirs=[],
|
||||
),
|
||||
model=get_agent_model(analyst_type),
|
||||
formatter=get_agent_formatter(analyst_type),
|
||||
agent_id=agent_id,
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
self._dynamic_analysts[agent_id] = agent
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
add=[agent_id],
|
||||
)
|
||||
return f"Created runtime analyst '{agent_id}' ({analyst_type})."
|
||||
|
||||
def _remove_runtime_analyst(self, agent_id: str) -> str:
|
||||
"""Remove one runtime-created analyst instance."""
|
||||
if agent_id not in self._dynamic_analysts:
|
||||
return f"Runtime analyst '{agent_id}' not found."
|
||||
self._dynamic_analysts.pop(agent_id, None)
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
remove=[agent_id],
|
||||
)
|
||||
return f"Removed runtime analyst '{agent_id}'."
|
||||
|
||||
def _get_active_analysts(self) -> List[Any]:
|
||||
"""Resolve active analyst participants from run-scoped team pipeline config."""
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
analyst_map = {agent.name: agent for agent in self._all_analysts()}
|
||||
active_ids = resolve_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(analyst_map.keys()),
|
||||
)
|
||||
return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map]
|
||||
|
||||
def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self.runtime_manager:
|
||||
return
|
||||
|
||||
@@ -61,7 +61,7 @@ def stop_gateway() -> None:
|
||||
_gateway_instance = None
|
||||
|
||||
|
||||
async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||
try:
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
@@ -206,6 +206,13 @@ async def run_pipeline(
|
||||
"""
|
||||
Run the trading pipeline with the given configuration.
|
||||
|
||||
Service Startup Order:
|
||||
Phase 1: WebSocket Server - Frontend can connect
|
||||
Phase 2: Market Service - Price data starts flowing
|
||||
Phase 3: Agent Runtime - Create all agents
|
||||
Phase 4: Pipeline & Scheduler - Trading logic ready
|
||||
Phase 5: Gateway Fully Operational - All systems running
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier (timestamp)
|
||||
run_dir: Run directory path
|
||||
@@ -219,7 +226,9 @@ async def run_pipeline(
|
||||
# Set global shutdown event
|
||||
set_shutdown_event(stop_event)
|
||||
|
||||
logger.info(f"[Pipeline {run_id}] Starting...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
|
||||
try:
|
||||
# Extract config values
|
||||
@@ -230,15 +239,21 @@ async def run_pipeline(
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
enable_mock = bootstrap.get("enable_mock", False)
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
is_mock = mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 0: Initialize runtime manager
|
||||
# ======================================================================
|
||||
logger.info("[Phase 0/5] Initializing runtime manager...")
|
||||
|
||||
# Get or create runtime manager
|
||||
from backend.api.runtime import runtime_manager
|
||||
|
||||
if runtime_manager is None:
|
||||
@@ -255,16 +270,11 @@ async def run_pipeline(
|
||||
from backend.api.runtime import register_runtime_manager
|
||||
register_runtime_manager(runtime_manager)
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
mock_mode=is_mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
# ======================================================================
|
||||
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||
# These will be started by Gateway in the correct order
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1-2/5] Creating infrastructure services...")
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
@@ -278,7 +288,22 @@ async def run_pipeline(
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents and pipeline
|
||||
# Create market service (data source)
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
mock_mode=is_mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Create Agent Runtime
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/5] Creating agent runtime...")
|
||||
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
@@ -303,6 +328,11 @@ async def run_pipeline(
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Create Pipeline & Scheduler
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
@@ -336,6 +366,7 @@ async def run_pipeline(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
@@ -344,7 +375,15 @@ async def run_pipeline(
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Create Gateway for WebSocket connections (after pipeline and scheduler are ready)
|
||||
# ======================================================================
|
||||
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
|
||||
# Gateway.start() will handle the final startup sequence:
|
||||
# - WebSocket Server first (frontend can connect)
|
||||
# - Market Service second (price data flows)
|
||||
# - Scheduler last (trading begins)
|
||||
# ======================================================================
|
||||
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
|
||||
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
@@ -359,6 +398,7 @@ async def run_pipeline(
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
@@ -374,13 +414,17 @@ async def run_pipeline(
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Start Gateway in background task
|
||||
# Start Gateway - this will execute the 4-phase startup:
|
||||
# Phase 1: WebSocket Server (frontend can connect immediately)
|
||||
# Phase 2: Market Service (price updates start flowing)
|
||||
# Phase 3: Market Status Monitor
|
||||
# Phase 4: Scheduler (trading cycles begin)
|
||||
gateway_task = asyncio.create_task(
|
||||
gateway.start(host="0.0.0.0", port=8765)
|
||||
)
|
||||
logger.info("[Pipeline] Gateway started on ws://localhost:8765")
|
||||
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
|
||||
|
||||
# Give Gateway a moment to start
|
||||
# Wait for Gateway to fully initialize all phases
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Define the trading cycle callback
|
||||
|
||||
@@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -28,17 +28,21 @@ class Scheduler:
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
heartbeat_interval: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._callback: Optional[Callable] = None
|
||||
self._heartbeat_callback: Optional[Callable] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
@@ -53,6 +57,15 @@ class Scheduler:
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _is_trading_hours(self, now: datetime) -> bool:
|
||||
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
|
||||
market_time = now.time()
|
||||
return time(9, 30) <= market_time <= time(16, 0)
|
||||
|
||||
def set_heartbeat_callback(self, callback: Callable) -> None:
|
||||
"""Register callback for heartbeat triggers."""
|
||||
self._heartbeat_callback = callback
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
@@ -72,6 +85,13 @@ class Scheduler:
|
||||
self._callback = callback
|
||||
self._schedule_task()
|
||||
|
||||
# Start heartbeat loop if configured
|
||||
if self.heartbeat_interval and self._heartbeat_callback:
|
||||
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
|
||||
logger.info(
|
||||
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
@@ -132,6 +152,30 @@ class Scheduler:
|
||||
|
||||
return changed
|
||||
|
||||
async def _run_heartbeat_loop(self):
|
||||
"""Run heartbeat checks on a separate interval during trading hours."""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
if self._is_trading_day(now) and self._is_trading_hours(now):
|
||||
if self._heartbeat_callback:
|
||||
try:
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.debug(
|
||||
f"[Heartbeat] Triggering heartbeat check for {current_date}",
|
||||
)
|
||||
await self._heartbeat_callback(date=current_date)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Heartbeat] Callback failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Heartbeat] Callback not set, skipping heartbeat",
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
@@ -206,6 +250,9 @@ class Scheduler:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user