feat: 微服务架构拆分和前后端优化
后端: - 拆分出 agent_service, runtime_service, trading_service, news_service - Gateway 模块化拆分 (gateway_*.py) - 添加 domains/ 领域层 - 新增 control_client, runtime_client - 更新 start-dev.sh 支持 split 服务模式 前端: - 完善 API 服务层 (newsApi, tradingApi) - 更新 vite.config.js - Explain 组件优化 测试: - 添加多个服务 app 测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
373
backend/services/gateway_cycle_support.py
Normal file
373
backend/services/gateway_cycle_support.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Cycle and monitoring helpers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
|
||||
if not tickers:
|
||||
return
|
||||
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
gateway._watchlist_ingest_task = asyncio.create_task(
|
||||
refresh_market_store_for_watchlist(gateway, tickers),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Refresh the long-lived market store after a watchlist update."""
|
||||
try:
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
ingest_symbols,
|
||||
tickers,
|
||||
mode="incremental",
|
||||
)
|
||||
summary = ", ".join(
|
||||
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||
for item in results
|
||||
)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据已同步: {summary}",
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据同步失败: {exc}",
|
||||
)
|
||||
|
||||
|
||||
async def market_status_monitor(gateway: Any) -> None:
|
||||
"""Periodically check and broadcast market status changes."""
|
||||
while True:
|
||||
try:
|
||||
await gateway.market_service.check_and_broadcast_market_status()
|
||||
|
||||
status = gateway.market_service.get_market_status()
|
||||
if status["status"] == "open" and not gateway.storage.is_live_session_active:
|
||||
gateway.storage.start_live_session()
|
||||
summary = gateway.storage.load_file("summary") or {}
|
||||
gateway._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
gateway.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: $%s",
|
||||
f"{gateway._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif status["status"] != "open" and gateway.storage.is_live_session_active:
|
||||
gateway.storage.end_live_session()
|
||||
gateway._session_start_portfolio_value = None
|
||||
|
||||
if gateway.storage.is_live_session_active:
|
||||
await update_and_broadcast_live_returns(gateway)
|
||||
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Market status monitor error: %s", exc)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
async def update_and_broadcast_live_returns(gateway: Any) -> None:
|
||||
"""Calculate and broadcast live returns for current session."""
|
||||
if not gateway.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
prices = gateway.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
state = gateway.storage.load_internal_state()
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = baseline_history[-1]["v"] if baseline_history else None
|
||||
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
current_momentum = momentum_history[-1]["v"] if momentum_history else None
|
||||
|
||||
point = gateway.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
if point:
|
||||
live_returns = gateway.storage.get_live_returns()
|
||||
await gateway.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def on_strategy_trigger(gateway: Any, date: str) -> None:
|
||||
"""Handle trading cycle trigger."""
|
||||
if gateway._cycle_lock.locked():
|
||||
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
||||
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
|
||||
return
|
||||
|
||||
async with gateway._cycle_lock:
|
||||
logger.info("Strategy triggered for %s", date)
|
||||
tickers = gateway.config.get("tickers", [])
|
||||
if gateway.is_backtest:
|
||||
await run_backtest_cycle(gateway, date, tickers)
|
||||
else:
|
||||
await run_live_cycle(gateway, date, tickers)
|
||||
|
||||
|
||||
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
|
||||
"""Run lightweight heartbeat check for all analysts."""
|
||||
logger.info("[Heartbeat] Running heartbeat check for %s", date)
|
||||
analysts = gateway.pipeline._all_analysts()
|
||||
|
||||
for analyst in analysts:
|
||||
try:
|
||||
ws_id = getattr(analyst, "workspace_id", None)
|
||||
if ws_id:
|
||||
from backend.agents.workspace_manager import get_workspace_dir
|
||||
from pathlib import Path
|
||||
from agentscope.message import Msg
|
||||
|
||||
ws_dir = get_workspace_dir(ws_id)
|
||||
if ws_dir:
|
||||
hb_path = Path(ws_dir) / "HEARTBEAT.md"
|
||||
if hb_path.exists():
|
||||
content = hb_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
|
||||
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
|
||||
msg = Msg(role="user", content=hb_task, name="system")
|
||||
await analyst.reply([msg])
|
||||
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
|
||||
continue
|
||||
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
|
||||
except Exception as exc:
|
||||
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
|
||||
|
||||
|
||||
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
gateway.market_service.set_backtest_date(date)
|
||||
await gateway.market_service.emit_market_open()
|
||||
|
||||
await gateway.state_sync.on_cycle_start(date)
|
||||
gateway._dashboard.update(date=date, status="Analyzing...")
|
||||
|
||||
prices = gateway.market_service.get_open_prices()
|
||||
close_prices = gateway.market_service.get_close_prices()
|
||||
market_caps = await get_market_caps(gateway, tickers, date)
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await gateway.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, date)
|
||||
|
||||
|
||||
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
trading_date = gateway.market_service.get_live_trading_date()
|
||||
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
|
||||
|
||||
await gateway.state_sync.on_cycle_start(trading_date)
|
||||
gateway._dashboard.update(date=trading_date, status="Analyzing...")
|
||||
|
||||
market_caps = await get_market_caps(gateway, tickers, trading_date)
|
||||
schedule_mode = gateway.config.get("schedule_mode", "daily")
|
||||
market_status = gateway.market_service.get_market_status()
|
||||
current_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
if schedule_mode == "intraday":
|
||||
execute_decisions = market_status.get("status") == "open"
|
||||
if execute_decisions:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
|
||||
else:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
prices=current_prices,
|
||||
market_caps=market_caps,
|
||||
execute_decisions=execute_decisions,
|
||||
)
|
||||
close_prices = current_prices
|
||||
else:
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
|
||||
)
|
||||
close_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, trading_date)
|
||||
|
||||
|
||||
async def finalize_cycle(gateway: Any, date: str) -> None:
|
||||
summary = gateway.storage.load_file("summary") or {}
|
||||
if gateway.storage.is_live_session_active:
|
||||
summary.update(gateway.storage.get_live_returns())
|
||||
|
||||
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
holdings = gateway.storage.load_file("holdings") or []
|
||||
trades = gateway.storage.load_file("trades") or []
|
||||
leaderboard = gateway.storage.load_file("leaderboard") or []
|
||||
if leaderboard:
|
||||
await gateway.state_sync.on_leaderboard_update(leaderboard)
|
||||
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
|
||||
|
||||
|
||||
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
|
||||
market_caps: dict[str, float] = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = None
|
||||
response = await gateway._call_trading_service(
|
||||
f"get_market_cap for {ticker}",
|
||||
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
|
||||
)
|
||||
if response is not None:
|
||||
market_cap = response.get("market_cap")
|
||||
if market_cap is None:
|
||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||
market_cap = payload.get("market_cap")
|
||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||
market_caps[ticker] = 1e9
|
||||
return market_caps
|
||||
|
||||
|
||||
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await gateway.state_sync.on_holdings_update(holdings)
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await gateway.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await gateway.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
|
||||
def save_cycle_results(
|
||||
gateway: Any,
|
||||
result: dict[str, Any],
|
||||
date: str,
|
||||
prices: dict[str, float],
|
||||
settlement_result: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
|
||||
if portfolio:
|
||||
gateway.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
|
||||
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
gateway._dashboard.update(days_total=len(dates), days_completed=0)
|
||||
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
|
||||
try:
|
||||
for i, date in enumerate(dates):
|
||||
gateway._dashboard.update(days_completed=i)
|
||||
await gateway.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
|
||||
summary = gateway.storage.load_file("summary") or {}
|
||||
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
|
||||
gateway._dashboard.stop()
|
||||
gateway._dashboard.print_final_summary()
|
||||
except Exception as exc:
|
||||
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
|
||||
gateway._dashboard.update(status=f"Failed: {str(exc)}")
|
||||
gateway._dashboard.stop()
|
||||
raise
|
||||
finally:
|
||||
gateway._backtest_task = None
|
||||
|
||||
|
||||
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
gateway._manual_cycle_task = None
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Manual cycle task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
gateway._backtest_start_date = dates[0]
|
||||
gateway._backtest_end_date = dates[-1]
|
||||
gateway._dashboard.days_total = len(dates)
|
||||
|
||||
|
||||
def stop_gateway(gateway: Any) -> None:
|
||||
gateway.state_sync.save_state()
|
||||
gateway.market_service.stop()
|
||||
if gateway._backtest_task:
|
||||
gateway._backtest_task.cancel()
|
||||
if gateway._market_status_task:
|
||||
gateway._market_status_task.cancel()
|
||||
if gateway._watchlist_ingest_task:
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
gateway._dashboard.stop()
|
||||
Reference in New Issue
Block a user