373 lines
14 KiB
Python
373 lines
14 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Cycle and monitoring helpers extracted from the main Gateway module."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
|
|
from backend.data.market_ingest import ingest_symbols, refresh_news_for_symbols
|
|
from backend.domains import trading as trading_domain
|
|
from backend.utils.msg_adapter import FrontendAdapter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
|
|
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
|
|
if not tickers:
|
|
return
|
|
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
|
|
gateway._watchlist_ingest_task.cancel()
|
|
gateway._watchlist_ingest_task = asyncio.create_task(
|
|
refresh_market_store_for_watchlist(gateway, tickers),
|
|
)
|
|
|
|
|
|
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
|
|
"""Refresh the long-lived market store after a watchlist update."""
|
|
try:
|
|
await gateway.state_sync.on_system_message(
|
|
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
|
)
|
|
results = await asyncio.to_thread(
|
|
ingest_symbols,
|
|
tickers,
|
|
mode="incremental",
|
|
)
|
|
summary = ", ".join(
|
|
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
|
for item in results
|
|
)
|
|
await gateway.state_sync.on_system_message(
|
|
f"自选股市场数据已同步: {summary}",
|
|
)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.warning("Watchlist market store refresh failed: %s", exc)
|
|
await gateway.state_sync.on_system_message(
|
|
f"自选股市场数据同步失败: {exc}",
|
|
)
|
|
|
|
|
|
async def market_status_monitor(gateway: Any) -> None:
|
|
"""Periodically check and broadcast market status changes."""
|
|
while True:
|
|
try:
|
|
await gateway.market_service.check_and_broadcast_market_status()
|
|
|
|
status = gateway.market_service.get_market_status()
|
|
if status["status"] == "open" and not gateway.storage.is_live_session_active:
|
|
gateway.storage.start_live_session()
|
|
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
|
|
gateway._session_start_portfolio_value = summary.get(
|
|
"totalAssetValue",
|
|
gateway.storage.initial_cash,
|
|
)
|
|
logger.info(
|
|
"Session start portfolio: $%s",
|
|
f"{gateway._session_start_portfolio_value:,.2f}",
|
|
)
|
|
elif status["status"] != "open" and gateway.storage.is_live_session_active:
|
|
gateway.storage.end_live_session()
|
|
gateway._session_start_portfolio_value = None
|
|
|
|
if gateway.storage.is_live_session_active:
|
|
await update_and_broadcast_live_returns(gateway)
|
|
|
|
await asyncio.sleep(60)
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as exc:
|
|
logger.error("Market status monitor error: %s", exc)
|
|
await asyncio.sleep(60)
|
|
|
|
|
|
async def update_and_broadcast_live_returns(gateway: Any) -> None:
|
|
"""Calculate and broadcast live returns for current session."""
|
|
if not gateway.storage.is_live_session_active:
|
|
return
|
|
|
|
prices = gateway.market_service.get_all_prices()
|
|
if not prices or not any(p > 0 for p in prices.values()):
|
|
return
|
|
|
|
state = gateway.storage.load_internal_state()
|
|
equity_history = state.get("equity_history", [])
|
|
baseline_history = state.get("baseline_history", [])
|
|
baseline_vw_history = state.get("baseline_vw_history", [])
|
|
momentum_history = state.get("momentum_history", [])
|
|
|
|
current_equity = equity_history[-1]["v"] if equity_history else None
|
|
current_baseline = baseline_history[-1]["v"] if baseline_history else None
|
|
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
|
current_momentum = momentum_history[-1]["v"] if momentum_history else None
|
|
|
|
point = gateway.storage.update_live_returns(
|
|
current_equity=current_equity,
|
|
current_baseline=current_baseline,
|
|
current_baseline_vw=current_baseline_vw,
|
|
current_momentum=current_momentum,
|
|
)
|
|
if point:
|
|
live_returns = gateway.storage.get_live_returns()
|
|
await gateway.broadcast(
|
|
{
|
|
"type": "team_summary",
|
|
"equity_return": live_returns["equity_return"],
|
|
"baseline_return": live_returns["baseline_return"],
|
|
"baseline_vw_return": live_returns["baseline_vw_return"],
|
|
"momentum_return": live_returns["momentum_return"],
|
|
},
|
|
)
|
|
|
|
|
|
async def on_strategy_trigger(gateway: Any, date: str) -> None:
|
|
"""Handle trading cycle trigger."""
|
|
if gateway._cycle_lock.locked():
|
|
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
|
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
|
|
return
|
|
|
|
async with gateway._cycle_lock:
|
|
logger.info("Strategy triggered for %s", date)
|
|
tickers = gateway.config.get("tickers", [])
|
|
if gateway.is_backtest:
|
|
await run_backtest_cycle(gateway, date, tickers)
|
|
else:
|
|
await run_live_cycle(gateway, date, tickers)
|
|
|
|
|
|
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
|
|
"""Run lightweight heartbeat check for all analysts."""
|
|
logger.info("[Heartbeat] Running heartbeat check for %s", date)
|
|
analysts = gateway.pipeline._all_analysts()
|
|
|
|
for analyst in analysts:
|
|
try:
|
|
logger.debug(
|
|
"[Heartbeat] No heartbeat configured for %s, skipping",
|
|
analyst.name,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
|
|
|
|
|
|
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
|
gateway.market_service.set_backtest_date(date)
|
|
await gateway.market_service.emit_market_open()
|
|
|
|
await gateway.state_sync.on_cycle_start(date)
|
|
|
|
prices = gateway.market_service.get_open_prices()
|
|
close_prices = gateway.market_service.get_close_prices()
|
|
market_caps = await get_market_caps(gateway, tickers, date)
|
|
|
|
result = await gateway.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=date,
|
|
prices=prices,
|
|
close_prices=close_prices,
|
|
market_caps=market_caps,
|
|
)
|
|
|
|
await gateway.market_service.emit_market_close()
|
|
settlement_result = result.get("settlement_result")
|
|
save_cycle_results(gateway, result, date, close_prices, settlement_result)
|
|
await broadcast_portfolio_updates(gateway, result, close_prices)
|
|
await finalize_cycle(gateway, date)
|
|
|
|
|
|
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
|
trading_date = gateway.market_service.get_live_trading_date()
|
|
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
|
|
|
|
try:
|
|
news_refresh = await asyncio.to_thread(
|
|
refresh_news_for_symbols,
|
|
tickers,
|
|
end_date=trading_date,
|
|
store=gateway.storage.market_store,
|
|
)
|
|
logger.info(
|
|
"News refresh complete: %s",
|
|
", ".join(
|
|
f"{item['symbol']} news={item['news']}"
|
|
for item in news_refresh
|
|
) or "no symbols",
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("Live cycle news refresh failed: %s", exc)
|
|
|
|
await gateway.state_sync.on_cycle_start(trading_date)
|
|
|
|
market_caps = await get_market_caps(gateway, tickers, trading_date)
|
|
schedule_mode = gateway.config.get("schedule_mode", "daily")
|
|
market_status = gateway.market_service.get_market_status()
|
|
current_prices = gateway.market_service.get_all_prices()
|
|
|
|
if schedule_mode == "intraday":
|
|
execute_decisions = market_status.get("status") == "open"
|
|
if execute_decisions:
|
|
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
|
|
else:
|
|
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
|
|
|
|
result = await gateway.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=trading_date,
|
|
prices=current_prices,
|
|
market_caps=market_caps,
|
|
execute_decisions=execute_decisions,
|
|
)
|
|
close_prices = current_prices
|
|
else:
|
|
result = await gateway.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=trading_date,
|
|
market_caps=market_caps,
|
|
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
|
|
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
|
|
)
|
|
close_prices = gateway.market_service.get_all_prices()
|
|
|
|
settlement_result = result.get("settlement_result")
|
|
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
|
|
await broadcast_portfolio_updates(gateway, result, close_prices)
|
|
await finalize_cycle(gateway, trading_date)
|
|
|
|
|
|
async def finalize_cycle(gateway: Any, date: str) -> None:
|
|
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
|
|
summary = dashboard_snapshot.get("summary") or {}
|
|
if gateway.storage.is_live_session_active:
|
|
summary.update(gateway.storage.get_live_returns())
|
|
|
|
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
|
leaderboard = dashboard_snapshot.get("leaderboard") or []
|
|
if leaderboard:
|
|
await gateway.state_sync.on_leaderboard_update(leaderboard)
|
|
|
|
|
|
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
|
|
market_caps: dict[str, float] = {}
|
|
for ticker in tickers:
|
|
try:
|
|
market_cap = None
|
|
response = await gateway._call_trading_service(
|
|
f"get_market_cap for {ticker}",
|
|
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
|
|
)
|
|
if response is not None:
|
|
market_cap = response.get("market_cap")
|
|
if market_cap is None:
|
|
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
|
market_cap = payload.get("market_cap")
|
|
market_caps[ticker] = market_cap if market_cap else 1e9
|
|
except Exception as exc:
|
|
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
|
market_caps[ticker] = 1e9
|
|
return market_caps
|
|
|
|
|
|
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
|
|
portfolio = result.get("portfolio", {})
|
|
if portfolio:
|
|
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
|
if holdings:
|
|
await gateway.state_sync.on_holdings_update(holdings)
|
|
stats = FrontendAdapter.build_stats(portfolio, prices)
|
|
if stats:
|
|
await gateway.state_sync.on_stats_update(stats)
|
|
|
|
executed_trades = result.get("executed_trades", [])
|
|
if executed_trades:
|
|
await gateway.state_sync.on_trades_executed(executed_trades)
|
|
|
|
|
|
def save_cycle_results(
|
|
gateway: Any,
|
|
result: dict[str, Any],
|
|
date: str,
|
|
prices: dict[str, float],
|
|
settlement_result: dict[str, Any] | None = None,
|
|
) -> None:
|
|
portfolio = result.get("portfolio", {})
|
|
executed_trades = result.get("executed_trades", [])
|
|
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
|
|
if portfolio:
|
|
gateway.storage.update_dashboard_after_cycle(
|
|
portfolio=portfolio,
|
|
prices=prices,
|
|
date=date,
|
|
executed_trades=executed_trades,
|
|
baseline_values=baseline_values,
|
|
)
|
|
|
|
|
|
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
|
gateway.state_sync.set_backtest_dates(dates)
|
|
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
|
|
try:
|
|
for date in dates:
|
|
await gateway.on_strategy_trigger(date=date)
|
|
await asyncio.sleep(0.1)
|
|
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
|
|
except Exception as exc:
|
|
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
|
|
raise
|
|
finally:
|
|
gateway._backtest_task = None
|
|
|
|
|
|
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
|
|
try:
|
|
task.result()
|
|
except asyncio.CancelledError:
|
|
logger.info("Backtest task was cancelled")
|
|
except Exception as exc:
|
|
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
|
|
|
|
|
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
|
|
gateway._manual_cycle_task = None
|
|
try:
|
|
task.result()
|
|
except asyncio.CancelledError:
|
|
logger.info("Manual cycle task was cancelled")
|
|
except Exception as exc:
|
|
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
|
|
|
|
|
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
|
gateway.state_sync.set_backtest_dates(dates)
|
|
if dates:
|
|
gateway._backtest_start_date = dates[0]
|
|
gateway._backtest_end_date = dates[-1]
|
|
|
|
|
|
def stop_gateway(gateway: Any) -> None:
|
|
gateway.state_sync.save_state()
|
|
gateway.market_service.stop()
|
|
if gateway._backtest_task:
|
|
gateway._backtest_task.cancel()
|
|
if gateway._market_status_task:
|
|
gateway._market_status_task.cancel()
|
|
if gateway._watchlist_ingest_task:
|
|
gateway._watchlist_ingest_task.cancel()
|
|
# Close OpenClaw WebSocket connection
|
|
if gateway._openclaw_ws:
|
|
import asyncio
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
if loop.is_running():
|
|
loop.create_task(gateway._openclaw_ws.disconnect())
|
|
else:
|
|
loop.run_until_complete(gateway._openclaw_ws.disconnect())
|
|
except Exception:
|
|
pass
|