feat: add runtime dynamic team controls
This commit is contained in:
@@ -33,6 +33,8 @@ from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.agents.dynamic_team_types import AnalystConfig
|
||||
from backend.tools.dynamic_team_tools import DynamicTeamController, set_controller
|
||||
|
||||
|
||||
def _resolve_evo_agent_ids() -> set[str]:
|
||||
@@ -84,6 +86,9 @@ def _log(msg: str) -> None:
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
from backend.core.apo import PolicyOptimizer
|
||||
|
||||
|
||||
class TradingPipeline:
|
||||
"""
|
||||
Trading Pipeline - Orchestrates the complete trading cycle
|
||||
@@ -127,7 +132,21 @@ class TradingPipeline:
|
||||
self.runtime_manager = runtime_manager
|
||||
self._session_key: Optional[str] = None
|
||||
self._dynamic_analysts: Dict[str, Any] = {}
|
||||
self._dynamic_analyst_configs: Dict[str, AnalystConfig] = {}
|
||||
|
||||
# Initialize APO (Autonomous Policy Optimizer)
|
||||
config_name = workspace_id or (runtime_manager.config_name if runtime_manager else "default")
|
||||
self.apo = PolicyOptimizer(config_name=config_name)
|
||||
|
||||
# Initialize dynamic team controller and inject into PM
|
||||
self._team_controller = DynamicTeamController(
|
||||
create_callback=self._create_runtime_analyst,
|
||||
remove_callback=self._remove_runtime_analyst,
|
||||
get_analysts_callback=self._all_analysts,
|
||||
)
|
||||
set_controller(self._team_controller)
|
||||
|
||||
# Backward compatibility: also set individual callbacks if PM expects them
|
||||
if hasattr(self.pm, "set_team_controller"):
|
||||
self.pm.set_team_controller(
|
||||
create_agent_callback=self._create_runtime_analyst,
|
||||
@@ -150,23 +169,7 @@ class TradingPipeline:
|
||||
execute_decisions: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run one complete trading cycle
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
prices: Open prices {ticker: price} (for backtest)
|
||||
close_prices: Close prices for settlement (for backtest)
|
||||
market_caps: Optional market caps for baseline calculation
|
||||
get_open_prices_fn: Async callback to wait for open prices (live mode)
|
||||
get_close_prices_fn: Async callback to wait for close prices (live mode)
|
||||
|
||||
For live mode:
|
||||
- Analysis runs immediately
|
||||
- Execution waits for market open via get_open_prices_fn
|
||||
- Settlement waits for market close via get_close_prices_fn
|
||||
|
||||
Each agent's result is broadcast immediately via StateSync.
|
||||
Run one complete trading cycle with checkpointing support.
|
||||
"""
|
||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||
session_key = TradingSessionKey(date=date).key()
|
||||
@@ -176,14 +179,45 @@ class TradingPipeline:
|
||||
agents=active_analysts + [self.risk_manager, self.pm],
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
# Load checkpoint if exists
|
||||
checkpoint = self._load_checkpoint(session_key)
|
||||
checkpoint_data = checkpoint.get("data", {}) if checkpoint else {}
|
||||
last_phase = checkpoint.get("phase") if checkpoint else None
|
||||
|
||||
if checkpoint:
|
||||
_log(f"Resuming from checkpoint: {last_phase}")
|
||||
# Restore state from checkpoint
|
||||
analyst_results = checkpoint_data.get("analyst_results", [])
|
||||
risk_assessment = checkpoint_data.get("risk_assessment", {})
|
||||
self.conference_summary = checkpoint_data.get("conference_summary")
|
||||
final_predictions = checkpoint_data.get("final_predictions", [])
|
||||
pm_result = checkpoint_data.get("pm_result", {})
|
||||
execution_result = checkpoint_data.get("execution_result", {})
|
||||
settlement_result = checkpoint_data.get("settlement_result")
|
||||
# Prefer passed prices if not hold in checkpoint
|
||||
if not prices:
|
||||
prices = checkpoint_data.get("prices")
|
||||
if not close_prices:
|
||||
close_prices = checkpoint_data.get("close_prices")
|
||||
else:
|
||||
analyst_results = []
|
||||
risk_assessment = {}
|
||||
self.conference_summary = None
|
||||
final_predictions = []
|
||||
pm_result = {}
|
||||
execution_result = {}
|
||||
settlement_result = None
|
||||
|
||||
if self.runtime_manager:
|
||||
self.runtime_manager.set_session_key(session_key)
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date, "resumed": checkpoint is not None})
|
||||
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||
|
||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
# Phase 0: Clear memory (only if not resuming or if resuming from very start)
|
||||
if not last_phase:
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
|
||||
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||
|
||||
@@ -196,125 +230,219 @@ class TradingPipeline:
|
||||
"system",
|
||||
),
|
||||
):
|
||||
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
# Phase 1.1: Analysts
|
||||
if not last_phase or last_phase == "cleared":
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
self._save_checkpoint(session_key, "analysis", {
|
||||
"analyst_results": analyst_results,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "analysis"
|
||||
|
||||
# Phase 1.2: Risk Manager
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||
risk_assessment = await self._run_risk_manager_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
)
|
||||
if last_phase == "analysis":
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||
risk_assessment = await self._run_risk_manager_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
)
|
||||
self._save_checkpoint(session_key, "risk_assessment", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "risk_assessment"
|
||||
|
||||
# Phase 2.1: Conference discussion (within same MsgHub)
|
||||
_log("Phase 2.1: Conference discussion")
|
||||
conference_summary = await self._run_conference_cycles(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
self.conference_summary = conference_summary
|
||||
# Phase 2.1: Conference discussion
|
||||
if last_phase == "risk_assessment":
|
||||
_log("Phase 2.1: Conference discussion")
|
||||
conference_summary = await self._run_conference_cycles(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
self.conference_summary = conference_summary
|
||||
self._save_checkpoint(session_key, "conference", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "conference"
|
||||
|
||||
# Phase 2.2: Analysts generate final structured predictions
|
||||
_log("Phase 2.2: Analysts generate final structured predictions")
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
if last_phase == "conference":
|
||||
_log("Phase 2.2: Analysts generate final structured predictions")
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
self._save_checkpoint(session_key, "predictions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "predictions"
|
||||
|
||||
# Record final predictions for leaderboard ranking
|
||||
if self.settlement_coordinator:
|
||||
# Record final predictions
|
||||
if last_phase == "predictions" and self.settlement_coordinator:
|
||||
self.settlement_coordinator.record_analyst_predictions(
|
||||
final_predictions,
|
||||
)
|
||||
|
||||
# Live mode: wait for market open before execution
|
||||
if get_open_prices_fn:
|
||||
# Live mode: wait for market open
|
||||
if not prices and get_open_prices_fn:
|
||||
_log("Waiting for market open...")
|
||||
prices = await get_open_prices_fn()
|
||||
_log(f"Got open prices: {prices}")
|
||||
# Update prices in checkpoint if we just got them
|
||||
self._save_checkpoint(session_key, "predictions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
|
||||
# Phase 3: PM makes decisions
|
||||
_log("Phase 3.1: PM makes decisions")
|
||||
self._runtime_update_status(self.pm, "decision_phase")
|
||||
pm_result = await self._run_pm_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
analyst_results,
|
||||
risk_assessment,
|
||||
)
|
||||
if last_phase == "predictions":
|
||||
_log("Phase 3.1: PM makes decisions")
|
||||
self._runtime_update_status(self.pm, "decision_phase")
|
||||
pm_result = await self._run_pm_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
prices,
|
||||
analyst_results,
|
||||
risk_assessment,
|
||||
)
|
||||
self._save_checkpoint(session_key, "decisions", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "decisions"
|
||||
|
||||
decisions = pm_result.get("decisions", {})
|
||||
execution_result = {
|
||||
"executed_trades": [],
|
||||
"portfolio": self.pm.get_portfolio_state(),
|
||||
}
|
||||
if execute_decisions:
|
||||
_log("Phase 4: Executing trades")
|
||||
self._runtime_update_status(self.pm, "executing")
|
||||
execution_result = self._execute_decisions(decisions, prices, date)
|
||||
else:
|
||||
_log("Phase 4: Skipping trade execution")
|
||||
# Outside MsgHub for execution and settlement
|
||||
decisions = pm_result.get("decisions", {}) if pm_result else {}
|
||||
if not execution_result:
|
||||
execution_result = {
|
||||
"executed_trades": [],
|
||||
"portfolio": self.pm.get_portfolio_state(),
|
||||
}
|
||||
|
||||
# Live mode: wait for market close before settlement
|
||||
if get_close_prices_fn:
|
||||
if last_phase == "decisions":
|
||||
if execute_decisions:
|
||||
_log("Phase 4: Executing trades")
|
||||
self._runtime_update_status(self.pm, "executing")
|
||||
execution_result = self._execute_decisions(decisions, prices, date)
|
||||
else:
|
||||
_log("Phase 4: Skipping trade execution")
|
||||
|
||||
self._save_checkpoint(session_key, "execution", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "execution"
|
||||
|
||||
# Live mode: wait for market close
|
||||
if not close_prices and get_close_prices_fn:
|
||||
_log("Waiting for market close")
|
||||
close_prices = await get_close_prices_fn()
|
||||
_log(f"Got close prices: {close_prices}")
|
||||
# Update close_prices in checkpoint
|
||||
self._save_checkpoint(session_key, "execution", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
|
||||
# Phase 5: Settlement - run after close prices available
|
||||
settlement_result = None
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
# Phase 5: Settlement
|
||||
if last_phase == "execution":
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
|
||||
agent_trajectories = await self._capture_agent_trajectories()
|
||||
agent_trajectories = await self._capture_agent_trajectories()
|
||||
|
||||
if market_caps is None:
|
||||
market_caps = {ticker: 1e9 for ticker in tickers}
|
||||
if market_caps is None:
|
||||
market_caps = {ticker: 1e9 for ticker in tickers}
|
||||
|
||||
settlement_result = (
|
||||
self.settlement_coordinator.run_daily_settlement(
|
||||
settlement_result = (
|
||||
self.settlement_coordinator.run_daily_settlement(
|
||||
date=date,
|
||||
tickers=tickers,
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
agent_portfolio=execution_result.get("portfolio", {}),
|
||||
analyst_results=analyst_results,
|
||||
pm_decisions=decisions,
|
||||
)
|
||||
)
|
||||
|
||||
await self._run_reflection(
|
||||
date=date,
|
||||
tickers=tickers,
|
||||
agent_trajectories=agent_trajectories,
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions,
|
||||
executed_trades=execution_result.get("executed_trades", []),
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
agent_portfolio=execution_result.get("portfolio", {}),
|
||||
analyst_results=analyst_results,
|
||||
pm_decisions=decisions,
|
||||
settlement_result=settlement_result,
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
)
|
||||
|
||||
await self._run_reflection(
|
||||
date=date,
|
||||
agent_trajectories=agent_trajectories,
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions,
|
||||
executed_trades=execution_result.get("executed_trades", []),
|
||||
open_prices=prices,
|
||||
close_prices=close_prices,
|
||||
settlement_result=settlement_result,
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
|
||||
self._save_checkpoint(session_key, "settlement", {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
"conference_summary": conference_summary,
|
||||
"final_predictions": final_predictions,
|
||||
"pm_result": pm_result,
|
||||
"execution_result": execution_result,
|
||||
"settlement_result": settlement_result,
|
||||
"prices": prices,
|
||||
"close_prices": close_prices
|
||||
})
|
||||
last_phase = "settlement"
|
||||
|
||||
_log(f"Cycle complete: {date}")
|
||||
self._runtime_batch_status(
|
||||
@@ -323,6 +451,11 @@ class TradingPipeline:
|
||||
)
|
||||
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||
|
||||
# Optional: Clean up checkpoint after successful completion
|
||||
# path = self._get_checkpoint_path(session_key)
|
||||
# if path and path.exists():
|
||||
# path.unlink()
|
||||
|
||||
return {
|
||||
"analyst_results": analyst_results,
|
||||
"risk_assessment": risk_assessment,
|
||||
@@ -385,6 +518,44 @@ class TradingPipeline:
|
||||
await self.risk_manager.memory.clear()
|
||||
await self.pm.memory.clear()
|
||||
|
||||
def _get_checkpoint_path(self, session_key: str) -> Optional[Path]:
|
||||
"""Get the path to the pipeline checkpoint file."""
|
||||
if not self.runtime_manager or not self.runtime_manager.run_dir:
|
||||
return None
|
||||
checkpoint_dir = self.runtime_manager.run_dir / "state" / "checkpoints"
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
return checkpoint_dir / f"pipeline_{session_key}.json"
|
||||
|
||||
def _save_checkpoint(self, session_key: str, phase: str, data: Dict[str, Any]) -> None:
|
||||
"""Save the current pipeline state to a checkpoint file."""
|
||||
path = self._get_checkpoint_path(session_key)
|
||||
if not path:
|
||||
return
|
||||
|
||||
checkpoint = {
|
||||
"session_key": session_key,
|
||||
"phase": phase,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"data": data
|
||||
}
|
||||
try:
|
||||
path.write_text(json.dumps(checkpoint, ensure_ascii=False, indent=2, default=str), encoding="utf-8")
|
||||
_log(f"Checkpoint saved: {phase} for {session_key}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint: {e}")
|
||||
|
||||
def _load_checkpoint(self, session_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the pipeline state from a checkpoint file."""
|
||||
path = self._get_checkpoint_path(session_key)
|
||||
if not path or not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load checkpoint: {e}")
|
||||
return None
|
||||
|
||||
async def _sync_memory_if_retrieved(self, agent: Any) -> None:
|
||||
"""
|
||||
Check agent's short-term memory for retrieved long-term memory and sync to frontend.
|
||||
@@ -585,6 +756,25 @@ class TradingPipeline:
|
||||
content=reflection_content,
|
||||
)
|
||||
|
||||
# Phase 6: APO (Autonomous Policy Optimization)
|
||||
# If the day was a loss, let APO suggest and apply policy updates.
|
||||
if hasattr(self, "apo") and self.apo:
|
||||
_log(f"Phase 6: APO - Running autonomous policy optimization for {date}")
|
||||
try:
|
||||
apo_result = await self.apo.run_optimization(
|
||||
date=date,
|
||||
reflection_content=reflection_content,
|
||||
settlement_result=settlement_result or {"portfolio_value": 100000.0 + total_pnl},
|
||||
analyst_results=analyst_results,
|
||||
decisions=decisions
|
||||
)
|
||||
if apo_result.get("status") == "completed":
|
||||
_log(f"APO: Successfully applied {len(apo_result.get('optimizations', []))} policy updates.")
|
||||
# Reload assets for next cycle to ensure they are picked up
|
||||
self.reload_runtime_assets()
|
||||
except Exception as e:
|
||||
logger.error(f"APO: Optimization failed: {e}")
|
||||
|
||||
def _build_reflection_content(
|
||||
self,
|
||||
date: str,
|
||||
@@ -1562,28 +1752,74 @@ class TradingPipeline:
|
||||
"""Return static analysts plus runtime-created analysts."""
|
||||
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||
|
||||
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||
"""Create one runtime analyst instance."""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
def _create_runtime_analyst(
|
||||
self,
|
||||
agent_id: str,
|
||||
analyst_type: str,
|
||||
custom_config: Optional[AnalystConfig] = None,
|
||||
) -> str:
|
||||
"""Create one runtime analyst instance.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the new analyst
|
||||
analyst_type: Type of analyst (e.g., "technical_analyst")
|
||||
custom_config: Optional custom configuration for the analyst,
|
||||
including persona, soul_md, agents_md, etc.
|
||||
|
||||
Returns:
|
||||
Success or error message
|
||||
"""
|
||||
# Validate analyst_type or custom_config
|
||||
if analyst_type not in ANALYST_TYPES and not custom_config:
|
||||
return (
|
||||
f"Unknown analyst_type '{analyst_type}'. "
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}. "
|
||||
f"Or provide custom_config to create a custom analyst."
|
||||
)
|
||||
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||
return f"Analyst '{agent_id}' already exists."
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
|
||||
# Get persona: use custom_config if provided, else load from personas.yaml
|
||||
if custom_config and custom_config.persona:
|
||||
persona = {
|
||||
"name": custom_config.persona.name,
|
||||
"focus": custom_config.persona.focus,
|
||||
"description": custom_config.persona.description,
|
||||
}
|
||||
else:
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
workspace_manager = WorkspaceManager(project_root=project_root)
|
||||
|
||||
# Build file contents: use custom if provided, else generate from persona
|
||||
file_contents = {}
|
||||
if custom_config:
|
||||
if custom_config.soul_md:
|
||||
file_contents["SOUL.md"] = custom_config.soul_md
|
||||
if custom_config.agents_md:
|
||||
file_contents["AGENTS.md"] = custom_config.agents_md
|
||||
if custom_config.profile_md:
|
||||
file_contents["PROFILE.md"] = custom_config.profile_md
|
||||
if custom_config.bootstrap_md:
|
||||
file_contents["BOOTSTRAP.md"] = custom_config.bootstrap_md
|
||||
|
||||
# Fill in any missing files with defaults
|
||||
if not file_contents or len(file_contents) < 4:
|
||||
default_files = workspace_manager.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
)
|
||||
for key, value in default_files.items():
|
||||
if key not in file_contents:
|
||||
file_contents[key] = value
|
||||
|
||||
workspace_manager.ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
file_contents=workspace_manager.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
),
|
||||
file_contents=file_contents,
|
||||
)
|
||||
|
||||
# Create EvoAgent with workspace-driven configuration
|
||||
@@ -1594,11 +1830,23 @@ class TradingPipeline:
|
||||
agent_id,
|
||||
)
|
||||
agent_config = load_agent_workspace_config(workspace_dir / "agent.yaml")
|
||||
# Support model override from custom_config
|
||||
if custom_config and custom_config.model_name:
|
||||
# Import create_model for custom model creation
|
||||
from backend.llm.models import create_model
|
||||
# Use specified model name, default to openai provider
|
||||
model = create_model(
|
||||
model_name=custom_config.model_name,
|
||||
model_provider=custom_config.memory_config.get("model_provider", "openai") if custom_config.memory_config else "openai"
|
||||
)
|
||||
else:
|
||||
model = get_agent_model(analyst_type)
|
||||
|
||||
agent = EvoAgent(
|
||||
agent_id=agent_id,
|
||||
config_name=config_name,
|
||||
workspace_dir=workspace_dir,
|
||||
model=get_agent_model(analyst_type),
|
||||
model=model,
|
||||
formatter=get_agent_formatter(analyst_type),
|
||||
prompt_files=agent_config.prompt_files,
|
||||
)
|
||||
@@ -1611,6 +1859,11 @@ class TradingPipeline:
|
||||
# Keep workspace_id for backward compatibility
|
||||
setattr(agent, "workspace_id", config_name)
|
||||
self._dynamic_analysts[agent_id] = agent
|
||||
|
||||
# Store custom config for future reference (e.g., cloning)
|
||||
if custom_config:
|
||||
self._dynamic_analyst_configs[agent_id] = custom_config
|
||||
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
@@ -1624,6 +1877,8 @@ class TradingPipeline:
|
||||
if agent_id not in self._dynamic_analysts:
|
||||
return f"Runtime analyst '{agent_id}' not found."
|
||||
self._dynamic_analysts.pop(agent_id, None)
|
||||
# Also remove stored config if exists
|
||||
self._dynamic_analyst_configs.pop(agent_id, None)
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
update_active_analysts(
|
||||
|
||||
Reference in New Issue
Block a user