- Add backend/core/pipeline_runner.py with full pipeline execution logic - Integrate main.py pipeline startup into REST API - Add comprehensive logging and error handling - Support mock/live/backtest modes via API Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
400 lines
12 KiB
Python
400 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Runtime API routes exposing the latest trading run state."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
|
from pydantic import BaseModel, Field
|
|
|
|
from backend.runtime.agent_runtime import AgentRuntimeState
|
|
from backend.runtime.context import TradingRunContext
|
|
from backend.runtime.manager import TradingRuntimeManager, get_global_runtime_manager
|
|
|
|
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
|
|
|
runtime_manager: Optional[TradingRuntimeManager] = None
|
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
|
|
|
# Global task reference for running pipeline
|
|
_running_task: Optional[asyncio.Task] = None
|
|
_stop_event: Optional[asyncio.Event] = None
|
|
|
|
|
|
class RunContextResponse(BaseModel):
|
|
config_name: str
|
|
run_dir: str
|
|
bootstrap_values: Dict[str, Any]
|
|
|
|
|
|
class RuntimeAgentState(BaseModel):
|
|
agent_id: str
|
|
status: str
|
|
last_session: Optional[str] = None
|
|
last_updated: str
|
|
|
|
|
|
class RuntimeAgentsResponse(BaseModel):
|
|
agents: List[RuntimeAgentState]
|
|
|
|
|
|
class RuntimeEvent(BaseModel):
|
|
timestamp: str
|
|
event: str
|
|
details: Dict[str, Any]
|
|
session: Optional[str]
|
|
|
|
|
|
class RuntimeEventsResponse(BaseModel):
|
|
events: List[RuntimeEvent]
|
|
|
|
|
|
class LaunchConfig(BaseModel):
|
|
"""Configuration for launching a new trading task."""
|
|
tickers: List[str] = Field(default_factory=list, description="股票池")
|
|
schedule_mode: str = Field(default="daily", description="调度模式: daily, interval")
|
|
interval_minutes: int = Field(default=60, ge=1, description="间隔分钟数")
|
|
trigger_time: str = Field(default="09:30", description="触发时间 HH:MM")
|
|
max_comm_cycles: int = Field(default=2, ge=1, description="最大会商轮数")
|
|
initial_cash: float = Field(default=100000.0, gt=0, description="初始资金")
|
|
margin_requirement: float = Field(default=0.0, ge=0, description="保证金要求")
|
|
enable_memory: bool = Field(default=False, description="是否启用长期记忆")
|
|
mode: str = Field(default="live", description="运行模式: live, backtest")
|
|
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
|
|
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
|
|
|
|
|
|
class LaunchResponse(BaseModel):
|
|
run_id: str
|
|
status: str
|
|
run_dir: str
|
|
message: str
|
|
|
|
|
|
class StopResponse(BaseModel):
|
|
status: str
|
|
message: str
|
|
|
|
|
|
class RestartResponse(BaseModel):
|
|
run_id: str
|
|
status: str
|
|
message: str
|
|
|
|
|
|
def _generate_run_id() -> str:
|
|
"""Generate timestamp-based run ID: YYYYMMDD_HHMMSS"""
|
|
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
|
def _get_run_dir(run_id: str) -> Path:
|
|
"""Return the run directory for a given run ID."""
|
|
return PROJECT_ROOT / "runs" / run_id
|
|
|
|
|
|
def _latest_snapshot_path() -> Optional[Path]:
|
|
candidates = sorted(
|
|
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
|
key=lambda path: path.stat().st_mtime,
|
|
reverse=True,
|
|
)
|
|
return candidates[0] if candidates else None
|
|
|
|
|
|
def _load_snapshot() -> Dict[str, Any]:
|
|
snapshot_path = _latest_snapshot_path()
|
|
if snapshot_path is None or not snapshot_path.exists():
|
|
raise HTTPException(status_code=503, detail="runtime manager is not initialized")
|
|
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
|
|
|
|
|
def _get_runtime_payload() -> Dict[str, Any]:
|
|
if runtime_manager is not None:
|
|
return runtime_manager.build_snapshot()
|
|
return _load_snapshot()
|
|
|
|
|
|
def _to_state_response(state: AgentRuntimeState) -> RuntimeAgentState:
|
|
return RuntimeAgentState(
|
|
agent_id=state.agent_id,
|
|
status=state.status,
|
|
last_session=state.last_session,
|
|
last_updated=state.last_updated.isoformat(),
|
|
)
|
|
|
|
|
|
@router.get("/context", response_model=RunContextResponse)
|
|
async def get_run_context() -> RunContextResponse:
|
|
"""Return the most recent run context."""
|
|
payload = _get_runtime_payload()
|
|
context = payload.get("context")
|
|
if context is None:
|
|
raise HTTPException(status_code=404, detail="run context is not ready")
|
|
|
|
return RunContextResponse(
|
|
config_name=context["config_name"],
|
|
run_dir=context["run_dir"],
|
|
bootstrap_values=context["bootstrap_values"],
|
|
)
|
|
|
|
|
|
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
|
async def list_agent_states() -> RuntimeAgentsResponse:
|
|
"""List the current runtime state of every registered agent."""
|
|
payload = _get_runtime_payload()
|
|
agents = [RuntimeAgentState(**agent) for agent in payload.get("agents", [])]
|
|
return RuntimeAgentsResponse(agents=agents)
|
|
|
|
|
|
@router.get("/events", response_model=RuntimeEventsResponse)
|
|
async def list_runtime_events() -> RuntimeEventsResponse:
|
|
"""Return the recent runtime events that TradingRuntimeManager emitted."""
|
|
payload = _get_runtime_payload()
|
|
events = [RuntimeEvent(**event) for event in payload.get("events", [])]
|
|
return RuntimeEventsResponse(events=events)
|
|
|
|
|
|
@router.get("/agents/{agent_id}", response_model=RuntimeAgentState)
|
|
async def get_agent_state(agent_id: str) -> RuntimeAgentState:
|
|
"""Return the current runtime state for a single agent."""
|
|
payload = _get_runtime_payload()
|
|
state = next(
|
|
(agent for agent in payload.get("agents", []) if agent["agent_id"] == agent_id),
|
|
None,
|
|
)
|
|
if state is None:
|
|
raise HTTPException(status_code=404, detail=f"agent '{agent_id}' not registered")
|
|
return RuntimeAgentState(**state)
|
|
|
|
|
|
def register_runtime_manager(manager: TradingRuntimeManager) -> None:
|
|
"""Allow other modules to expose the runtime manager to the API."""
|
|
global runtime_manager
|
|
runtime_manager = manager
|
|
|
|
|
|
def unregister_runtime_manager() -> None:
|
|
"""Drop the runtime manager reference (used for shutdown/testing)."""
|
|
global runtime_manager
|
|
runtime_manager = None
|
|
|
|
|
|
async def _stop_current_runtime(force: bool = True) -> bool:
|
|
"""Stop the current running runtime if exists.
|
|
|
|
Args:
|
|
force: If True, cancel the running task immediately
|
|
|
|
Returns:
|
|
True if a runtime was stopped, False if no runtime was running
|
|
"""
|
|
global _running_task, _stop_event
|
|
|
|
# Signal stop
|
|
if _stop_event is not None:
|
|
_stop_event.set()
|
|
|
|
# Cancel running task
|
|
if _running_task is not None and not _running_task.done():
|
|
if force:
|
|
_running_task.cancel()
|
|
try:
|
|
await _running_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
else:
|
|
# Wait for graceful shutdown
|
|
try:
|
|
await asyncio.wait_for(_running_task, timeout=30.0)
|
|
except asyncio.TimeoutError:
|
|
_running_task.cancel()
|
|
try:
|
|
await _running_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
_running_task = None
|
|
_stop_event = None
|
|
|
|
# Unregister runtime manager
|
|
if runtime_manager is not None:
|
|
unregister_runtime_manager()
|
|
|
|
return True
|
|
|
|
|
|
@router.post("/start", response_model=LaunchResponse)
|
|
async def start_runtime(
|
|
config: LaunchConfig,
|
|
background_tasks: BackgroundTasks
|
|
) -> LaunchResponse:
|
|
"""Start a new trading runtime with the given configuration.
|
|
|
|
If a runtime is already running, it will be forcefully stopped first.
|
|
Creates a new timestamped run directory.
|
|
"""
|
|
global _running_task, _stop_event, runtime_manager
|
|
|
|
# 1. Stop current runtime if exists
|
|
await _stop_current_runtime(force=True)
|
|
|
|
# 2. Generate run ID and directory
|
|
run_id = _generate_run_id()
|
|
run_dir = _get_run_dir(run_id)
|
|
|
|
# 3. Prepare bootstrap config
|
|
bootstrap = {
|
|
"tickers": config.tickers,
|
|
"schedule_mode": config.schedule_mode,
|
|
"interval_minutes": config.interval_minutes,
|
|
"trigger_time": config.trigger_time,
|
|
"max_comm_cycles": config.max_comm_cycles,
|
|
"initial_cash": config.initial_cash,
|
|
"margin_requirement": config.margin_requirement,
|
|
"enable_memory": config.enable_memory,
|
|
"mode": config.mode,
|
|
"start_date": config.start_date,
|
|
"end_date": config.end_date,
|
|
}
|
|
|
|
# 4. Create and prepare runtime manager
|
|
runtime_manager = TradingRuntimeManager(
|
|
config_name=run_id,
|
|
run_dir=run_dir,
|
|
bootstrap=bootstrap,
|
|
)
|
|
runtime_manager.prepare_run()
|
|
set_global_runtime_manager = None # Will be set by main module
|
|
|
|
# 5. Write BOOTSTRAP.md
|
|
_write_bootstrap_md(run_dir, bootstrap)
|
|
|
|
# 6. Start pipeline in background
|
|
_stop_event = asyncio.Event()
|
|
_running_task = asyncio.create_task(
|
|
_run_pipeline(run_id, run_dir, bootstrap, _stop_event)
|
|
)
|
|
|
|
return LaunchResponse(
|
|
run_id=run_id,
|
|
status="started",
|
|
run_dir=str(run_dir),
|
|
message=f"Runtime started with run_id: {run_id}",
|
|
)
|
|
|
|
|
|
@router.post("/stop", response_model=StopResponse)
|
|
async def stop_runtime(force: bool = True) -> StopResponse:
|
|
"""Stop the current running runtime.
|
|
|
|
Args:
|
|
force: If True, forcefully cancel the running task
|
|
"""
|
|
was_running = await _stop_current_runtime(force=force)
|
|
|
|
if not was_running:
|
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
|
|
|
return StopResponse(
|
|
status="stopped",
|
|
message="Runtime stopped successfully",
|
|
)
|
|
|
|
|
|
@router.post("/restart", response_model=RestartResponse)
|
|
async def restart_runtime(
|
|
config: LaunchConfig,
|
|
background_tasks: BackgroundTasks
|
|
) -> RestartResponse:
|
|
"""Restart the runtime with a new configuration.
|
|
|
|
Equivalent to stop + start.
|
|
"""
|
|
# Stop current runtime
|
|
await _stop_current_runtime(force=True)
|
|
|
|
# Start new runtime
|
|
response = await start_runtime(config, background_tasks)
|
|
|
|
return RestartResponse(
|
|
run_id=response.run_id,
|
|
status="restarted",
|
|
message=f"Runtime restarted with run_id: {response.run_id}",
|
|
)
|
|
|
|
|
|
@router.get("/current")
|
|
async def get_current_runtime():
|
|
"""Get information about the currently running runtime."""
|
|
global _running_task, runtime_manager
|
|
|
|
is_running = _running_task is not None and not _running_task.done()
|
|
|
|
if not is_running or runtime_manager is None:
|
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
|
|
|
return {
|
|
"run_id": runtime_manager.config_name,
|
|
"run_dir": str(runtime_manager.run_dir),
|
|
"is_running": is_running,
|
|
"bootstrap": runtime_manager.bootstrap,
|
|
}
|
|
|
|
|
|
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
|
"""Write bootstrap configuration to BOOTSTRAP.md."""
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
yaml = None
|
|
|
|
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
|
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Filter out None values
|
|
values = {k: v for k, v in bootstrap.items() if v is not None}
|
|
|
|
if yaml:
|
|
front_matter = yaml.safe_dump(values, allow_unicode=True, sort_keys=False)
|
|
else:
|
|
# Fallback to JSON if yaml not available
|
|
front_matter = json.dumps(values, ensure_ascii=False, indent=2)
|
|
|
|
content = f"---\n{front_matter}---\n"
|
|
bootstrap_path.write_text(content, encoding="utf-8")
|
|
|
|
|
|
async def _run_pipeline(
|
|
run_id: str,
|
|
run_dir: Path,
|
|
bootstrap: Dict[str, Any],
|
|
stop_event: asyncio.Event
|
|
) -> None:
|
|
"""Background task to run the trading pipeline."""
|
|
import logging
|
|
logger = logging.getLogger(__name__)
|
|
|
|
from backend.core.pipeline_runner import run_pipeline
|
|
|
|
try:
|
|
logger.info(f"Starting pipeline for run_id: {run_id}")
|
|
await run_pipeline(
|
|
run_id=run_id,
|
|
run_dir=run_dir,
|
|
bootstrap=bootstrap,
|
|
stop_event=stop_event,
|
|
)
|
|
logger.info(f"Pipeline completed for run_id: {run_id}")
|
|
except asyncio.CancelledError:
|
|
logger.info(f"Pipeline cancelled for run_id: {run_id}")
|
|
raise
|
|
except Exception as e:
|
|
logger.exception(f"Pipeline failed for run_id: {run_id}: {e}")
|
|
# Re-raise to allow proper cleanup
|
|
raise
|