Files
evotraders/backend/core/pipeline_runner.py
cillin 06a23c32a4 refactor: Fix code quality issues identified in analysis
1. Rename factory.py's EvoAgent data class to AgentConfig
   - Avoids naming conflict with base/evo_agent.py's EvoAgent

2. Export pipeline_runner functions in backend/core/__init__.py
   - Add create_agents, create_long_term_memory, stop_gateway

3. Consolidate PromptLoader to singleton pattern
   - Add get_prompt_loader() singleton function
   - Update all usages to use singleton

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 01:07:53 +08:00

490 lines
18 KiB
Python

# -*- coding: utf-8 -*-
"""
Pipeline Runner - Independent trading pipeline execution
This module provides functions to start/stop trading pipelines
that can be called from the REST API.
"""
from __future__ import annotations
import asyncio
import os
from contextlib import AsyncExitStack
from pathlib import Path
from typing import Any, Dict, Optional, Callable
from backend.agents import AnalystAgent, PMAgent, RiskAgent
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import get_prompt_loader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.constants import ANALYST_TYPES
from backend.core.pipeline import TradingPipeline
from backend.core.scheduler import BacktestScheduler, Scheduler
from backend.llm.models import get_agent_formatter, get_agent_model
from backend.runtime.manager import (
TradingRuntimeManager,
set_global_runtime_manager,
clear_global_runtime_manager,
set_shutdown_event,
clear_shutdown_event,
is_shutdown_requested,
)
from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.services.gateway import Gateway
from backend.utils.settlement import SettlementCoordinator
_prompt_loader = get_prompt_loader()
# Global gateway reference for cleanup
_gateway_instance: Optional[Gateway] = None
def _set_gateway(gateway: Optional[Gateway]) -> None:
"""Set global gateway reference."""
global _gateway_instance
_gateway_instance = gateway
def stop_gateway() -> None:
"""Stop the running gateway if exists."""
global _gateway_instance
if _gateway_instance is not None:
try:
_gateway_instance.stop()
except Exception as e:
import logging
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
finally:
_gateway_instance = None
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
"""Create ReMeTaskLongTermMemory for an agent."""
try:
from agentscope.memory import ReMeTaskLongTermMemory
from agentscope.model import DashScopeChatModel
from agentscope.embedding import DashScopeTextEmbedding
except ImportError:
return None
api_key = os.getenv("MEMORY_API_KEY")
if not api_key:
return None
memory_dir = str(run_dir / "memory")
return ReMeTaskLongTermMemory(
agent_name=agent_name,
user_name=agent_name,
model=DashScopeChatModel(
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
api_key=api_key,
stream=False,
),
embedding_model=DashScopeTextEmbedding(
model_name=os.getenv("MEMORY_EMBEDDING_MODEL", "text-embedding-v4"),
api_key=api_key,
dimensions=1024,
),
**{
"vector_store.default.backend": "local",
"vector_store.default.params.store_dir": memory_dir,
},
)
def create_agents(
run_id: str,
run_dir: Path,
initial_cash: float,
margin_requirement: float,
enable_long_term_memory: bool = False,
):
"""Create all agents for the system."""
analysts = []
long_term_memories = []
# Initialize workspace manager and assets
workspace_manager = WorkspaceManager()
workspace_manager.initialize_default_assets(
config_name=run_id,
agent_ids=list(ANALYST_TYPES.keys()) + ["risk_manager", "portfolio_manager"],
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
)
profiles = load_agent_profiles()
skills_manager = SkillsManager()
active_skill_map = skills_manager.prepare_active_skills(
config_name=run_id,
agent_defaults={
agent_id: profile.get("skills", [])
for agent_id, profile in profiles.items()
},
)
# Create analyst agents
for analyst_type in ANALYST_TYPES:
model = get_agent_model(analyst_type)
formatter = get_agent_formatter(analyst_type)
toolkit = create_agent_toolkit(
analyst_type,
run_id,
active_skill_dirs=active_skill_map.get(analyst_type, []),
)
long_term_memory = None
if enable_long_term_memory:
long_term_memory = create_long_term_memory(analyst_type, run_id, run_dir)
if long_term_memory:
long_term_memories.append(long_term_memory)
analyst = AnalystAgent(
analyst_type=analyst_type,
toolkit=toolkit,
model=model,
formatter=formatter,
agent_id=analyst_type,
config={"config_name": run_id},
long_term_memory=long_term_memory,
)
analysts.append(analyst)
# Create risk manager
risk_long_term_memory = None
if enable_long_term_memory:
risk_long_term_memory = create_long_term_memory("risk_manager", run_id, run_dir)
if risk_long_term_memory:
long_term_memories.append(risk_long_term_memory)
risk_manager = RiskAgent(
model=get_agent_model("risk_manager"),
formatter=get_agent_formatter("risk_manager"),
name="risk_manager",
config={"config_name": run_id},
long_term_memory=risk_long_term_memory,
toolkit=create_agent_toolkit(
"risk_manager",
run_id,
active_skill_dirs=active_skill_map.get("risk_manager", []),
),
)
# Create portfolio manager
pm_long_term_memory = None
if enable_long_term_memory:
pm_long_term_memory = create_long_term_memory("portfolio_manager", run_id, run_dir)
if pm_long_term_memory:
long_term_memories.append(pm_long_term_memory)
portfolio_manager = PMAgent(
name="portfolio_manager",
model=get_agent_model("portfolio_manager"),
formatter=get_agent_formatter("portfolio_manager"),
initial_cash=initial_cash,
margin_requirement=margin_requirement,
config={"config_name": run_id},
long_term_memory=pm_long_term_memory,
toolkit_factory=create_agent_toolkit,
toolkit_factory_kwargs={
"active_skill_dirs": active_skill_map.get("portfolio_manager", []),
},
)
return analysts, risk_manager, portfolio_manager, long_term_memories
async def run_pipeline(
run_id: str,
run_dir: Path,
bootstrap: Dict[str, Any],
stop_event: asyncio.Event,
message_callback: Optional[Callable[[str, Any], None]] = None,
) -> None:
"""
Run the trading pipeline with the given configuration.
Service Startup Order:
Phase 1: WebSocket Server - Frontend can connect
Phase 2: Market Service - Price data starts flowing
Phase 3: Agent Runtime - Create all agents
Phase 4: Pipeline & Scheduler - Trading logic ready
Phase 5: Gateway Fully Operational - All systems running
Args:
run_id: Unique run identifier (timestamp)
run_dir: Run directory path
bootstrap: Bootstrap configuration
stop_event: Event to signal pipeline stop
message_callback: Optional callback for sending messages to clients
"""
import logging
logger = logging.getLogger(__name__)
# Set global shutdown event
set_shutdown_event(stop_event)
logger.info(f"[Pipeline {run_id}] ======================================")
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
logger.info(f"[Pipeline {run_id}] ======================================")
try:
# Extract config values
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
schedule_mode = bootstrap.get("schedule_mode", "daily")
trigger_time = bootstrap.get("trigger_time", "09:30")
interval_minutes = int(bootstrap.get("interval_minutes", 60))
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
mode = bootstrap.get("mode", "live")
start_date = bootstrap.get("start_date")
end_date = bootstrap.get("end_date")
enable_memory = bootstrap.get("enable_memory", False)
enable_mock = bootstrap.get("enable_mock", False)
is_backtest = mode == "backtest"
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
# ======================================================================
# PHASE 0: Initialize runtime manager
# ======================================================================
logger.info("[Phase 0/5] Initializing runtime manager...")
from backend.api.runtime import runtime_manager
if runtime_manager is None:
runtime_manager = TradingRuntimeManager(
config_name=run_id,
run_dir=run_dir,
bootstrap=bootstrap,
)
runtime_manager.prepare_run()
set_global_runtime_manager(runtime_manager)
# Register runtime manager with API
from backend.api.runtime import register_runtime_manager
register_runtime_manager(runtime_manager)
# ======================================================================
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
# These will be started by Gateway in the correct order
# ======================================================================
logger.info("[Phase 1-2/5] Creating infrastructure services...")
# Create storage service
storage_service = StorageService(
dashboard_dir=run_dir / "team_dashboard",
initial_cash=initial_cash,
config_name=run_id,
)
if not storage_service.files["summary"].exists():
storage_service.initialize_empty_dashboard()
else:
storage_service.update_leaderboard_model_info()
# Create market service (data source)
market_service = MarketService(
tickers=tickers,
poll_interval=10,
mock_mode=is_mock and not is_backtest,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
backtest_start_date=start_date if is_backtest else None,
backtest_end_date=end_date if is_backtest else None,
)
# ======================================================================
# PHASE 3: Create Agent Runtime
# ======================================================================
logger.info("[Phase 3/5] Creating agent runtime...")
analysts, risk_manager, pm, long_term_memories = create_agents(
run_id=run_id,
run_dir=run_dir,
initial_cash=initial_cash,
margin_requirement=margin_requirement,
enable_long_term_memory=enable_memory,
)
# Register agents with runtime manager
for agent in analysts + [risk_manager, pm]:
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
if agent_id:
runtime_manager.register_agent(agent_id)
# Load portfolio state
portfolio_state = storage_service.load_portfolio_state()
pm.load_portfolio_state(portfolio_state)
# Create settlement coordinator
settlement_coordinator = SettlementCoordinator(
storage=storage_service,
initial_capital=initial_cash,
)
# ======================================================================
# PHASE 4: Create Pipeline & Scheduler
# ======================================================================
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
# Create pipeline
pipeline = TradingPipeline(
analysts=analysts,
risk_manager=risk_manager,
portfolio_manager=pm,
settlement_coordinator=settlement_coordinator,
max_comm_cycles=max_comm_cycles,
runtime_manager=runtime_manager,
)
# Create scheduler
scheduler_callback = None
live_scheduler = None
if is_backtest:
backtest_scheduler = BacktestScheduler(
start_date=start_date,
end_date=end_date,
trading_calendar="NYSE",
delay_between_days=0.5,
)
trading_dates = backtest_scheduler.get_trading_dates()
async def scheduler_callback_fn(callback):
await backtest_scheduler.start(callback)
scheduler_callback = scheduler_callback_fn
else:
# Live mode
live_scheduler = Scheduler(
mode=schedule_mode,
trigger_time=trigger_time,
interval_minutes=interval_minutes,
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
config={"config_name": run_id},
)
async def scheduler_callback_fn(callback):
await live_scheduler.start(callback)
scheduler_callback = scheduler_callback_fn
# ======================================================================
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
# Gateway.start() will handle the final startup sequence:
# - WebSocket Server first (frontend can connect)
# - Market Service second (price data flows)
# - Scheduler last (trading begins)
# ======================================================================
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
gateway = Gateway(
market_service=market_service,
storage_service=storage_service,
pipeline=pipeline,
scheduler_callback=scheduler_callback,
config={
"mode": mode,
"mock_mode": is_mock,
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": run_id,
"schedule_mode": schedule_mode,
"interval_minutes": interval_minutes,
"trigger_time": trigger_time,
"heartbeat_interval": heartbeat_interval,
"initial_cash": initial_cash,
"margin_requirement": margin_requirement,
"max_comm_cycles": max_comm_cycles,
"enable_memory": enable_memory,
},
scheduler=live_scheduler,
)
_set_gateway(gateway)
# Start pipeline execution
async with AsyncExitStack() as stack:
# Enter long-term memory contexts
for memory in long_term_memories:
await stack.enter_async_context(memory)
# Start Gateway - this will execute the 4-phase startup:
# Phase 1: WebSocket Server (frontend can connect immediately)
# Phase 2: Market Service (price updates start flowing)
# Phase 3: Market Status Monitor
# Phase 4: Scheduler (trading cycles begin)
gateway_task = asyncio.create_task(
gateway.start(host="0.0.0.0", port=8765)
)
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
# Wait for Gateway to fully initialize all phases
await asyncio.sleep(0.5)
# Define the trading cycle callback
async def trading_cycle(session_key: str) -> None:
"""Execute one trading cycle."""
if is_shutdown_requested():
return
runtime_manager.set_session_key(session_key)
runtime_manager.log_event("cycle:start", {"session": session_key})
try:
# Fetch market data
market_data = await market_service.get_all_data()
# Run pipeline
await pipeline.run_cycle(
session_key=session_key,
market_data=market_data,
)
runtime_manager.log_event("cycle:complete", {"session": session_key})
except Exception as e:
runtime_manager.log_event("cycle:error", {"error": str(e)})
raise
# Start scheduler
if scheduler_callback:
await scheduler_callback(trading_cycle)
# Wait for stop signal
while not stop_event.is_set():
await asyncio.sleep(1)
# Cancel gateway task
if not gateway_task.done():
gateway_task.cancel()
try:
await gateway_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError:
# Handle cancellation gracefully
raise
finally:
# Cleanup
logger.info("[Pipeline] Cleaning up...")
# Stop Gateway
try:
stop_gateway()
logger.info("[Pipeline] Gateway stopped")
except Exception as e:
logger.error(f"[Pipeline] Error stopping gateway: {e}")
clear_shutdown_event()
clear_global_runtime_manager()
from backend.api.runtime import unregister_runtime_manager
unregister_runtime_manager()
logger.info("[Pipeline] Cleanup complete")