feat: add runtime dynamic team controls
This commit is contained in:
197
backend/core/apo.py
Normal file
197
backend/core/apo.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Autonomous Policy Optimizer (APO)
|
||||
Automatically tunes agent policies based on performance feedback.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
from backend.llm.models import get_agent_model, get_agent_formatter
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PolicyOptimizer:
|
||||
"""
|
||||
PolicyOptimizer analyzes trading performance and automatically updates
|
||||
agent workspace files (POLICY.md, AGENTS.md) to improve future results.
|
||||
"""
|
||||
|
||||
def __init__(self, config_name: str, project_root: Optional[Path] = None):
|
||||
self.config_name = config_name
|
||||
self.workspace_manager = WorkspaceManager(project_root=project_root)
|
||||
# Use a high-capability model for the optimizer (meta-agent)
|
||||
self.model = get_agent_model("portfolio_manager")
|
||||
self.formatter = get_agent_formatter("portfolio_manager")
|
||||
|
||||
async def run_optimization(
|
||||
self,
|
||||
date: str,
|
||||
reflection_content: str,
|
||||
settlement_result: Dict[str, Any],
|
||||
analyst_results: List[Dict[str, Any]],
|
||||
decisions: Dict[str, Dict],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the optimization loop if performance indicates a need for change.
|
||||
"""
|
||||
total_pnl = settlement_result.get("portfolio_value", 0) - 100000.0 # Assuming 100k initial
|
||||
# You might want to use a more sophisticated trigger, like 3 consecutive losses
|
||||
if total_pnl >= 0:
|
||||
logger.info(f"APO: Positive P&L (${total_pnl:,.2f}) for {date}, skipping optimization.")
|
||||
return {"status": "skipped", "reason": "positive_pnl"}
|
||||
|
||||
logger.info(f"APO: Negative P&L (${total_pnl:,.2f}) detected for {date}. Starting optimization...")
|
||||
|
||||
# 1. Identify underperforming agents or logic
|
||||
# 2. Generate policy updates
|
||||
# 3. Apply updates
|
||||
|
||||
optimizations = []
|
||||
|
||||
# Focus on agents that gave high confidence but wrong direction
|
||||
underperformers = self._identify_underperformers(settlement_result, analyst_results)
|
||||
|
||||
for agent_id in underperformers:
|
||||
update = await self._generate_policy_update(
|
||||
agent_id,
|
||||
date,
|
||||
reflection_content,
|
||||
settlement_result,
|
||||
analyst_results,
|
||||
decisions
|
||||
)
|
||||
if update:
|
||||
self._apply_update(agent_id, update)
|
||||
optimizations.append({
|
||||
"agent_id": agent_id,
|
||||
"file": update.get("file", "POLICY.md"),
|
||||
"change": update.get("change", "")
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"date": date,
|
||||
"total_pnl": total_pnl,
|
||||
"optimizations": optimizations
|
||||
}
|
||||
|
||||
def _identify_underperformers(
|
||||
self,
|
||||
settlement_result: Dict[str, Any],
|
||||
analyst_results: List[Dict[str, Any]]
|
||||
) -> List[str]:
|
||||
"""Identify which agents might need policy adjustments."""
|
||||
underperformers = []
|
||||
|
||||
# Simple logic: if the overall day was a loss, all active analysts might need a check,
|
||||
# but specifically those whose predictions didn't match the market.
|
||||
# For now, let's include all analysts involved in the day.
|
||||
for result in analyst_results:
|
||||
agent_id = result.get("agent")
|
||||
if agent_id:
|
||||
underperformers.append(agent_id)
|
||||
|
||||
# Also include PM and Risk Manager as they are critical
|
||||
underperformers.append("portfolio_manager")
|
||||
underperformers.append("risk_manager")
|
||||
|
||||
return list(set(underperformers))
|
||||
|
||||
async def _generate_policy_update(
|
||||
self,
|
||||
agent_id: str,
|
||||
date: str,
|
||||
reflection_content: str,
|
||||
settlement_result: Dict[str, Any],
|
||||
analyst_results: List[Dict[str, Any]],
|
||||
decisions: Dict[str, Dict],
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Use LLM to generate a specific policy update for an agent."""
|
||||
|
||||
# Load current policy
|
||||
try:
|
||||
current_policy = self.workspace_manager.load_agent_file(
|
||||
config_name=self.config_name,
|
||||
agent_id=agent_id,
|
||||
filename="POLICY.md"
|
||||
)
|
||||
except Exception:
|
||||
current_policy = "No existing policy found."
|
||||
|
||||
prompt = f"""
|
||||
As an Expert Meta-Optimizer for a multi-agent trading system, your task is to update the operational POLICY for an agent named '{agent_id}' based on recent performance failures.
|
||||
|
||||
[Current Context]
|
||||
Date: {date}
|
||||
Daily Reflection:
|
||||
{reflection_content}
|
||||
|
||||
[Agent's Current POLICY.md]
|
||||
{current_policy}
|
||||
|
||||
[Task]
|
||||
Analyze why the system failed (loss occurred). Identify what '{agent_id}' could have done differently or what new constraint/heuristic should be added to its policy to prevent similar mistakes in the future.
|
||||
|
||||
Provide a specific, concise addition or modification to the POLICY.md file.
|
||||
The output MUST be a JSON object with:
|
||||
1. "reasoning": Brief explanation of why this change is needed.
|
||||
2. "file": Always "POLICY.md".
|
||||
3. "change": The EXACT markdown text to APPEND or REPLACE in the file. Keep it in Chinese as the system uses Chinese prompts.
|
||||
|
||||
Output ONLY the JSON object.
|
||||
"""
|
||||
msg = Msg(name="system", content=prompt, role="user")
|
||||
response = await self.model.reply(msg)
|
||||
|
||||
content = response.content
|
||||
if isinstance(content, list):
|
||||
content = content[0].get("text", "")
|
||||
|
||||
# Clean JSON if wrapped in markdown
|
||||
if "```json" in content:
|
||||
content = content.split("```json")[1].split("```")[0].strip()
|
||||
|
||||
try:
|
||||
return json.loads(content)
|
||||
except Exception as e:
|
||||
logger.error(f"APO: Failed to parse optimization response for {agent_id}: {e}")
|
||||
return None
|
||||
|
||||
def _apply_update(self, agent_id: str, update: Dict[str, str]) -> None:
|
||||
"""Apply the suggested update to the agent's workspace."""
|
||||
filename = update.get("file", "POLICY.md")
|
||||
change = update.get("change", "")
|
||||
|
||||
if not change:
|
||||
return
|
||||
|
||||
try:
|
||||
current_content = self.workspace_manager.load_agent_file(
|
||||
config_name=self.config_name,
|
||||
agent_id=agent_id,
|
||||
filename=filename
|
||||
)
|
||||
|
||||
# Check if change is already there to avoid duplicates
|
||||
if change.strip() in current_content:
|
||||
logger.info(f"APO: Change already present in {agent_id}/{filename}")
|
||||
return
|
||||
|
||||
new_content = current_content + "\n\n### APO Update (" + datetime.now().strftime("%Y-%m-%d") + ")\n" + change
|
||||
|
||||
self.workspace_manager.update_agent_file(
|
||||
config_name=self.config_name,
|
||||
agent_id=agent_id,
|
||||
filename=filename,
|
||||
content=new_content
|
||||
)
|
||||
logger.info(f"APO: Updated {agent_id}/{filename} with new heuristics.")
|
||||
except Exception as e:
|
||||
logger.error(f"APO: Failed to apply update to {agent_id}/{filename}: {e}")
|
||||
@@ -33,6 +33,8 @@ from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.agents.dynamic_team_types import AnalystConfig
|
||||
from backend.tools.dynamic_team_tools import DynamicTeamController, set_controller
|
||||
|
||||
|
||||
def _resolve_evo_agent_ids() -> set[str]:
|
||||
@@ -84,6 +86,9 @@ def _log(msg: str) -> None:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
from backend.core.apo import PolicyOptimizer
|
||||
|
||||
|
||||
class TradingPipeline:
|
||||
"""
|
||||
Trading Pipeline - Orchestrates the complete trading cycle
|
||||
@@ -127,7 +132,21 @@ class TradingPipeline:
|
||||
self.runtime_manager = runtime_manager
|
||||
self._session_key: Optional[str] = None
|
||||
self._dynamic_analysts: Dict[str, Any] = {}
|
||||
self._dynamic_analyst_configs: Dict[str, AnalystConfig] = {}
|
||||
|
||||
# Initialize APO (Autonomous Policy Optimizer)
|
||||
config_name = workspace_id or (runtime_manager.config_name if runtime_manager else "default")
|
||||
self.apo = PolicyOptimizer(config_name=config_name)
|
||||
|
||||
# Initialize dynamic team controller and inject into PM
|
||||
self._team_controller = DynamicTeamController(
|
||||
create_callback=self._create_runtime_analyst,
|
||||
remove_callback=self._remove_runtime_analyst,
|
||||
get_analysts_callback=self._all_analysts,
|
||||
)
|
||||
set_controller(self._team_controller)
|
||||
|
||||
# Backward compatibility: also set individual callbacks if PM expects them
|
||||
if hasattr(self.pm, "set_team_controller"):
|
||||
self.pm.set_team_controller(
|
||||
create_agent_callback=self._create_runtime_analyst,
|
||||
@@ -150,23 +169,7 @@ class TradingPipeline:
|
||||
execute_decisions: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run one complete trading cycle
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
prices: Open prices {ticker: price} (for backtest)
|
||||
close_prices: Close prices for settlement (for backtest)
|
||||
market_caps: Optional market caps for baseline calculation
|
||||
get_open_prices_fn: Async callback to wait for open prices (live mode)
|
||||
get_close_prices_fn: Async callback to wait for close prices (live mode)
|
||||
|
||||
For live mode:
|
||||
- Analysis runs immediately
|
||||
- Execution waits for market open via get_open_prices_fn
|
||||
- Settlement waits for market close via get_close_prices_fn
|
||||
|
||||
Each agent's result is broadcast immediately via StateSync.
|
||||
Run one complete trading cycle with checkpointing support.
|
||||
"""
|
||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||
session_key = TradingSessionKey(date=date).key()
|
||||
@@ -176,14 +179,45 @@ class TradingPipeline:
|
||||
agents=active_analysts + [self.risk_manager, self.pm],
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
# Load checkpoint if exists
|
||||
checkpoint = self._load_checkpoint(session_key)
|
||||
checkpoint_data = checkpoint.get("data", {}) if checkpoint else {}
|
||||
last_phase = checkpoint.get("phase") if checkpoint else None
|
||||
|
||||
if checkpoint:
|
||||
_log(f"Resuming from checkpoint: {last_phase}")
|
||||
# Restore state from checkpoint
|
||||
analyst_results = checkpoint_data.get("analyst_results", [])
|
||||
risk_assessment = checkpoint_data.get("risk_assessment", {})
|
||||
self.conference_summary = checkpoint_data.get("conference_summary")
|
||||
final_predictions = checkpoint_data.get("final_predictions", [])
|
||||
pm_result = checkpoint_data.get("pm_result", {})
|
||||
execution_result = checkpoint_data.get("execution_result", {})
|
||||
settlement_result = checkpoint_data.get("settlement_result")
|
||||
# Prefer passed prices if not hold in checkpoint
|
||||
if not prices:
|
||||
prices = checkpoint_data.get("prices")
|
||||
if not close_prices:
|
||||
close_prices = checkpoint_data.get("close_prices")
|
||||
else:
|
||||
analyst_results = []
|
||||
risk_assessment = {}
|
||||
self.conference_summary = None
|
||||
final_predictions = []
|
||||
pm_result = {}
|
||||
execution_result = {}
|
||||
settlement_result = None
|
||||
|
||||
if self.runtime_manager:
|
||||
self.runtime_manager.set_session_key(session_key)
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date, "resumed": checkpoint is not None})
|
||||
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||
|
||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
# Phase 0: Clear memory (only if not resuming or if resuming from very start)
|
||||
if not last_phase:
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
|
||||
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||
|
||||
@@ -196,125 +230,219 @@ class TradingPipeline:
|
||||
"system",
|
||||
),
|
||||
):
|
||||
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
# Phase 1.1: Analysts
|
||||
if not last_phase or last_phase == "cleared":
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
self._save_checkpoint(session_key, "analysis", {
|
||||
"analyst_results": analyst_results,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "analysis"
|
||||
|
||||
# Phase 1.2: Risk Manager
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||
risk_assessment = await self._run_risk_manager_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
)
|
||||
if last_phase == "analysis":
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||
risk_assessment = await self._run_risk_manager_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
)
|
||||
self._save_checkpoint(session_key, "risk_assessment", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "risk_assessment"
|
||||
|
||||
# Phase 2.1: Conference discussion (within same MsgHub)
|
||||
_log("Phase 2.1: Conference discussion")
|
||||
conference_summary = await self._run_conference_cycles(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
self.conference_summary = conference_summary
|
||||
# Phase 2.1: Conference discussion
|
||||
if last_phase == "risk_assessment":
|
||||
_log("Phase 2.1: Conference discussion")
|
||||
conference_summary = await self._run_conference_cycles(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
self.conference_summary = conference_summary
|
||||
self._save_checkpoint(session_key, "conference", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "conference"
|
||||
|
||||
# Phase 2.2: Analysts generate final structured predictions
|
||||
_log("Phase 2.2: Analysts generate final structured predictions")
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
if last_phase == "conference":
|
||||
_log("Phase 2.2: Analysts generate final structured predictions")
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
self._save_checkpoint(session_key, "predictions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "predictions"
|
||||
|
||||
# Record final predictions for leaderboard ranking
|
||||
if self.settlement_coordinator:
|
||||
# Record final predictions
|
||||
if last_phase == "predictions" and self.settlement_coordinator:
|
||||
self.settlement_coordinator.record_analyst_predictions(
|
||||
final_predictions,
|
||||
)
|
||||
|
||||
# Live mode: wait for market open before execution
|
||||
if get_open_prices_fn:
|
||||
# Live mode: wait for market open
|
||||
if not prices and get_open_prices_fn:
|
||||
_log("Waiting for market open...")
|
||||
prices = await get_open_prices_fn()
|
||||
_log(f"Got open prices: {prices}")
|
||||
# Update prices in checkpoint if we just got them
|
||||
self._save_checkpoint(session_key, "predictions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
|
||||
# Phase 3: PM makes decisions
|
||||
_log("Phase 3.1: PM makes decisions")
|
||||
self._runtime_update_status(self.pm, "decision_phase")
|
||||
pm_result = await self._run_pm_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
analyst_results,
|
||||
risk_assessment,
|
||||
)
|
||||
if last_phase == "predictions":
|
||||
_log("Phase 3.1: PM makes decisions")
|
||||
self._runtime_update_status(self.pm, "decision_phase")
|
||||
pm_result = await self._run_pm_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
analyst_results,
|
||||
risk_assessment,
|
||||
)
|
||||
self._save_checkpoint(session_key, "decisions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "decisions"
|
||||
|
||||
decisions = pm_result.get("decisions", {})
|
||||
execution_result = {
|
||||
"executed_trades": [],
|
||||
"portfolio": self.pm.get_portfolio_state(),
|
||||
}
|
||||
if execute_decisions:
|
||||
_log("Phase 4: Executing trades")
|
||||
self._runtime_update_status(self.pm, "executing")
|
||||
execution_result = self._execute_decisions(decisions, prices, date)
|
||||
else:
|
||||
_log("Phase 4: Skipping trade execution")
|
||||
# Outside MsgHub for execution and settlement
|
||||
decisions = pm_result.get("decisions", {}) if pm_result else {}
|
||||
if not execution_result:
|
||||
execution_result = {
|
||||
"executed_trades": [],
|
||||
"portfolio": self.pm.get_portfolio_state(),
|
||||
}
|
||||
|
||||
# Live mode: wait for market close before settlement
|
||||
if get_close_prices_fn:
|
||||
if last_phase == "decisions":
|
||||
if execute_decisions:
|
||||
_log("Phase 4: Executing trades")
|
||||
self._runtime_update_status(self.pm, "executing")
|
||||
execution_result = self._execute_decisions(decisions, prices, date)
|
||||
else:
|
||||
_log("Phase 4: Skipping trade execution")
|
||||
|
||||
self._save_checkpoint(session_key, "execution", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "execution"
|
||||
|
||||
# Live mode: wait for market close
|
||||
if not close_prices and get_close_prices_fn:
|
||||
_log("Waiting for market close")
|
||||
close_prices = await get_close_prices_fn()
|
||||
_log(f"Got close prices: {close_prices}")
|
||||
# Update close_prices in checkpoint
|
||||
self._save_checkpoint(session_key, "execution", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
|
||||
# Phase 5: Settlement - run after close prices available
|
||||
settlement_result = None
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
# Phase 5: Settlement
|
||||
if last_phase == "execution":
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
|
||||
agent_trajectories = await self._capture_agent_trajectories()
|
||||
agent_trajectories = await self._capture_agent_trajectories()
|
||||
|
||||
if market_caps is None:
|
||||
market_caps = {ticker: 1e9 for ticker in tickers}
|
||||
if market_caps is None:
|
||||
market_caps = {ticker: 1e9 for ticker in tickers}
|
||||
|
||||
settlement_result = (
|
||||
self.settlement_coordinator.run_daily_settlement(
|
||||
settlement_result = (
|
||||
self.settlement_coordinator.run_daily_settlement(
|
||||
date=date,
|
||||
tickers=tickers,
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
agent_portfolio=execution_result.get("portfolio", {}),
|
||||
analyst_results=analyst_results,
|
||||
pm_decisions=decisions,
|
||||
)
|
||||
)
|
||||
|
||||
await self._run_reflection(
|
||||
date=date,
|
||||
tickers=tickers,
|
||||
agent_trajectories=agent_trajectories,
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions,
|
||||
executed_trades=execution_result.get("executed_trades", []),
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
agent_portfolio=execution_result.get("portfolio", {}),
|
||||
analyst_results=analyst_results,
|
||||
pm_decisions=decisions,
|
||||
settlement_result=settlement_result,
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
)
|
||||
|
||||
await self._run_reflection(
|
||||
date=date,
|
||||
agent_trajectories=agent_trajectories,
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions,
|
||||
executed_trades=execution_result.get("executed_trades", []),
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
settlement_result=settlement_result,
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
|
||||
self._save_checkpoint(session_key, "settlement", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"settlement_result": settlement_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "settlement"
|
||||
|
||||
_log(f"Cycle complete: {date}")
|
||||
self._runtime_batch_status(
|
||||
@@ -323,6 +451,11 @@ class TradingPipeline:
|
||||
)
|
||||
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||
|
||||
# Optional: Clean up checkpoint after successful completion
|
||||
# path = self._get_checkpoint_path(session_key)
|
||||
# if path and path.exists():
|
||||
# path.unlink()
|
||||
|
||||
return {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
@@ -385,6 +518,44 @@ class TradingPipeline:
|
||||
await self.risk_manager.memory.clear()
|
||||
await self.pm.memory.clear()
|
||||
|
||||
def _get_checkpoint_path(self, session_key: str) -> Optional[Path]:
|
||||
"""Get the path to the pipeline checkpoint file."""
|
||||
if not self.runtime_manager or not self.runtime_manager.run_dir:
|
||||
return None
|
||||
checkpoint_dir = self.runtime_manager.run_dir / "state" / "checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
return checkpoint_dir / f"pipeline_{session_key}.json"
|
||||
|
||||
def _save_checkpoint(self, session_key: str, phase: str, data: Dict[str, Any]) -> None:
|
||||
"""Save the current pipeline state to a checkpoint file."""
|
||||
path = self._get_checkpoint_path(session_key)
|
||||
if not path:
|
||||
return
|
||||
|
||||
checkpoint = {
|
||||
"session_key": session_key,
|
||||
"phase": phase,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data
|
||||
}
|
||||
try:
|
||||
path.write_text(json.dumps(checkpoint, ensure_ascii=False, indent=2, default=str), encoding="utf-8")
|
||||
_log(f"Checkpoint saved: {phase} for {session_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
|
||||
def _load_checkpoint(self, session_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the pipeline state from a checkpoint file."""
|
||||
path = self._get_checkpoint_path(session_key)
|
||||
if not path or not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _sync_memory_if_retrieved(self, agent: Any) -> None:
|
||||
"""
|
||||
Check agent's short-term memory for retrieved long-term memory and sync to frontend.
|
||||
@@ -585,6 +756,25 @@ class TradingPipeline:
|
||||
content=reflection_content,
|
||||
)
|
||||
|
||||
# Phase 6: APO (Autonomous Policy Optimization)
|
||||
# If the day was a loss, let APO suggest and apply policy updates.
|
||||
if hasattr(self, "apo") and self.apo:
|
||||
_log(f"Phase 6: APO - Running autonomous policy optimization for {date}")
|
||||
try:
|
||||
apo_result = await self.apo.run_optimization(
|
||||
date=date,
|
||||
reflection_content=reflection_content,
|
||||
settlement_result=settlement_result or {"portfolio_value": 100000.0 + total_pnl},
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions
|
||||
)
|
||||
if apo_result.get("status") == "completed":
|
||||
_log(f"APO: Successfully applied {len(apo_result.get('optimizations', []))} policy updates.")
|
||||
# Reload assets for next cycle to ensure they are picked up
|
||||
self.reload_runtime_assets()
|
||||
except Exception as e:
|
||||
logger.error(f"APO: Optimization failed: {e}")
|
||||
|
||||
def _build_reflection_content(
|
||||
self,
|
||||
date: str,
|
||||
@@ -1562,28 +1752,74 @@ class TradingPipeline:
|
||||
"""Return static analysts plus runtime-created analysts."""
|
||||
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||
|
||||
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||
"""Create one runtime analyst instance."""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
def _create_runtime_analyst(
|
||||
self,
|
||||
agent_id: str,
|
||||
analyst_type: str,
|
||||
custom_config: Optional[AnalystConfig] = None,
|
||||
) -> str:
|
||||
"""Create one runtime analyst instance.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the new analyst
|
||||
analyst_type: Type of analyst (e.g., "technical_analyst")
|
||||
custom_config: Optional custom configuration for the analyst,
|
||||
including persona, soul_md, agents_md, etc.
|
||||
|
||||
Returns:
|
||||
Success or error message
|
||||
"""
|
||||
# Validate analyst_type or custom_config
|
||||
if analyst_type not in ANALYST_TYPES and not custom_config:
|
||||
return (
|
||||
f"Unknown analyst_type '{analyst_type}'. "
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}. "
|
||||
f"Or provide custom_config to create a custom analyst."
|
||||
)
|
||||
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||
return f"Analyst '{agent_id}' already exists."
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
|
||||
# Get persona: use custom_config if provided, else load from personas.yaml
|
||||
if custom_config and custom_config.persona:
|
||||
persona = {
|
||||
"name": custom_config.persona.name,
|
||||
"focus": custom_config.persona.focus,
|
||||
"description": custom_config.persona.description,
|
||||
}
|
||||
else:
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
workspace_manager = WorkspaceManager(project_root=project_root)
|
||||
|
||||
# Build file contents: use custom if provided, else generate from persona
|
||||
file_contents = {}
|
||||
if custom_config:
|
||||
if custom_config.soul_md:
|
||||
file_contents["SOUL.md"] = custom_config.soul_md
|
||||
if custom_config.agents_md:
|
||||
file_contents["AGENTS.md"] = custom_config.agents_md
|
||||
if custom_config.profile_md:
|
||||
file_contents["PROFILE.md"] = custom_config.profile_md
|
||||
if custom_config.bootstrap_md:
|
||||
file_contents["BOOTSTRAP.md"] = custom_config.bootstrap_md
|
||||
|
||||
# Fill in any missing files with defaults
|
||||
if not file_contents or len(file_contents) < 4:
|
||||
default_files = workspace_manager.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
)
|
||||
for key, value in default_files.items():
|
||||
if key not in file_contents:
|
||||
file_contents[key] = value
|
||||
|
||||
workspace_manager.ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
file_contents=workspace_manager.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
),
|
||||
file_contents=file_contents,
|
||||
)
|
||||
|
||||
# Create EvoAgent with workspace-driven configuration
|
||||
@@ -1594,11 +1830,23 @@ class TradingPipeline:
|
||||
agent_id,
|
||||
)
|
||||
agent_config = load_agent_workspace_config(workspace_dir / "agent.yaml")
|
||||
# Support model override from custom_config
|
||||
if custom_config and custom_config.model_name:
|
||||
# Import create_model for custom model creation
|
||||
from backend.llm.models import create_model
|
||||
# Use specified model name, default to openai provider
|
||||
model = create_model(
|
||||
model_name=custom_config.model_name,
|
||||
model_provider=custom_config.memory_config.get("model_provider", "openai") if custom_config.memory_config else "openai"
|
||||
)
|
||||
else:
|
||||
model = get_agent_model(analyst_type)
|
||||
|
||||
agent = EvoAgent(
|
||||
agent_id=agent_id,
|
||||
config_name=config_name,
|
||||
workspace_dir=workspace_dir,
|
||||
model=get_agent_model(analyst_type),
|
||||
model=model,
|
||||
formatter=get_agent_formatter(analyst_type),
|
||||
prompt_files=agent_config.prompt_files,
|
||||
)
|
||||
@@ -1611,6 +1859,11 @@ class TradingPipeline:
|
||||
# Keep workspace_id for backward compatibility
|
||||
setattr(agent, "workspace_id", config_name)
|
||||
self._dynamic_analysts[agent_id] = agent
|
||||
|
||||
# Store custom config for future reference (e.g., cloning)
|
||||
if custom_config:
|
||||
self._dynamic_analyst_configs[agent_id] = custom_config
|
||||
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
@@ -1624,6 +1877,8 @@ class TradingPipeline:
|
||||
if agent_id not in self._dynamic_analysts:
|
||||
return f"Runtime analyst '{agent_id}' not found."
|
||||
self._dynamic_analysts.pop(agent_id, None)
|
||||
# Also remove stored config if exists
|
||||
self._dynamic_analyst_configs.pop(agent_id, None)
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
update_active_analysts(
|
||||
|
||||
@@ -17,6 +17,14 @@ NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
def normalize_schedule_mode(mode: str | None) -> str:
|
||||
"""Normalize schedule mode to the current public vocabulary."""
|
||||
value = str(mode or "daily").strip().lower()
|
||||
if value == "intraday":
|
||||
return "interval"
|
||||
return value or "daily"
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Market-aware scheduler for live trading.
|
||||
@@ -31,7 +39,7 @@ class Scheduler:
|
||||
heartbeat_interval: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.mode = normalize_schedule_mode(mode)
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
@@ -107,7 +115,7 @@ class Scheduler:
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(self._callback))
|
||||
elif self.mode == "intraday":
|
||||
elif self.mode == "interval":
|
||||
self._task = asyncio.create_task(
|
||||
self._run_intraday(self._callback),
|
||||
)
|
||||
@@ -124,8 +132,13 @@ class Scheduler:
|
||||
"""Update scheduler parameters in-place and restart its timing loop."""
|
||||
changed = False
|
||||
|
||||
if mode and mode != self.mode:
|
||||
self.mode = mode
|
||||
if mode:
|
||||
normalized_mode = normalize_schedule_mode(mode)
|
||||
else:
|
||||
normalized_mode = None
|
||||
|
||||
if normalized_mode and normalized_mode != self.mode:
|
||||
self.mode = normalized_mode
|
||||
changed = True
|
||||
|
||||
if trigger_time and trigger_time != self.trigger_time:
|
||||
@@ -233,13 +246,13 @@ class Scheduler:
|
||||
await callback(date=current_date)
|
||||
|
||||
async def _run_intraday(self, callback: Callable):
|
||||
"""Run every N minutes (for future use)"""
|
||||
"""Run every N minutes in interval mode."""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
|
||||
if self._is_trading_day(now):
|
||||
logger.info(f"Triggering intraday cycle for {current_date}")
|
||||
logger.info(f"Triggering interval cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
await asyncio.sleep(self.interval_minutes * 60)
|
||||
|
||||
Reference in New Issue
Block a user