feat: add runtime dynamic team controls

This commit is contained in:
2026-04-03 13:48:31 +08:00
parent dc0b250adc
commit ecfbd87244
16 changed files with 2146 additions and 147 deletions

View File

@@ -33,6 +33,8 @@ from backend.agents.workspace_manager import WorkspaceManager
from backend.agents.prompt_loader import get_prompt_loader
from backend.llm.models import get_agent_formatter, get_agent_model
from backend.config.constants import ANALYST_TYPES
from backend.agents.dynamic_team_types import AnalystConfig
from backend.tools.dynamic_team_tools import DynamicTeamController, set_controller
def _resolve_evo_agent_ids() -> set[str]:
@@ -84,6 +86,9 @@ def _log(msg: str) -> None:
logger.info(msg)
from backend.core.apo import PolicyOptimizer
class TradingPipeline:
"""
Trading Pipeline - Orchestrates the complete trading cycle
@@ -127,7 +132,21 @@ class TradingPipeline:
self.runtime_manager = runtime_manager
self._session_key: Optional[str] = None
self._dynamic_analysts: Dict[str, Any] = {}
self._dynamic_analyst_configs: Dict[str, AnalystConfig] = {}
# Initialize APO (Autonomous Policy Optimizer)
config_name = workspace_id or (runtime_manager.config_name if runtime_manager else "default")
self.apo = PolicyOptimizer(config_name=config_name)
# Initialize dynamic team controller and inject into PM
self._team_controller = DynamicTeamController(
create_callback=self._create_runtime_analyst,
remove_callback=self._remove_runtime_analyst,
get_analysts_callback=self._all_analysts,
)
set_controller(self._team_controller)
# Backward compatibility: also set individual callbacks if PM expects them
if hasattr(self.pm, "set_team_controller"):
self.pm.set_team_controller(
create_agent_callback=self._create_runtime_analyst,
@@ -150,23 +169,7 @@ class TradingPipeline:
execute_decisions: bool = True,
) -> Dict[str, Any]:
"""
Run one complete trading cycle
Args:
tickers: List of stock tickers
date: Trading date (YYYY-MM-DD)
prices: Open prices {ticker: price} (for backtest)
close_prices: Close prices for settlement (for backtest)
market_caps: Optional market caps for baseline calculation
get_open_prices_fn: Async callback to wait for open prices (live mode)
get_close_prices_fn: Async callback to wait for close prices (live mode)
For live mode:
- Analysis runs immediately
- Execution waits for market open via get_open_prices_fn
- Settlement waits for market close via get_close_prices_fn
Each agent's result is broadcast immediately via StateSync.
Run one complete trading cycle with checkpointing support.
"""
_log(f"Starting cycle {date} - {len(tickers)} tickers")
session_key = TradingSessionKey(date=date).key()
@@ -176,14 +179,45 @@ class TradingPipeline:
agents=active_analysts + [self.risk_manager, self.pm],
session_key=session_key,
)
# Load checkpoint if exists
checkpoint = self._load_checkpoint(session_key)
checkpoint_data = checkpoint.get("data", {}) if checkpoint else {}
last_phase = checkpoint.get("phase") if checkpoint else None
if checkpoint:
_log(f"Resuming from checkpoint: {last_phase}")
# Restore state from checkpoint
analyst_results = checkpoint_data.get("analyst_results", [])
risk_assessment = checkpoint_data.get("risk_assessment", {})
self.conference_summary = checkpoint_data.get("conference_summary")
final_predictions = checkpoint_data.get("final_predictions", [])
pm_result = checkpoint_data.get("pm_result", {})
execution_result = checkpoint_data.get("execution_result", {})
settlement_result = checkpoint_data.get("settlement_result")
# Prefer passed prices if not hold in checkpoint
if not prices:
prices = checkpoint_data.get("prices")
if not close_prices:
close_prices = checkpoint_data.get("close_prices")
else:
analyst_results = []
risk_assessment = {}
self.conference_summary = None
final_predictions = []
pm_result = {}
execution_result = {}
settlement_result = None
if self.runtime_manager:
self.runtime_manager.set_session_key(session_key)
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date, "resumed": checkpoint is not None})
self._runtime_batch_status(active_analysts, "analysis_in_progress")
# Phase 0: Clear short-term memory to avoid cross-day context pollution
_log("Phase 0: Clearing memory")
await self._clear_all_agent_memory()
# Phase 0: Clear memory (only if not resuming or if resuming from very start)
if not last_phase:
_log("Phase 0: Clearing memory")
await self._clear_all_agent_memory()
participants = self._all_analysts() + [self.risk_manager, self.pm]
@@ -196,125 +230,219 @@ class TradingPipeline:
"system",
),
):
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
_log("Phase 1.1: Analyst analysis (parallel)")
analyst_results = await self._run_analysts_parallel(
tickers,
date,
active_analysts=active_analysts,
)
# Phase 1.1: Analysts
if not last_phase or last_phase == "cleared":
_log("Phase 1.1: Analyst analysis (parallel)")
analyst_results = await self._run_analysts_parallel(
tickers,
date,
active_analysts=active_analysts,
)
self._save_checkpoint(session_key, "analysis", {
"analyst_results": analyst_results,
"prices": prices,
"close_prices": close_prices
})
last_phase = "analysis"
# Phase 1.2: Risk Manager
_log("Phase 1.2: Risk assessment")
self._runtime_update_status(self.risk_manager, "risk_assessment")
risk_assessment = await self._run_risk_manager_with_sync(
tickers,
date,
prices,
)
if last_phase == "analysis":
_log("Phase 1.2: Risk assessment")
self._runtime_update_status(self.risk_manager, "risk_assessment")
risk_assessment = await self._run_risk_manager_with_sync(
tickers,
date,
prices,
)
self._save_checkpoint(session_key, "risk_assessment", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"prices": prices,
"close_prices": close_prices
})
last_phase = "risk_assessment"
# Phase 2.1: Conference discussion (within same MsgHub)
_log("Phase 2.1: Conference discussion")
conference_summary = await self._run_conference_cycles(
tickers=tickers,
date=date,
prices=prices,
analyst_results=analyst_results,
risk_assessment=risk_assessment,
)
self.conference_summary = conference_summary
# Phase 2.1: Conference discussion
if last_phase == "risk_assessment":
_log("Phase 2.1: Conference discussion")
conference_summary = await self._run_conference_cycles(
tickers=tickers,
date=date,
prices=prices,
analyst_results=analyst_results,
risk_assessment=risk_assessment,
)
self.conference_summary = conference_summary
self._save_checkpoint(session_key, "conference", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"prices": prices,
"close_prices": close_prices
})
last_phase = "conference"
# Phase 2.2: Analysts generate final structured predictions
_log("Phase 2.2: Analysts generate final structured predictions")
final_predictions = await self._collect_final_predictions(
tickers,
date,
active_analysts=active_analysts,
)
if last_phase == "conference":
_log("Phase 2.2: Analysts generate final structured predictions")
final_predictions = await self._collect_final_predictions(
tickers,
date,
active_analysts=active_analysts,
)
self._save_checkpoint(session_key, "predictions", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"prices": prices,
"close_prices": close_prices
})
last_phase = "predictions"
# Record final predictions for leaderboard ranking
if self.settlement_coordinator:
# Record final predictions
if last_phase == "predictions" and self.settlement_coordinator:
self.settlement_coordinator.record_analyst_predictions(
final_predictions,
)
# Live mode: wait for market open before execution
if get_open_prices_fn:
# Live mode: wait for market open
if not prices and get_open_prices_fn:
_log("Waiting for market open...")
prices = await get_open_prices_fn()
_log(f"Got open prices: {prices}")
# Update prices in checkpoint if we just got them
self._save_checkpoint(session_key, "predictions", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"prices": prices,
"close_prices": close_prices
})
# Phase 3: PM makes decisions
_log("Phase 3.1: PM makes decisions")
self._runtime_update_status(self.pm, "decision_phase")
pm_result = await self._run_pm_with_sync(
tickers,
date,
prices,
analyst_results,
risk_assessment,
)
if last_phase == "predictions":
_log("Phase 3.1: PM makes decisions")
self._runtime_update_status(self.pm, "decision_phase")
pm_result = await self._run_pm_with_sync(
tickers,
date,
prices,
analyst_results,
risk_assessment,
)
self._save_checkpoint(session_key, "decisions", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"pm_result": pm_result,
"prices": prices,
"close_prices": close_prices
})
last_phase = "decisions"
decisions = pm_result.get("decisions", {})
execution_result = {
"executed_trades": [],
"portfolio": self.pm.get_portfolio_state(),
}
if execute_decisions:
_log("Phase 4: Executing trades")
self._runtime_update_status(self.pm, "executing")
execution_result = self._execute_decisions(decisions, prices, date)
else:
_log("Phase 4: Skipping trade execution")
# Outside MsgHub for execution and settlement
decisions = pm_result.get("decisions", {}) if pm_result else {}
if not execution_result:
execution_result = {
"executed_trades": [],
"portfolio": self.pm.get_portfolio_state(),
}
# Live mode: wait for market close before settlement
if get_close_prices_fn:
if last_phase == "decisions":
if execute_decisions:
_log("Phase 4: Executing trades")
self._runtime_update_status(self.pm, "executing")
execution_result = self._execute_decisions(decisions, prices, date)
else:
_log("Phase 4: Skipping trade execution")
self._save_checkpoint(session_key, "execution", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"pm_result": pm_result,
"execution_result": execution_result,
"prices": prices,
"close_prices": close_prices
})
last_phase = "execution"
# Live mode: wait for market close
if not close_prices and get_close_prices_fn:
_log("Waiting for market close")
close_prices = await get_close_prices_fn()
_log(f"Got close prices: {close_prices}")
# Update close_prices in checkpoint
self._save_checkpoint(session_key, "execution", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"pm_result": pm_result,
"execution_result": execution_result,
"prices": prices,
"close_prices": close_prices
})
# Phase 5: Settlement - run after close prices available
settlement_result = None
if close_prices and self.settlement_coordinator:
_log("Phase 5: Daily review and generate memories")
self._runtime_batch_status(
[self.risk_manager] + self._all_analysts() + [self.pm],
"settlement",
)
# Phase 5: Settlement
if last_phase == "execution":
if close_prices and self.settlement_coordinator:
_log("Phase 5: Daily review and generate memories")
self._runtime_batch_status(
[self.risk_manager] + self._all_analysts() + [self.pm],
"settlement",
)
agent_trajectories = await self._capture_agent_trajectories()
agent_trajectories = await self._capture_agent_trajectories()
if market_caps is None:
market_caps = {ticker: 1e9 for ticker in tickers}
if market_caps is None:
market_caps = {ticker: 1e9 for ticker in tickers}
settlement_result = (
self.settlement_coordinator.run_daily_settlement(
settlement_result = (
self.settlement_coordinator.run_daily_settlement(
date=date,
tickers=tickers,
open_prices=prices,
close_prices=close_prices,
market_caps=market_caps,
agent_portfolio=execution_result.get("portfolio", {}),
analyst_results=analyst_results,
pm_decisions=decisions,
)
)
await self._run_reflection(
date=date,
tickers=tickers,
agent_trajectories=agent_trajectories,
analyst_results=analyst_results,
decisions=decisions,
executed_trades=execution_result.get("executed_trades", []),
open_prices=prices,
close_prices=close_prices,
market_caps=market_caps,
agent_portfolio=execution_result.get("portfolio", {}),
analyst_results=analyst_results,
pm_decisions=decisions,
settlement_result=settlement_result,
conference_summary=self.conference_summary,
)
)
await self._run_reflection(
date=date,
agent_trajectories=agent_trajectories,
analyst_results=analyst_results,
decisions=decisions,
executed_trades=execution_result.get("executed_trades", []),
open_prices=prices,
close_prices=close_prices,
settlement_result=settlement_result,
conference_summary=self.conference_summary,
)
self._runtime_batch_status(
[self.risk_manager] + self._all_analysts() + [self.pm],
"reflection",
)
self._runtime_batch_status(
[self.risk_manager] + self._all_analysts() + [self.pm],
"reflection",
)
self._save_checkpoint(session_key, "settlement", {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
"conference_summary": conference_summary,
"final_predictions": final_predictions,
"pm_result": pm_result,
"execution_result": execution_result,
"settlement_result": settlement_result,
"prices": prices,
"close_prices": close_prices
})
last_phase = "settlement"
_log(f"Cycle complete: {date}")
self._runtime_batch_status(
@@ -323,6 +451,11 @@ class TradingPipeline:
)
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
# Optional: Clean up checkpoint after successful completion
# path = self._get_checkpoint_path(session_key)
# if path and path.exists():
# path.unlink()
return {
"analyst_results": analyst_results,
"risk_assessment": risk_assessment,
@@ -385,6 +518,44 @@ class TradingPipeline:
await self.risk_manager.memory.clear()
await self.pm.memory.clear()
def _get_checkpoint_path(self, session_key: str) -> Optional[Path]:
"""Get the path to the pipeline checkpoint file."""
if not self.runtime_manager or not self.runtime_manager.run_dir:
return None
checkpoint_dir = self.runtime_manager.run_dir / "state" / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
return checkpoint_dir / f"pipeline_{session_key}.json"
def _save_checkpoint(self, session_key: str, phase: str, data: Dict[str, Any]) -> None:
"""Save the current pipeline state to a checkpoint file."""
path = self._get_checkpoint_path(session_key)
if not path:
return
checkpoint = {
"session_key": session_key,
"phase": phase,
"timestamp": datetime.now().isoformat(),
"data": data
}
try:
path.write_text(json.dumps(checkpoint, ensure_ascii=False, indent=2, default=str), encoding="utf-8")
_log(f"Checkpoint saved: {phase} for {session_key}")
except Exception as e:
logger.error(f"Failed to save checkpoint: {e}")
def _load_checkpoint(self, session_key: str) -> Optional[Dict[str, Any]]:
"""Load the pipeline state from a checkpoint file."""
path = self._get_checkpoint_path(session_key)
if not path or not path.exists():
return None
try:
return json.loads(path.read_text(encoding="utf-8"))
except Exception as e:
logger.error(f"Failed to load checkpoint: {e}")
return None
async def _sync_memory_if_retrieved(self, agent: Any) -> None:
"""
Check agent's short-term memory for retrieved long-term memory and sync to frontend.
@@ -585,6 +756,25 @@ class TradingPipeline:
content=reflection_content,
)
# Phase 6: APO (Autonomous Policy Optimization)
# If the day was a loss, let APO suggest and apply policy updates.
if hasattr(self, "apo") and self.apo:
_log(f"Phase 6: APO - Running autonomous policy optimization for {date}")
try:
apo_result = await self.apo.run_optimization(
date=date,
reflection_content=reflection_content,
settlement_result=settlement_result or {"portfolio_value": 100000.0 + total_pnl},
analyst_results=analyst_results,
decisions=decisions
)
if apo_result.get("status") == "completed":
_log(f"APO: Successfully applied {len(apo_result.get('optimizations', []))} policy updates.")
# Reload assets for next cycle to ensure they are picked up
self.reload_runtime_assets()
except Exception as e:
logger.error(f"APO: Optimization failed: {e}")
def _build_reflection_content(
self,
date: str,
@@ -1562,28 +1752,74 @@ class TradingPipeline:
"""Return static analysts plus runtime-created analysts."""
return list(self.analysts) + list(self._dynamic_analysts.values())
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
"""Create one runtime analyst instance."""
if analyst_type not in ANALYST_TYPES:
def _create_runtime_analyst(
self,
agent_id: str,
analyst_type: str,
custom_config: Optional[AnalystConfig] = None,
) -> str:
"""Create one runtime analyst instance.
Args:
agent_id: Unique identifier for the new analyst
analyst_type: Type of analyst (e.g., "technical_analyst")
custom_config: Optional custom configuration for the analyst,
including persona, soul_md, agents_md, etc.
Returns:
Success or error message
"""
# Validate analyst_type or custom_config
if analyst_type not in ANALYST_TYPES and not custom_config:
return (
f"Unknown analyst_type '{analyst_type}'. "
f"Available: {', '.join(ANALYST_TYPES.keys())}"
f"Available: {', '.join(ANALYST_TYPES.keys())}. "
f"Or provide custom_config to create a custom analyst."
)
if agent_id in {agent.name for agent in self._all_analysts()}:
return f"Analyst '{agent_id}' already exists."
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
persona = personas.get(analyst_type, {})
# Get persona: use custom_config if provided, else load from personas.yaml
if custom_config and custom_config.persona:
persona = {
"name": custom_config.persona.name,
"focus": custom_config.persona.focus,
"description": custom_config.persona.description,
}
else:
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
persona = personas.get(analyst_type, {})
workspace_manager = WorkspaceManager(project_root=project_root)
# Build file contents: use custom if provided, else generate from persona
file_contents = {}
if custom_config:
if custom_config.soul_md:
file_contents["SOUL.md"] = custom_config.soul_md
if custom_config.agents_md:
file_contents["AGENTS.md"] = custom_config.agents_md
if custom_config.profile_md:
file_contents["PROFILE.md"] = custom_config.profile_md
if custom_config.bootstrap_md:
file_contents["BOOTSTRAP.md"] = custom_config.bootstrap_md
# Fill in any missing files with defaults
if not file_contents or len(file_contents) < 4:
default_files = workspace_manager.build_default_agent_files(
agent_id=agent_id,
persona=persona,
)
for key, value in default_files.items():
if key not in file_contents:
file_contents[key] = value
workspace_manager.ensure_agent_assets(
config_name=config_name,
agent_id=agent_id,
file_contents=workspace_manager.build_default_agent_files(
agent_id=agent_id,
persona=persona,
),
file_contents=file_contents,
)
# Create EvoAgent with workspace-driven configuration
@@ -1594,11 +1830,23 @@ class TradingPipeline:
agent_id,
)
agent_config = load_agent_workspace_config(workspace_dir / "agent.yaml")
# Support model override from custom_config
if custom_config and custom_config.model_name:
# Import create_model for custom model creation
from backend.llm.models import create_model
# Use specified model name, default to openai provider
model = create_model(
model_name=custom_config.model_name,
model_provider=custom_config.memory_config.get("model_provider", "openai") if custom_config.memory_config else "openai"
)
else:
model = get_agent_model(analyst_type)
agent = EvoAgent(
agent_id=agent_id,
config_name=config_name,
workspace_dir=workspace_dir,
model=get_agent_model(analyst_type),
model=model,
formatter=get_agent_formatter(analyst_type),
prompt_files=agent_config.prompt_files,
)
@@ -1611,6 +1859,11 @@ class TradingPipeline:
# Keep workspace_id for backward compatibility
setattr(agent, "workspace_id", config_name)
self._dynamic_analysts[agent_id] = agent
# Store custom config for future reference (e.g., cloning)
if custom_config:
self._dynamic_analyst_configs[agent_id] = custom_config
update_active_analysts(
project_root=project_root,
config_name=config_name,
@@ -1624,6 +1877,8 @@ class TradingPipeline:
if agent_id not in self._dynamic_analysts:
return f"Runtime analyst '{agent_id}' not found."
self._dynamic_analysts.pop(agent_id, None)
# Also remove stored config if exists
self._dynamic_analyst_configs.pop(agent_id, None)
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
update_active_analysts(