diff --git a/backend/api/runtime.py b/backend/api/runtime.py index def6fad..57dae78 100644 --- a/backend/api/runtime.py +++ b/backend/api/runtime.py @@ -375,21 +375,25 @@ async def _run_pipeline( bootstrap: Dict[str, Any], stop_event: asyncio.Event ) -> None: - """Background task to run the trading pipeline. + """Background task to run the trading pipeline.""" + import logging + logger = logging.getLogger(__name__) + + from backend.core.pipeline_runner import run_pipeline - This is a placeholder - actual implementation will integrate with main.py - """ try: - # TODO: Integrate with main.py pipeline execution - # This should call the actual pipeline startup logic from main.py - - # For now, just wait until stop event is set - while not stop_event.is_set(): - await asyncio.sleep(1) - + logger.info(f"Starting pipeline for run_id: {run_id}") + await run_pipeline( + run_id=run_id, + run_dir=run_dir, + bootstrap=bootstrap, + stop_event=stop_event, + ) + logger.info(f"Pipeline completed for run_id: {run_id}") except asyncio.CancelledError: - # Handle cancellation gracefully + logger.info(f"Pipeline cancelled for run_id: {run_id}") + raise + except Exception as e: + logger.exception(f"Pipeline failed for run_id: {run_id}: {e}") + # Re-raise to allow proper cleanup raise - finally: - # Cleanup - pass diff --git a/backend/core/pipeline_runner.py b/backend/core/pipeline_runner.py new file mode 100644 index 0000000..ffed2b7 --- /dev/null +++ b/backend/core/pipeline_runner.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- +""" +Pipeline Runner - Independent trading pipeline execution + +This module provides functions to start/stop trading pipelines +that can be called from the REST API. +""" + +from __future__ import annotations + +import asyncio +import os +from contextlib import AsyncExitStack +from pathlib import Path +from typing import Any, Dict, Optional, Callable + +from backend.agents import AnalystAgent, PMAgent, RiskAgent +from backend.agents.skills_manager import SkillsManager +from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles +from backend.agents.prompt_loader import PromptLoader +from backend.agents.workspace_manager import WorkspaceManager +from backend.config.constants import ANALYST_TYPES +from backend.core.pipeline import TradingPipeline +from backend.core.scheduler import BacktestScheduler, Scheduler +from backend.llm.models import get_agent_formatter, get_agent_model +from backend.runtime.manager import ( + TradingRuntimeManager, + set_global_runtime_manager, + clear_global_runtime_manager, + set_shutdown_event, + clear_shutdown_event, + is_shutdown_requested, +) +from backend.services.market import MarketService +from backend.services.storage import StorageService +from backend.utils.settlement import SettlementCoordinator + +_prompt_loader = PromptLoader() + + +async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path): + """Create ReMeTaskLongTermMemory for an agent.""" + try: + from agentscope.memory import ReMeTaskLongTermMemory + from agentscope.model import DashScopeChatModel + from agentscope.embedding import DashScopeTextEmbedding + except ImportError: + return None + + api_key = os.getenv("MEMORY_API_KEY") + if not api_key: + return None + + memory_dir = str(run_dir / "memory") + + return ReMeTaskLongTermMemory( + agent_name=agent_name, + user_name=agent_name, + model=DashScopeChatModel( + model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"), + api_key=api_key, + stream=False, + ), + embedding_model=DashScopeTextEmbedding( + model_name=os.getenv("MEMORY_EMBEDDING_MODEL", "text-embedding-v4"), + api_key=api_key, + dimensions=1024, + ), + **{ + "vector_store.default.backend": "local", + "vector_store.default.params.store_dir": memory_dir, + }, + ) + + +def create_agents( + run_id: str, + run_dir: Path, + initial_cash: float, + margin_requirement: float, + enable_long_term_memory: bool = False, +): + """Create all agents for the system.""" + analysts = [] + long_term_memories = [] + + # Initialize workspace manager and assets + workspace_manager = WorkspaceManager() + workspace_manager.initialize_default_assets( + config_name=run_id, + agent_ids=list(ANALYST_TYPES.keys()) + ["risk_manager", "portfolio_manager"], + analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"), + ) + + profiles = load_agent_profiles() + skills_manager = SkillsManager() + active_skill_map = skills_manager.prepare_active_skills( + config_name=run_id, + agent_defaults={ + agent_id: profile.get("skills", []) + for agent_id, profile in profiles.items() + }, + ) + + # Create analyst agents + for analyst_type in ANALYST_TYPES: + model = get_agent_model(analyst_type) + formatter = get_agent_formatter(analyst_type) + toolkit = create_agent_toolkit( + analyst_type, + run_id, + active_skill_dirs=active_skill_map.get(analyst_type, []), + ) + + long_term_memory = None + if enable_long_term_memory: + long_term_memory = create_long_term_memory(analyst_type, run_id, run_dir) + if long_term_memory: + long_term_memories.append(long_term_memory) + + analyst = AnalystAgent( + analyst_type=analyst_type, + toolkit=toolkit, + model=model, + formatter=formatter, + agent_id=analyst_type, + config={"config_name": run_id}, + long_term_memory=long_term_memory, + ) + analysts.append(analyst) + + # Create risk manager + risk_long_term_memory = None + if enable_long_term_memory: + risk_long_term_memory = create_long_term_memory("risk_manager", run_id, run_dir) + if risk_long_term_memory: + long_term_memories.append(risk_long_term_memory) + + risk_manager = RiskAgent( + model=get_agent_model("risk_manager"), + formatter=get_agent_formatter("risk_manager"), + name="risk_manager", + config={"config_name": run_id}, + long_term_memory=risk_long_term_memory, + toolkit=create_agent_toolkit( + "risk_manager", + run_id, + active_skill_dirs=active_skill_map.get("risk_manager", []), + ), + ) + + # Create portfolio manager + pm_long_term_memory = None + if enable_long_term_memory: + pm_long_term_memory = create_long_term_memory("portfolio_manager", run_id, run_dir) + if pm_long_term_memory: + long_term_memories.append(pm_long_term_memory) + + portfolio_manager = PMAgent( + name="portfolio_manager", + model=get_agent_model("portfolio_manager"), + formatter=get_agent_formatter("portfolio_manager"), + initial_cash=initial_cash, + margin_requirement=margin_requirement, + config={"config_name": run_id}, + long_term_memory=pm_long_term_memory, + toolkit_factory=create_agent_toolkit, + toolkit_factory_kwargs={ + "active_skill_dirs": active_skill_map.get("portfolio_manager", []), + }, + ) + + return analysts, risk_manager, portfolio_manager, long_term_memories + + +async def run_pipeline( + run_id: str, + run_dir: Path, + bootstrap: Dict[str, Any], + stop_event: asyncio.Event, + message_callback: Optional[Callable[[str, Any], None]] = None, +) -> None: + """ + Run the trading pipeline with the given configuration. + + Args: + run_id: Unique run identifier (timestamp) + run_dir: Run directory path + bootstrap: Bootstrap configuration + stop_event: Event to signal pipeline stop + message_callback: Optional callback for sending messages to clients + """ + import logging + logger = logging.getLogger(__name__) + + # Set global shutdown event + set_shutdown_event(stop_event) + + logger.info(f"[Pipeline {run_id}] Starting...") + + try: + # Extract config values + tickers = bootstrap.get("tickers", ["AAPL", "MSFT"]) + initial_cash = float(bootstrap.get("initial_cash", 100000.0)) + margin_requirement = float(bootstrap.get("margin_requirement", 0.0)) + max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2)) + schedule_mode = bootstrap.get("schedule_mode", "daily") + trigger_time = bootstrap.get("trigger_time", "09:30") + interval_minutes = int(bootstrap.get("interval_minutes", 60)) + mode = bootstrap.get("mode", "live") + start_date = bootstrap.get("start_date") + end_date = bootstrap.get("end_date") + enable_memory = bootstrap.get("enable_memory", False) + + is_backtest = mode == "backtest" + is_mock = mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true") + + # Get or create runtime manager + from backend.api.runtime import runtime_manager + + if runtime_manager is None: + runtime_manager = TradingRuntimeManager( + config_name=run_id, + run_dir=run_dir, + bootstrap=bootstrap, + ) + runtime_manager.prepare_run() + + set_global_runtime_manager(runtime_manager) + + # Register runtime manager with API + from backend.api.runtime import register_runtime_manager + register_runtime_manager(runtime_manager) + + # Create market service + market_service = MarketService( + tickers=tickers, + poll_interval=10, + mock_mode=is_mock and not is_backtest, + backtest_mode=is_backtest, + api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None, + backtest_start_date=start_date if is_backtest else None, + backtest_end_date=end_date if is_backtest else None, + ) + + # Create storage service + storage_service = StorageService( + dashboard_dir=run_dir / "team_dashboard", + initial_cash=initial_cash, + config_name=run_id, + ) + + if not storage_service.files["summary"].exists(): + storage_service.initialize_empty_dashboard() + else: + storage_service.update_leaderboard_model_info() + + # Create agents and pipeline + analysts, risk_manager, pm, long_term_memories = create_agents( + run_id=run_id, + run_dir=run_dir, + initial_cash=initial_cash, + margin_requirement=margin_requirement, + enable_long_term_memory=enable_memory, + ) + + # Register agents with runtime manager + for agent in analysts + [risk_manager, pm]: + agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None) + if agent_id: + runtime_manager.register_agent(agent_id) + + # Load portfolio state + portfolio_state = storage_service.load_portfolio_state() + pm.load_portfolio_state(portfolio_state) + + # Create settlement coordinator + settlement_coordinator = SettlementCoordinator( + storage=storage_service, + initial_capital=initial_cash, + ) + + # Create pipeline + pipeline = TradingPipeline( + analysts=analysts, + risk_manager=risk_manager, + portfolio_manager=pm, + settlement_coordinator=settlement_coordinator, + max_comm_cycles=max_comm_cycles, + runtime_manager=runtime_manager, + ) + + # Create scheduler + scheduler_callback = None + live_scheduler = None + + if is_backtest: + backtest_scheduler = BacktestScheduler( + start_date=start_date, + end_date=end_date, + trading_calendar="NYSE", + delay_between_days=0.5, + ) + trading_dates = backtest_scheduler.get_trading_dates() + + async def scheduler_callback_fn(callback): + await backtest_scheduler.start(callback) + + scheduler_callback = scheduler_callback_fn + else: + # Live mode + live_scheduler = Scheduler( + mode=schedule_mode, + trigger_time=trigger_time, + interval_minutes=interval_minutes, + config={"config_name": run_id}, + ) + + async def scheduler_callback_fn(callback): + await live_scheduler.start(callback) + + scheduler_callback = scheduler_callback_fn + + # Start pipeline execution + async with AsyncExitStack() as stack: + # Enter long-term memory contexts + for memory in long_term_memories: + await stack.enter_async_context(memory) + + # Define the trading cycle callback + async def trading_cycle(session_key: str) -> None: + """Execute one trading cycle.""" + if is_shutdown_requested(): + return + + runtime_manager.set_session_key(session_key) + runtime_manager.log_event("cycle:start", {"session": session_key}) + + try: + # Fetch market data + market_data = await market_service.get_all_data() + + # Run pipeline + await pipeline.run_cycle( + session_key=session_key, + market_data=market_data, + ) + + runtime_manager.log_event("cycle:complete", {"session": session_key}) + + except Exception as e: + runtime_manager.log_event("cycle:error", {"error": str(e)}) + raise + + # Start scheduler + if scheduler_callback: + await scheduler_callback(trading_cycle) + + # Wait for stop signal + while not stop_event.is_set(): + await asyncio.sleep(1) + + except asyncio.CancelledError: + # Handle cancellation gracefully + raise + finally: + # Cleanup + clear_shutdown_event() + clear_global_runtime_manager() + from backend.api.runtime import unregister_runtime_manager + unregister_runtime_manager()