Initial commit of integrated agent system

This commit is contained in:
cillin
2026-03-30 17:46:44 +08:00
commit 0fa413380c
337 changed files with 75268 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Services layer for infrastructure components"""

923
backend/services/gateway.py Normal file
View File

@@ -0,0 +1,923 @@
# -*- coding: utf-8 -*-
"""
WebSocket Gateway for frontend communication
"""
import asyncio
import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set
import websockets
from websockets.asyncio.server import ServerConnection
from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain
from backend.llm.models import get_agent_model_info
from backend.core.pipeline import TradingPipeline
from backend.core.state_sync import StateSync
from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.data.provider_router import get_provider_router
from backend.tools.technical_signals import StockTechnicalAnalyzer
from backend.core.scheduler import Scheduler
from backend.services import gateway_admin_handlers
from backend.services import gateway_cycle_support
from backend.services import gateway_openclaw_handlers
from backend.services import gateway_runtime_support
from backend.services import gateway_stock_handlers
from shared.client import NewsServiceClient
from shared.client import TradingServiceClient
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient, DEFAULT_GATEWAY_URL as OPENCLAW_WS_URL
logger = logging.getLogger(__name__)
EDITABLE_AGENT_WORKSPACE_FILES = {
"SOUL.md",
"PROFILE.md",
"AGENTS.md",
"MEMORY.md",
"POLICY.md",
}
class Gateway:
"""WebSocket Gateway for frontend communication"""
def __init__(
self,
market_service: MarketService,
storage_service: StorageService,
pipeline: TradingPipeline,
state_sync: Optional[StateSync] = None,
scheduler_callback: Optional[Callable] = None,
scheduler: Optional[Scheduler] = None,
config: Dict[str, Any] = None,
):
self.market_service = market_service
self.storage = storage_service
self.pipeline = pipeline
self.scheduler_callback = scheduler_callback
self.scheduler = scheduler
self.config = config or {}
self.mode = self.config.get("mode", "live")
self.is_backtest = self.mode == "backtest" or self.config.get(
"backtest_mode",
False,
)
self.state_sync = state_sync or StateSync(storage=storage_service)
# self.state_sync.set_mode(self.is_backtest)
self.state_sync.set_broadcast_fn(self.broadcast)
self.pipeline.state_sync = self.state_sync
self.connected_clients: Set[ServerConnection] = set()
self.lock = asyncio.Lock()
self._cycle_lock = asyncio.Lock()
self._backtest_task: Optional[asyncio.Task] = None
self._manual_cycle_task: Optional[asyncio.Task] = None
self._backtest_start_date: Optional[str] = None
self._backtest_end_date: Optional[str] = None
self._market_status_task: Optional[asyncio.Task] = None
self._watchlist_ingest_task: Optional[asyncio.Task] = None
# Session tracking for live returns
self._session_start_portfolio_value: Optional[float] = None
self._provider_router = get_provider_router()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._project_root = Path(__file__).resolve().parents[2]
self._technical_analyzer = StockTechnicalAnalyzer()
self._openclaw_ws: OpenClawWebSocketClient | None = None
async def start(self, host: str = "0.0.0.0", port: int = 8766):
"""Start gateway server with proper initialization order.
Phase 1: Start WebSocket server first so frontend can connect immediately
Phase 2: Start market data service (pushes data to connected clients)
Phase 3: Start scheduler last (triggers trading cycles)
"""
logger.info(f"Starting gateway on {host}:{port}")
self._loop = asyncio.get_running_loop()
self._provider_router.add_listener(self._on_provider_usage_changed)
self.state_sync.load_state()
self.market_service.set_price_recorder(self.storage.record_price_point)
self.state_sync.update_state("status", "initializing")
self.state_sync.update_state("server_mode", self.mode)
self.state_sync.update_state("is_backtest", self.is_backtest)
self.state_sync.update_state("tickers", self.config.get("tickers", []))
self.state_sync.update_state(
"runtime_config",
{
"tickers": self.config.get("tickers", []),
"schedule_mode": self.config.get("schedule_mode", "daily"),
"interval_minutes": self.config.get("interval_minutes", 60),
"trigger_time": self.config.get("trigger_time", "09:30"),
"initial_cash": self.config.get(
"initial_cash",
self.storage.initial_cash,
),
"margin_requirement": self.config.get("margin_requirement"),
"max_comm_cycles": self.config.get("max_comm_cycles"),
"enable_memory": self.config.get("enable_memory", False),
},
)
self.state_sync.update_state(
"data_sources",
self._provider_router.get_usage_snapshot(),
)
# Load and display existing portfolio state if available
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self.state_sync.state)
summary = dashboard_snapshot.get("summary")
if summary:
logger.info(
"Loaded existing portfolio: $%s",
f"{summary.get('totalAssetValue', 0):,.2f}",
)
# ======================================================================
# PHASE 1: Start WebSocket server first
# This allows frontend to connect immediately and receive status updates
# ======================================================================
logger.info("[Phase 1/4] Starting WebSocket server...")
self.state_sync.update_state("status", "websocket_ready")
# Create server but don't block yet - we'll serve inside the context manager
server = await websockets.serve(
self.handle_client,
host,
port,
ping_interval=30,
ping_timeout=60,
)
logger.info(f"WebSocket server ready: ws://{host}:{port}")
# Give a brief moment for any existing clients to reconnect
await asyncio.sleep(0.1)
# Connect to OpenClaw Gateway (18789) via WebSocket
logger.info("Connecting to OpenClaw Gateway...")
try:
self._openclaw_ws = OpenClawWebSocketClient(
url=OPENCLAW_WS_URL,
client_name="gateway-client",
client_version="1.0.0",
)
await self._openclaw_ws.connect()
logger.info("OpenClaw Gateway WebSocket connected")
except Exception as e:
logger.warning("Failed to connect to OpenClaw Gateway: %s", e)
self._openclaw_ws = None
# ======================================================================
# PHASE 2: Start market data service
# Now frontend is connected, start pushing price updates
# ======================================================================
logger.info("[Phase 2/4] Starting market data service...")
self.state_sync.update_state("status", "market_service_starting")
await self.market_service.start(broadcast_func=self.broadcast)
self.state_sync.update_state("status", "market_service_ready")
logger.info("Market data service ready - price updates active")
# ======================================================================
# PHASE 3: Start market status monitoring
# Monitors market open/close and broadcasts status
# ======================================================================
logger.info("[Phase 3/4] Starting market status monitoring...")
if not self.is_backtest:
self._market_status_task = asyncio.create_task(
self._market_status_monitor(),
)
# ======================================================================
# PHASE 4: Start scheduler last
# Only start trading after everything else is ready
# ======================================================================
logger.info("[Phase 4/4] Starting scheduler...")
self.state_sync.update_state("status", "scheduler_starting")
if self.scheduler:
# Wire up heartbeat callback if heartbeat is configured
heartbeat_interval = self.config.get("heartbeat_interval", 0)
if heartbeat_interval and heartbeat_interval > 0:
self.scheduler.set_heartbeat_callback(self.on_heartbeat_trigger)
logger.info(
f"[Heartbeat] Registered heartbeat callback (interval={heartbeat_interval}s)",
)
await self.scheduler.start(self.on_strategy_trigger)
elif self.scheduler_callback:
await self.scheduler_callback(callback=self.on_strategy_trigger)
self.state_sync.update_state("status", "running")
logger.info(
f"Gateway fully operational: ws://{host}:{port}, mode={self.mode}",
)
# Keep server running
await asyncio.Future()
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
"""Handle provider routing updates from the shared router."""
self.state_sync.update_state("data_sources", snapshot)
if self._loop and self._loop.is_running():
asyncio.run_coroutine_threadsafe(
self.broadcast(
{
"type": "data_sources_update",
"data_sources": snapshot,
},
),
self._loop,
)
@property
def state(self) -> Dict[str, Any]:
return self.state_sync.state
@staticmethod
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
return news_domain.news_rows_need_enrichment(rows)
def _news_service_url(self) -> str | None:
"""Return configured news-service base URL, if any."""
candidate = self.config.get("news_service_url") or os.getenv(
"NEWS_SERVICE_URL",
"",
)
value = str(candidate or "").strip()
return value or None
def _trading_service_url(self) -> str | None:
"""Return configured trading-service base URL, if any."""
candidate = self.config.get("trading_service_url") or os.getenv(
"TRADING_SERVICE_URL",
"",
)
value = str(candidate or "").strip()
return value or None
async def _call_news_service(
self,
action: str,
callback: Callable[[NewsServiceClient], Any],
) -> Any | None:
"""Call news-service when configured, otherwise return None."""
service_url = self._news_service_url()
if not service_url:
return None
try:
async with NewsServiceClient(service_url) as client:
return await callback(client)
except Exception as exc:
logger.warning("news-service %s failed: %s", action, exc)
return None
async def _call_trading_service(
self,
action: str,
callback: Callable[[TradingServiceClient], Any],
) -> Any | None:
"""Call trading-service when configured, otherwise return None."""
service_url = self._trading_service_url()
if not service_url:
return None
try:
async with TradingServiceClient(service_url) as client:
return await callback(client)
except Exception as exc:
logger.warning("trading-service %s failed: %s", action, exc)
return None
async def handle_client(self, websocket: ServerConnection):
"""Handle WebSocket client connection"""
async with self.lock:
self.connected_clients.add(websocket)
await self._send_initial_state(websocket)
await self._handle_client_messages(websocket)
async with self.lock:
self.connected_clients.discard(websocket)
async def _send_initial_state(self, websocket: ServerConnection):
try:
logger.info("[Gateway] Sending initial state to client...")
state_payload = self.state_sync.get_initial_state_payload(
include_dashboard=True,
)
state_payload["data_sources"] = (
self._provider_router.get_usage_snapshot()
)
# Include market status in initial state
state_payload[
"market_status"
] = self.market_service.get_market_status()
# Include live returns if session is active
if self.storage.is_live_session_active:
live_returns = self.storage.get_live_returns()
if "portfolio" in state_payload:
state_payload["portfolio"].update(live_returns)
await websocket.send(
json.dumps(
{"type": "initial_state", "state": state_payload},
ensure_ascii=False,
default=str,
),
)
logger.info("[Gateway] Initial state sent successfully")
except Exception as e:
logger.exception(f"[Gateway] Failed to send initial state: {e}")
# Send error response so client knows something went wrong
try:
await websocket.send(
json.dumps(
{"type": "error", "message": "Failed to load initial state"},
ensure_ascii=False,
),
)
except Exception as e:
logger.warning(f"Failed to send error response to client: {e}")
async def _handle_client_messages(
self,
websocket: ServerConnection,
):
try:
async for message in websocket:
data = json.loads(message)
msg_type = data.get("type", "unknown")
if msg_type == "ping":
await websocket.send(
json.dumps(
{
"type": "pong",
"timestamp": datetime.now().isoformat(),
},
ensure_ascii=False,
),
)
elif msg_type == "get_state":
await self._send_initial_state(websocket)
elif msg_type == "start_backtest":
await self._handle_start_backtest(data)
elif msg_type == "trigger_strategy":
await self._handle_manual_trigger(websocket, data)
elif msg_type == "update_runtime_config":
await self._handle_update_runtime_config(websocket, data)
elif msg_type == "reload_runtime_assets":
await self._handle_reload_runtime_assets()
elif msg_type == "update_watchlist":
await self._handle_update_watchlist(websocket, data)
elif msg_type == "get_agent_skills":
await self._handle_get_agent_skills(websocket, data)
elif msg_type == "get_agent_profile":
await self._handle_get_agent_profile(websocket, data)
elif msg_type == "get_skill_detail":
await self._handle_get_skill_detail(websocket, data)
elif msg_type == "create_agent_local_skill":
await self._handle_create_agent_local_skill(websocket, data)
elif msg_type == "update_agent_local_skill":
await self._handle_update_agent_local_skill(websocket, data)
elif msg_type == "delete_agent_local_skill":
await self._handle_delete_agent_local_skill(websocket, data)
elif msg_type == "remove_agent_skill":
await self._handle_remove_agent_skill(websocket, data)
elif msg_type == "update_agent_skill":
await self._handle_update_agent_skill(websocket, data)
elif msg_type == "get_agent_workspace_file":
await self._handle_get_agent_workspace_file(websocket, data)
elif msg_type == "update_agent_workspace_file":
await self._handle_update_agent_workspace_file(websocket, data)
elif msg_type == "get_stock_history":
await self._handle_get_stock_history(websocket, data)
elif msg_type == "get_stock_explain_events":
await self._handle_get_stock_explain_events(websocket, data)
elif msg_type == "get_stock_news":
await self._handle_get_stock_news(websocket, data)
elif msg_type == "get_stock_news_for_date":
await self._handle_get_stock_news_for_date(websocket, data)
elif msg_type == "get_stock_news_timeline":
await self._handle_get_stock_news_timeline(websocket, data)
elif msg_type == "get_stock_news_categories":
await self._handle_get_stock_news_categories(websocket, data)
elif msg_type == "get_stock_range_explain":
await self._handle_get_stock_range_explain(websocket, data)
elif msg_type == "get_stock_insider_trades":
await self._handle_get_stock_insider_trades(websocket, data)
elif msg_type == "get_stock_story":
await self._handle_get_stock_story(websocket, data)
elif msg_type == "get_stock_similar_days":
await self._handle_get_stock_similar_days(websocket, data)
elif msg_type == "get_stock_technical_indicators":
await self._handle_get_stock_technical_indicators(websocket, data)
elif msg_type == "run_stock_enrich":
await self._handle_run_stock_enrich(websocket, data)
elif msg_type == "get_openclaw_status":
await self._handle_get_openclaw_status(websocket, data)
elif msg_type == "get_openclaw_sessions":
await self._handle_get_openclaw_sessions(websocket, data)
elif msg_type == "get_openclaw_session_detail":
await self._handle_get_openclaw_session_detail(websocket, data)
elif msg_type == "get_openclaw_session_history":
await self._handle_get_openclaw_session_history(websocket, data)
elif msg_type == "get_openclaw_cron":
await self._handle_get_openclaw_cron(websocket, data)
elif msg_type == "get_openclaw_approvals":
await self._handle_get_openclaw_approvals(websocket, data)
elif msg_type == "get_openclaw_agents":
await self._handle_get_openclaw_agents(websocket, data)
elif msg_type == "get_openclaw_agents_presence":
await self._handle_get_openclaw_agents_presence(websocket, data)
elif msg_type == "get_openclaw_skills":
await self._handle_get_openclaw_skills(websocket, data)
elif msg_type == "get_openclaw_models":
await self._handle_get_openclaw_models(websocket, data)
elif msg_type == "get_openclaw_hooks":
await gateway_openclaw_handlers.handle_get_openclaw_hooks(self, websocket, data)
elif msg_type == "get_openclaw_plugins":
await gateway_openclaw_handlers.handle_get_openclaw_plugins(self, websocket, data)
elif msg_type == "get_openclaw_secrets_audit":
await gateway_openclaw_handlers.handle_get_openclaw_secrets_audit(self, websocket, data)
elif msg_type == "get_openclaw_security_audit":
await gateway_openclaw_handlers.handle_get_openclaw_security_audit(self, websocket, data)
elif msg_type == "get_openclaw_daemon_status":
await gateway_openclaw_handlers.handle_get_openclaw_daemon_status(self, websocket, data)
elif msg_type == "get_openclaw_pairing":
await gateway_openclaw_handlers.handle_get_openclaw_pairing(self, websocket, data)
elif msg_type == "get_openclaw_qr":
await gateway_openclaw_handlers.handle_get_openclaw_qr(self, websocket, data)
elif msg_type == "get_openclaw_update_status":
await gateway_openclaw_handlers.handle_get_openclaw_update_status(self, websocket, data)
elif msg_type == "get_openclaw_models_aliases":
await gateway_openclaw_handlers.handle_get_openclaw_models_aliases(self, websocket, data)
elif msg_type == "get_openclaw_models_fallbacks":
await gateway_openclaw_handlers.handle_get_openclaw_models_fallbacks(self, websocket, data)
elif msg_type == "get_openclaw_models_image_fallbacks":
await gateway_openclaw_handlers.handle_get_openclaw_models_image_fallbacks(self, websocket, data)
elif msg_type == "get_openclaw_skill_update":
await gateway_openclaw_handlers.handle_get_openclaw_skill_update(self, websocket, data)
elif msg_type == "get_openclaw_workspace_files":
await gateway_openclaw_handlers.handle_get_openclaw_workspace_files(self, websocket, data)
elif msg_type == "get_openclaw_workspace_file":
await gateway_openclaw_handlers.handle_get_openclaw_workspace_file(self, websocket, data)
elif msg_type == "openclaw_resolve_session":
await gateway_openclaw_handlers.handle_openclaw_resolve_session(self, websocket, data)
elif msg_type == "openclaw_create_session":
await gateway_openclaw_handlers.handle_openclaw_create_session(self, websocket, data)
elif msg_type == "openclaw_send_message":
await gateway_openclaw_handlers.handle_openclaw_send_message(self, websocket, data)
elif msg_type == "openclaw_subscribe_session":
await gateway_openclaw_handlers.handle_openclaw_subscribe_session(self, websocket, data)
elif msg_type == "openclaw_unsubscribe_session":
await gateway_openclaw_handlers.handle_openclaw_unsubscribe_session(self, websocket, data)
elif msg_type == "openclaw_reset_session":
await gateway_openclaw_handlers.handle_openclaw_reset_session(self, websocket, data)
elif msg_type == "openclaw_delete_session":
await gateway_openclaw_handlers.handle_openclaw_delete_session(self, websocket, data)
except websockets.ConnectionClosed:
pass
except json.JSONDecodeError:
pass
finally:
subscriber_map = getattr(self, "_openclaw_session_subscribers", None)
if isinstance(subscriber_map, dict):
subscriber_map.pop(websocket, None)
async def _handle_get_stock_history(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_history(self, websocket, data)
async def _handle_get_stock_explain_events(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_explain_events(self, websocket, data)
async def _handle_get_stock_news(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_news(self, websocket, data)
async def _handle_get_stock_news_for_date(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_news_for_date(self, websocket, data)
async def _handle_get_stock_news_timeline(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_news_timeline(self, websocket, data)
async def _handle_get_stock_news_categories(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_news_categories(self, websocket, data)
async def _handle_get_stock_range_explain(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_range_explain(self, websocket, data)
async def _handle_get_stock_insider_trades(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_insider_trades(self, websocket, data)
async def _handle_get_stock_story(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_story(self, websocket, data)
async def _handle_get_stock_similar_days(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_similar_days(self, websocket, data)
async def _handle_get_stock_technical_indicators(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_get_stock_technical_indicators(self, websocket, data)
async def _handle_run_stock_enrich(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
await gateway_stock_handlers.handle_run_stock_enrich(self, websocket, data)
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
dates = data.get("dates", [])
if dates and self._backtest_task is None:
task = asyncio.create_task(
self._run_backtest_dates(dates),
)
task.add_done_callback(self._handle_backtest_exception)
self._backtest_task = task
async def _handle_manual_trigger(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
"""Run one live trading cycle on demand."""
if self.is_backtest:
await websocket.send(
json.dumps(
{
"type": "error",
"message": "Manual trigger is only available in live mode.",
},
ensure_ascii=False,
),
)
return
if (
self._cycle_lock.locked()
or (
self._manual_cycle_task is not None
and not self._manual_cycle_task.done()
)
):
await websocket.send(
json.dumps(
{
"type": "error",
"message": "A trading cycle is already running.",
},
ensure_ascii=False,
),
)
await self.state_sync.on_system_message("已有任务在运行,已忽略手动触发")
return
requested_date = data.get("date")
await self.state_sync.on_system_message("收到手动触发请求,准备开始新一轮分析与决策")
task = asyncio.create_task(
self.on_strategy_trigger(
date=requested_date or datetime.now().strftime("%Y-%m-%d"),
),
)
task.add_done_callback(self._handle_manual_cycle_exception)
self._manual_cycle_task = task
async def _handle_reload_runtime_assets(self):
await gateway_admin_handlers.handle_reload_runtime_assets(self)
async def _handle_update_runtime_config(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_update_runtime_config(self, websocket, data)
async def _handle_update_watchlist(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_update_watchlist(self, websocket, data)
async def _handle_get_agent_skills(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_get_agent_skills(self, websocket, data)
async def _handle_get_agent_profile(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_get_agent_profile(self, websocket, data)
async def _handle_get_skill_detail(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_get_skill_detail(self, websocket, data)
async def _handle_create_agent_local_skill(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_create_agent_local_skill(self, websocket, data)
async def _handle_update_agent_local_skill(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_update_agent_local_skill(self, websocket, data)
async def _handle_delete_agent_local_skill(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_delete_agent_local_skill(self, websocket, data)
async def _handle_remove_agent_skill(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_remove_agent_skill(self, websocket, data)
async def _handle_update_agent_skill(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_update_agent_skill(self, websocket, data)
async def _handle_get_agent_workspace_file(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_get_agent_workspace_file(self, websocket, data)
async def _handle_update_agent_workspace_file(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_admin_handlers.handle_update_agent_workspace_file(self, websocket, data)
async def _handle_get_openclaw_status(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_status(self, websocket, data)
async def _handle_get_openclaw_sessions(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_sessions(self, websocket, data)
async def _handle_get_openclaw_session_detail(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_session_detail(self, websocket, data)
async def _handle_get_openclaw_session_history(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_session_history(self, websocket, data)
async def _handle_get_openclaw_cron(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_cron(self, websocket, data)
async def _handle_get_openclaw_approvals(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_approvals(self, websocket, data)
async def _handle_get_openclaw_agents(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_agents(self, websocket, data)
async def _handle_get_openclaw_agents_presence(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_agents_presence(self, websocket, data)
async def _handle_get_openclaw_skills(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_skills(self, websocket, data)
async def _handle_get_openclaw_models(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_models(self, websocket, data)
async def _handle_get_openclaw_workspace_files(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
await gateway_openclaw_handlers.handle_get_openclaw_workspace_files(self, websocket, data)
@staticmethod
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
return gateway_runtime_support.normalize_watchlist(raw_tickers)
@staticmethod
def _normalize_agent_workspace_filename(raw_name: Any) -> Optional[str]:
return gateway_runtime_support.normalize_agent_workspace_filename(
raw_name,
allowlist=EDITABLE_AGENT_WORKSPACE_FILES,
)
def _apply_runtime_config(
self,
runtime_config: Dict[str, Any],
) -> Dict[str, Any]:
return gateway_runtime_support.apply_runtime_config(self, runtime_config)
def _sync_runtime_state(self) -> None:
gateway_runtime_support.sync_runtime_state(self)
def _schedule_watchlist_market_store_refresh(
self,
tickers: List[str],
) -> None:
gateway_cycle_support.schedule_watchlist_market_store_refresh(self, tickers)
async def _refresh_market_store_for_watchlist(
self,
tickers: List[str],
) -> None:
await gateway_cycle_support.refresh_market_store_for_watchlist(self, tickers)
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
return
message_json = json.dumps(message, ensure_ascii=False, default=str)
async with self.lock:
tasks = [
self._send_to_client(client, message_json)
for client in self.connected_clients.copy()
]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _send_to_client(
self,
client: ServerConnection,
message: str,
):
try:
await client.send(message)
except websockets.ConnectionClosed:
async with self.lock:
self.connected_clients.discard(client)
async def _market_status_monitor(self):
await gateway_cycle_support.market_status_monitor(self)
async def _update_and_broadcast_live_returns(self):
await gateway_cycle_support.update_and_broadcast_live_returns(self)
async def on_strategy_trigger(self, date: str):
await gateway_cycle_support.on_strategy_trigger(self, date)
async def on_heartbeat_trigger(self, date: str):
await gateway_cycle_support.on_heartbeat_trigger(self, date)
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
await gateway_cycle_support.run_backtest_cycle(self, date, tickers)
async def _run_live_cycle(self, date: str, tickers: List[str]):
await gateway_cycle_support.run_live_cycle(self, date, tickers)
async def _finalize_cycle(self, date: str):
await gateway_cycle_support.finalize_cycle(self, date)
async def _get_market_caps(
self,
tickers: List[str],
date: str,
) -> Dict[str, float]:
return await gateway_cycle_support.get_market_caps(self, tickers, date)
async def _broadcast_portfolio_updates(
self,
result: Dict[str, Any],
prices: Dict[str, float],
):
await gateway_cycle_support.broadcast_portfolio_updates(self, result, prices)
def _save_cycle_results(
self,
result: Dict[str, Any],
date: str,
prices: Dict[str, float],
settlement_result: Optional[Dict[str, Any]] = None,
):
gateway_cycle_support.save_cycle_results(
self,
result,
date,
prices,
settlement_result,
)
async def _run_backtest_dates(self, dates: List[str]):
await gateway_cycle_support.run_backtest_dates(self, dates)
def _handle_backtest_exception(self, task: asyncio.Task):
gateway_cycle_support.handle_backtest_exception(self, task)
def _handle_manual_cycle_exception(self, task: asyncio.Task):
gateway_cycle_support.handle_manual_cycle_exception(self, task)
def set_backtest_dates(self, dates: List[str]):
gateway_cycle_support.set_backtest_dates(self, dates)
def stop(self):
gateway_cycle_support.stop_gateway(self)

View File

@@ -0,0 +1,426 @@
# -*- coding: utf-8 -*-
"""Runtime/workspace/skills handlers extracted from the main Gateway module.
Deprecated note:
Agent/workspace/skill read-write operations are being migrated to
agent_service REST endpoints. These websocket handlers remain as a
compatibility fallback and should not be considered the primary control
plane path for frontend reads/writes.
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
from backend.config.bootstrap_config import (
get_bootstrap_config_for_run,
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.market_ingest import ingest_symbols
from backend.llm.models import get_agent_model_info
async def handle_reload_runtime_assets(gateway: Any) -> None:
config_name = gateway.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=gateway._project_root,
config_name=config_name,
enable_memory=gateway.config.get("enable_memory", False),
schedule_mode=gateway.config.get("schedule_mode", "daily"),
interval_minutes=gateway.config.get("interval_minutes", 60),
trigger_time=gateway.config.get("trigger_time", "09:30"),
)
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = gateway._apply_runtime_config(runtime_config)
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
updates: dict[str, Any] = {}
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
if schedule_mode:
if schedule_mode not in {"daily", "intraday"}:
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
return
updates["schedule_mode"] = schedule_mode
interval_minutes = data.get("interval_minutes")
if interval_minutes is not None:
try:
parsed_interval = int(interval_minutes)
except (TypeError, ValueError):
parsed_interval = 0
if parsed_interval <= 0:
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
return
updates["interval_minutes"] = parsed_interval
trigger_time = data.get("trigger_time")
if trigger_time is not None:
raw_trigger = str(trigger_time).strip()
if raw_trigger and raw_trigger != "now":
try:
datetime.strptime(raw_trigger, "%H:%M")
except ValueError:
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
return
updates["trigger_time"] = raw_trigger or "09:30"
max_comm_cycles = data.get("max_comm_cycles")
if max_comm_cycles is not None:
try:
parsed_cycles = int(max_comm_cycles)
except (TypeError, ValueError):
parsed_cycles = 0
if parsed_cycles <= 0:
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
return
updates["max_comm_cycles"] = parsed_cycles
initial_cash = data.get("initial_cash")
if initial_cash is not None:
try:
parsed_initial_cash = float(initial_cash)
except (TypeError, ValueError):
parsed_initial_cash = 0.0
if parsed_initial_cash <= 0:
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
return
updates["initial_cash"] = parsed_initial_cash
margin_requirement = data.get("margin_requirement")
if margin_requirement is not None:
try:
parsed_margin_requirement = float(margin_requirement)
except (TypeError, ValueError):
parsed_margin_requirement = -1.0
if parsed_margin_requirement < 0:
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
return
updates["margin_requirement"] = parsed_margin_requirement
enable_memory = data.get("enable_memory")
if enable_memory is not None:
updates["enable_memory"] = bool(enable_memory)
if not updates:
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates=updates,
)
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
await handle_reload_runtime_assets(gateway)
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
tickers = gateway._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
await handle_reload_runtime_assets(gateway)
gateway._schedule_watchlist_market_store_refresh(tickers)
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
enabled = set(agent_config.enabled_skills)
disabled = set(agent_config.disabled_skills)
payload = []
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
if item.skill_name in disabled:
status = "disabled"
elif item.skill_name in enabled:
status = "enabled"
elif item.skill_name in resolved_skills:
status = "active"
else:
status = "available"
payload.append({
"skill_name": item.skill_name,
"name": item.name,
"description": item.description,
"version": item.version,
"source": item.source,
"tools": item.tools,
"status": status,
})
await websocket.send(json.dumps({
"type": "agent_skills_loaded",
"config_name": config_name,
"agent_id": agent_id,
"skills": payload,
}, ensure_ascii=False))
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
profiles = load_agent_profiles()
profile = profiles.get(agent_id, {})
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
override = bootstrap.agent_override(agent_id)
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
if not isinstance(active_tool_groups, list):
active_tool_groups = []
disabled_tool_groups = agent_config.disabled_tool_groups
if disabled_tool_groups:
disabled_set = set(disabled_tool_groups)
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
default_skills = profile.get("skills", [])
if not isinstance(default_skills, list):
default_skills = []
resolved_skills = skills_manager.resolve_agent_skill_names(
config_name=config_name,
agent_id=agent_id,
default_skills=default_skills,
)
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
model_name, model_provider = get_agent_model_info(agent_id)
await websocket.send(json.dumps({
"type": "agent_profile_loaded",
"config_name": config_name,
"agent_id": agent_id,
"profile": {
"model_name": model_name,
"model_provider": model_provider,
"prompt_files": prompt_files,
"default_skills": default_skills,
"resolved_skills": resolved_skills,
"active_tool_groups": active_tool_groups,
"disabled_tool_groups": disabled_tool_groups,
"enabled_skills": agent_config.enabled_skills,
"disabled_skills": agent_config.disabled_skills,
},
}, ensure_ascii=False))
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
return
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
if agent_id:
config_name = gateway.config.get("config_name", "default")
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
else:
detail = skills_manager.load_skill_document(skill_name)
except FileNotFoundError:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
await websocket.send(json.dumps({
"type": "skill_detail_loaded",
"agent_id": agent_id,
"skill": detail,
}, ensure_ascii=False))
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
except (ValueError, FileExistsError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
content = data.get("content")
if not agent_id or not skill_name or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {
item.skill_name
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
if item.source != "local"
}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
return
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
enabled = data.get("enabled")
if not agent_id or not skill_name or not isinstance(enabled, bool):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
if enabled:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
else:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
await websocket.send(json.dumps({
"type": "agent_skill_updated",
"agent_id": agent_id,
"skill_name": skill_name,
"enabled": enabled,
}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
if not agent_id or not filename:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
content = path.read_text(encoding="utf-8") if path.exists() else ""
await websocket.send(json.dumps({
"type": "agent_workspace_file_loaded",
"config_name": config_name,
"agent_id": agent_id,
"filename": filename,
"content": content,
}, ensure_ascii=False))
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
content = data.get("content")
if not agent_id or not filename or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
path.write_text(content, encoding="utf-8")
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})

View File

@@ -0,0 +1,372 @@
# -*- coding: utf-8 -*-
"""Cycle and monitoring helpers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from backend.data.market_ingest import ingest_symbols, refresh_news_for_symbols
from backend.domains import trading as trading_domain
from backend.utils.msg_adapter import FrontendAdapter
logger = logging.getLogger(__name__)
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
if not tickers:
return
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
gateway._watchlist_ingest_task.cancel()
gateway._watchlist_ingest_task = asyncio.create_task(
refresh_market_store_for_watchlist(gateway, tickers),
)
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await gateway.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await gateway.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await gateway.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def market_status_monitor(gateway: Any) -> None:
"""Periodically check and broadcast market status changes."""
while True:
try:
await gateway.market_service.check_and_broadcast_market_status()
status = gateway.market_service.get_market_status()
if status["status"] == "open" and not gateway.storage.is_live_session_active:
gateway.storage.start_live_session()
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
gateway._session_start_portfolio_value = summary.get(
"totalAssetValue",
gateway.storage.initial_cash,
)
logger.info(
"Session start portfolio: $%s",
f"{gateway._session_start_portfolio_value:,.2f}",
)
elif status["status"] != "open" and gateway.storage.is_live_session_active:
gateway.storage.end_live_session()
gateway._session_start_portfolio_value = None
if gateway.storage.is_live_session_active:
await update_and_broadcast_live_returns(gateway)
await asyncio.sleep(60)
except asyncio.CancelledError:
break
except Exception as exc:
logger.error("Market status monitor error: %s", exc)
await asyncio.sleep(60)
async def update_and_broadcast_live_returns(gateway: Any) -> None:
"""Calculate and broadcast live returns for current session."""
if not gateway.storage.is_live_session_active:
return
prices = gateway.market_service.get_all_prices()
if not prices or not any(p > 0 for p in prices.values()):
return
state = gateway.storage.load_internal_state()
equity_history = state.get("equity_history", [])
baseline_history = state.get("baseline_history", [])
baseline_vw_history = state.get("baseline_vw_history", [])
momentum_history = state.get("momentum_history", [])
current_equity = equity_history[-1]["v"] if equity_history else None
current_baseline = baseline_history[-1]["v"] if baseline_history else None
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
current_momentum = momentum_history[-1]["v"] if momentum_history else None
point = gateway.storage.update_live_returns(
current_equity=current_equity,
current_baseline=current_baseline,
current_baseline_vw=current_baseline_vw,
current_momentum=current_momentum,
)
if point:
live_returns = gateway.storage.get_live_returns()
await gateway.broadcast(
{
"type": "team_summary",
"equity_return": live_returns["equity_return"],
"baseline_return": live_returns["baseline_return"],
"baseline_vw_return": live_returns["baseline_vw_return"],
"momentum_return": live_returns["momentum_return"],
},
)
async def on_strategy_trigger(gateway: Any, date: str) -> None:
"""Handle trading cycle trigger."""
if gateway._cycle_lock.locked():
logger.warning("Trading cycle already running, skipping trigger for %s", date)
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
return
async with gateway._cycle_lock:
logger.info("Strategy triggered for %s", date)
tickers = gateway.config.get("tickers", [])
if gateway.is_backtest:
await run_backtest_cycle(gateway, date, tickers)
else:
await run_live_cycle(gateway, date, tickers)
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
"""Run lightweight heartbeat check for all analysts."""
logger.info("[Heartbeat] Running heartbeat check for %s", date)
analysts = gateway.pipeline._all_analysts()
for analyst in analysts:
try:
logger.debug(
"[Heartbeat] No heartbeat configured for %s, skipping",
analyst.name,
)
except Exception as exc:
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
gateway.market_service.set_backtest_date(date)
await gateway.market_service.emit_market_open()
await gateway.state_sync.on_cycle_start(date)
prices = gateway.market_service.get_open_prices()
close_prices = gateway.market_service.get_close_prices()
market_caps = await get_market_caps(gateway, tickers, date)
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=date,
prices=prices,
close_prices=close_prices,
market_caps=market_caps,
)
await gateway.market_service.emit_market_close()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, date)
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
trading_date = gateway.market_service.get_live_trading_date()
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
try:
news_refresh = await asyncio.to_thread(
refresh_news_for_symbols,
tickers,
end_date=trading_date,
store=gateway.storage.market_store,
)
logger.info(
"News refresh complete: %s",
", ".join(
f"{item['symbol']} news={item['news']}"
for item in news_refresh
) or "no symbols",
)
except Exception as exc:
logger.warning("Live cycle news refresh failed: %s", exc)
await gateway.state_sync.on_cycle_start(trading_date)
market_caps = await get_market_caps(gateway, tickers, trading_date)
schedule_mode = gateway.config.get("schedule_mode", "daily")
market_status = gateway.market_service.get_market_status()
current_prices = gateway.market_service.get_all_prices()
if schedule_mode == "intraday":
execute_decisions = market_status.get("status") == "open"
if execute_decisions:
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
else:
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
prices=current_prices,
market_caps=market_caps,
execute_decisions=execute_decisions,
)
close_prices = current_prices
else:
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
market_caps=market_caps,
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
)
close_prices = gateway.market_service.get_all_prices()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, trading_date)
async def finalize_cycle(gateway: Any, date: str) -> None:
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
summary = dashboard_snapshot.get("summary") or {}
if gateway.storage.is_live_session_active:
summary.update(gateway.storage.get_live_returns())
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
leaderboard = dashboard_snapshot.get("leaderboard") or []
if leaderboard:
await gateway.state_sync.on_leaderboard_update(leaderboard)
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
market_caps: dict[str, float] = {}
for ticker in tickers:
try:
market_cap = None
response = await gateway._call_trading_service(
f"get_market_cap for {ticker}",
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
)
if response is not None:
market_cap = response.get("market_cap")
if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9
except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
market_caps[ticker] = 1e9
return market_caps
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
portfolio = result.get("portfolio", {})
if portfolio:
holdings = FrontendAdapter.build_holdings(portfolio, prices)
if holdings:
await gateway.state_sync.on_holdings_update(holdings)
stats = FrontendAdapter.build_stats(portfolio, prices)
if stats:
await gateway.state_sync.on_stats_update(stats)
executed_trades = result.get("executed_trades", [])
if executed_trades:
await gateway.state_sync.on_trades_executed(executed_trades)
def save_cycle_results(
gateway: Any,
result: dict[str, Any],
date: str,
prices: dict[str, float],
settlement_result: dict[str, Any] | None = None,
) -> None:
portfolio = result.get("portfolio", {})
executed_trades = result.get("executed_trades", [])
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
if portfolio:
gateway.storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date=date,
executed_trades=executed_trades,
baseline_values=baseline_values,
)
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
try:
for date in dates:
await gateway.on_strategy_trigger(date=date)
await asyncio.sleep(0.1)
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
except Exception as exc:
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
logger.error(error_msg, exc_info=True)
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
raise
finally:
gateway._backtest_task = None
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
logger.info("Backtest task was cancelled")
except Exception as exc:
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
gateway._manual_cycle_task = None
try:
task.result()
except asyncio.CancelledError:
logger.info("Manual cycle task was cancelled")
except Exception as exc:
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
if dates:
gateway._backtest_start_date = dates[0]
gateway._backtest_end_date = dates[-1]
def stop_gateway(gateway: Any) -> None:
gateway.state_sync.save_state()
gateway.market_service.stop()
if gateway._backtest_task:
gateway._backtest_task.cancel()
if gateway._market_status_task:
gateway._market_status_task.cancel()
if gateway._watchlist_ingest_task:
gateway._watchlist_ingest_task.cancel()
# Close OpenClaw WebSocket connection
if gateway._openclaw_ws:
import asyncio
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.create_task(gateway._openclaw_ws.disconnect())
else:
loop.run_until_complete(gateway._openclaw_ws.disconnect())
except Exception:
pass

View File

@@ -0,0 +1,534 @@
# -*- coding: utf-8 -*-
"""OpenClaw WebSocket handlers — gateway calls OpenClaw Gateway via WebSocket."""
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from backend.services.gateway import Gateway
logger = logging.getLogger(__name__)
def _ensure_session_bridge(gateway) -> None:
"""Forward OpenClaw session events into 大时代 frontend websockets."""
if getattr(gateway, "_openclaw_session_bridge_ready", False):
return
async def _forward(event) -> None:
payload = event.payload or {}
session_key = str(payload.get("sessionKey") or payload.get("key") or "").strip()
if not session_key:
return
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
targets = [
ws
for ws, session_keys in list(subscriber_map.items())
if session_key in session_keys
]
if not targets:
return
message = json.dumps(
{
"type": "openclaw_session_event",
"event": event.event,
"session_key": session_key,
"payload": payload,
}
)
stale = []
for ws in targets:
try:
await ws.send(message)
except Exception:
stale.append(ws)
for ws in stale:
try:
subscriber_map.pop(ws, None)
except Exception:
pass
def _handler(event) -> None:
try:
import asyncio
asyncio.create_task(_forward(event))
except Exception as exc:
logger.debug("OpenClaw session bridge skipped event: %s", exc)
client = _get_ws_client(gateway)
client.add_event_handler(_handler)
gateway._openclaw_session_bridge_ready = True
gateway._openclaw_session_bridge_handler = _handler
if not hasattr(gateway, "_openclaw_session_subscribers"):
gateway._openclaw_session_subscribers = {}
def _get_ws_client(gateway) -> "OpenClawWebSocketClient":
"""Get the OpenClaw WebSocket client from gateway."""
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient
client = gateway._openclaw_ws
if client is None:
raise RuntimeError("OpenClaw Gateway not connected")
return client
async def _ws_call(gateway, method: str, params: dict | None = None) -> dict:
"""Call OpenClaw Gateway via WebSocket and return result."""
try:
client = _get_ws_client(gateway)
return await client.call_method(method, params)
except Exception as exc:
logger.warning("OpenClaw Gateway call failed for %s: %s", method, exc)
return {"error": str(exc)[:200]}
async def handle_get_openclaw_status(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "status")
await websocket.send(json.dumps({"type": "openclaw_status_loaded", "data": result}))
async def handle_get_openclaw_sessions(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "sessions.list", {"limit": 50, "includeLastMessage": True})
await websocket.send(json.dumps({"type": "openclaw_sessions_loaded", "data": result}))
async def handle_get_openclaw_session_detail(gateway, websocket, data: dict) -> None:
session_key = data.get("session_key", "")
result = await _ws_call(gateway, "sessions.list", {"limit": 200, "includeLastMessage": True})
session = None
if isinstance(result, dict):
for item in result.get("sessions", []) or []:
if not isinstance(item, dict):
continue
if item.get("key") == session_key or item.get("sessionKey") == session_key:
session = item
break
await websocket.send(json.dumps({
"type": "openclaw_session_detail_loaded",
"data": {"session": session, "error": None if session else f"session '{session_key}' not found"},
"session_key": session_key,
}))
async def handle_get_openclaw_session_history(gateway, websocket, data: dict) -> None:
session_key = data.get("session_key", "")
limit = data.get("limit", 20)
try:
from backend.services.openclaw_cli import OpenClawCliService
result = OpenClawCliService().get_session_history_model(session_key, limit=limit)
payload = {
"session_key": result.session_key,
"session_id": result.session_id,
"history": result.events,
"events": result.events,
"raw_text": result.raw_text,
}
except Exception as exc:
payload = {"error": str(exc)[:200], "history": []}
await websocket.send(json.dumps({
"type": "openclaw_session_history_loaded",
"data": payload,
"session_key": session_key,
}))
async def handle_openclaw_resolve_session(gateway, websocket, data: dict) -> None:
params = {}
agent_id = str(data.get("agent_id") or "").strip()
label = str(data.get("label") or "").strip()
channel = str(data.get("channel") or "").strip()
if agent_id:
params["agentId"] = agent_id
if label:
params["label"] = label
if channel:
params["channel"] = channel
params["includeGlobal"] = bool(data.get("include_global", True))
result = await _ws_call(gateway, "sessions.resolve", params)
await websocket.send(json.dumps({"type": "openclaw_session_resolved", "data": result}))
async def handle_openclaw_create_session(gateway, websocket, data: dict) -> None:
params = {}
agent_id = str(data.get("agent_id") or "").strip()
label = str(data.get("label") or "").strip()
model = str(data.get("model") or "").strip()
initial_message = str(data.get("initial_message") or "").strip()
if agent_id:
params["agentId"] = agent_id
if label:
params["label"] = label
if model:
params["model"] = model
if initial_message:
params["message"] = initial_message
result = await _ws_call(gateway, "sessions.create", params)
await websocket.send(json.dumps({"type": "openclaw_session_created", "data": result}))
async def handle_openclaw_send_message(gateway, websocket, data: dict) -> None:
session_key = str(data.get("session_key") or "").strip()
message = str(data.get("message") or "").strip()
thinking = str(data.get("thinking") or "").strip()
if not session_key or not message:
await websocket.send(
json.dumps(
{
"type": "openclaw_message_sent",
"data": {"error": "session_key and message are required"},
}
)
)
return
params = {"key": session_key, "message": message}
if thinking:
params["thinking"] = thinking
result = await _ws_call(gateway, "sessions.send", params)
await websocket.send(
json.dumps(
{
"type": "openclaw_message_sent",
"data": result,
"session_key": session_key,
}
)
)
async def handle_openclaw_subscribe_session(gateway, websocket, data: dict) -> None:
session_key = str(data.get("session_key") or "").strip()
if not session_key:
await websocket.send(
json.dumps(
{
"type": "openclaw_session_subscribed",
"data": {"error": "session_key is required"},
}
)
)
return
_ensure_session_bridge(gateway)
result = await _ws_call(gateway, "sessions.messages.subscribe", {"key": session_key})
if not isinstance(result, dict) or not result.get("error"):
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
subscriber_map.setdefault(websocket, set()).add(session_key)
gateway._openclaw_session_subscribers = subscriber_map
await websocket.send(
json.dumps(
{
"type": "openclaw_session_subscribed",
"data": result,
"session_key": session_key,
}
)
)
async def handle_openclaw_unsubscribe_session(gateway, websocket, data: dict) -> None:
session_key = str(data.get("session_key") or "").strip()
if not session_key:
await websocket.send(
json.dumps(
{
"type": "openclaw_session_unsubscribed",
"data": {"error": "session_key is required"},
}
)
)
return
result = await _ws_call(gateway, "sessions.messages.unsubscribe", {"key": session_key})
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
session_keys = subscriber_map.get(websocket)
if isinstance(session_keys, set):
session_keys.discard(session_key)
if not session_keys:
subscriber_map.pop(websocket, None)
gateway._openclaw_session_subscribers = subscriber_map
await websocket.send(
json.dumps(
{
"type": "openclaw_session_unsubscribed",
"data": result,
"session_key": session_key,
}
)
)
async def handle_openclaw_reset_session(gateway, websocket, data: dict) -> None:
session_key = str(data.get("session_key") or "").strip()
if not session_key:
await websocket.send(
json.dumps(
{
"type": "openclaw_session_reset",
"data": {"error": "session_key is required"},
}
)
)
return
result = await _ws_call(gateway, "sessions.reset", {"key": session_key})
await websocket.send(
json.dumps(
{
"type": "openclaw_session_reset",
"data": result,
"session_key": session_key,
}
)
)
async def handle_openclaw_delete_session(gateway, websocket, data: dict) -> None:
session_key = str(data.get("session_key") or "").strip()
if not session_key:
await websocket.send(
json.dumps(
{
"type": "openclaw_session_deleted",
"data": {"error": "session_key is required"},
}
)
)
return
result = await _ws_call(gateway, "sessions.delete", {"key": session_key})
await websocket.send(
json.dumps(
{
"type": "openclaw_session_deleted",
"data": result,
"session_key": session_key,
}
)
)
async def handle_get_openclaw_cron(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "cron.list")
await websocket.send(json.dumps({"type": "openclaw_cron_loaded", "data": result}))
async def handle_get_openclaw_approvals(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "exec.approvals.get")
await websocket.send(json.dumps({"type": "openclaw_approvals_loaded", "data": result}))
async def handle_get_openclaw_agents(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "agents.list")
sessions_result = await _ws_call(
gateway,
"sessions.list",
{"limit": 200, "includeLastMessage": True},
)
config_result = await _ws_call(gateway, "config.get")
session_model_by_agent: dict[str, str] = {}
default_session_model: str | None = None
agent_skills_by_id: dict[str, list[str] | None] = {}
default_agent_skills: list[str] | None = None
parsed_config = config_result.get("parsed") if isinstance(config_result, dict) else None
if isinstance(parsed_config, dict):
agents_cfg = parsed_config.get("agents")
if isinstance(agents_cfg, dict):
defaults_cfg = agents_cfg.get("defaults")
if isinstance(defaults_cfg, dict):
default_skills = defaults_cfg.get("skills")
if isinstance(default_skills, list):
default_agent_skills = [
str(skill).strip()
for skill in default_skills
if str(skill).strip()
]
list_cfg = agents_cfg.get("list")
if isinstance(list_cfg, list):
for entry in list_cfg:
if not isinstance(entry, dict):
continue
agent_id = str(entry.get("id") or "").strip()
if not agent_id:
continue
skills = entry.get("skills")
if isinstance(skills, list):
agent_skills_by_id[agent_id] = [
str(skill).strip()
for skill in skills
if str(skill).strip()
]
elif skills == []:
agent_skills_by_id[agent_id] = []
if isinstance(sessions_result, dict) and isinstance(sessions_result.get("sessions"), list):
defaults = sessions_result.get("defaults")
if isinstance(defaults, dict):
value = (
defaults.get("model")
or defaults.get("modelName")
or defaults.get("model_name")
)
if value:
default_session_model = str(value)
for session in sessions_result.get("sessions", []):
if not isinstance(session, dict):
continue
agent_id = str(
session.get("agentId")
or session.get("agent_id")
or ""
).strip()
if not agent_id:
key = str(session.get("key") or session.get("sessionKey") or "").strip()
parts = key.split(":")
if len(parts) >= 3 and parts[0] == "agent":
agent_id = parts[1]
model_value = (
session.get("model")
or session.get("modelName")
or session.get("model_name")
or session.get("resolvedModel")
or session.get("resolved_model")
or session.get("defaultModel")
or session.get("default_model")
)
if agent_id and model_value and agent_id not in session_model_by_agent:
session_model_by_agent[agent_id] = str(model_value)
if isinstance(result, dict) and isinstance(result.get("agents"), list):
normalized_agents = []
for agent in result.get("agents", []):
if not isinstance(agent, dict):
normalized_agents.append(agent)
continue
normalized = dict(agent)
if not normalized.get("model"):
normalized["model"] = (
normalized.get("modelName")
or normalized.get("model_name")
or normalized.get("resolvedModel")
or normalized.get("resolved_model")
or normalized.get("defaultModel")
or normalized.get("default_model")
or session_model_by_agent.get(str(normalized.get("id") or "").strip())
or default_session_model
)
agent_id = str(normalized.get("id") or "").strip()
if "skills" not in normalized:
normalized["skills"] = agent_skills_by_id.get(agent_id, default_agent_skills)
normalized_agents.append(normalized)
result = {**result, "agents": normalized_agents}
await websocket.send(json.dumps({"type": "openclaw_agents_loaded", "data": result}))
async def handle_get_openclaw_agents_presence(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "node.list")
await websocket.send(json.dumps({"type": "openclaw_agents_presence_loaded", "data": result}))
async def handle_get_openclaw_skills(gateway, websocket, data: dict) -> None:
agent_id = str(data.get("agent_id") or "").strip()
params = {"agentId": agent_id} if agent_id else {}
result = await _ws_call(gateway, "skills.status", params)
await websocket.send(json.dumps({"type": "openclaw_skills_loaded", "data": result}))
async def handle_get_openclaw_models(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "models.list")
await websocket.send(json.dumps({"type": "openclaw_models_loaded", "data": result}))
async def handle_get_openclaw_hooks(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "tools.catalog")
await websocket.send(json.dumps({"type": "openclaw_hooks_loaded", "data": result}))
async def handle_get_openclaw_plugins(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "config.get")
await websocket.send(json.dumps({"type": "openclaw_plugins_loaded", "data": result}))
async def handle_get_openclaw_secrets_audit(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "secrets.reload")
await websocket.send(json.dumps({"type": "openclaw_secrets_audit_loaded", "data": result}))
async def handle_get_openclaw_security_audit(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "gateway.identity.get")
await websocket.send(json.dumps({"type": "openclaw_security_audit_loaded", "data": result}))
async def handle_get_openclaw_daemon_status(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "doctor.memory.status")
await websocket.send(json.dumps({"type": "openclaw_daemon_status_loaded", "data": result}))
async def handle_get_openclaw_pairing(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "device.pair.list")
await websocket.send(json.dumps({"type": "openclaw_pairing_loaded", "data": result}))
async def handle_get_openclaw_qr(gateway, websocket, data: dict) -> None:
await websocket.send(json.dumps({"type": "openclaw_qr_loaded", "data": {"error": "QR code not available via WebSocket"}}))
async def handle_get_openclaw_update_status(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "update.run")
await websocket.send(json.dumps({"type": "openclaw_update_status_loaded", "data": result}))
async def handle_get_openclaw_models_aliases(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "models.list")
await websocket.send(json.dumps({"type": "openclaw_models_aliases_loaded", "data": result}))
async def handle_get_openclaw_models_fallbacks(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "models.list")
await websocket.send(json.dumps({"type": "openclaw_models_fallbacks_loaded", "data": result}))
async def handle_get_openclaw_models_image_fallbacks(gateway, websocket, data: dict) -> None:
result = await _ws_call(gateway, "models.list")
await websocket.send(json.dumps({"type": "openclaw_models_image_fallbacks_loaded", "data": result}))
async def handle_get_openclaw_skill_update(gateway, websocket, data: dict) -> None:
slug = data.get("slug")
all_flag = data.get("all", False)
params = {}
if slug is not None:
params["slug"] = slug
if all_flag:
params["all"] = "true"
result = await _ws_call(gateway, "skills.update", params)
await websocket.send(json.dumps({"type": "openclaw_skill_update_loaded", "data": result}))
async def handle_get_openclaw_workspace_files(gateway, websocket, data: dict) -> None:
raw_workspace = data.get("workspace", "")
# Use the workspace param (which is actually the agent.id from frontend) as agent_id
agent_id = raw_workspace or "main"
result = await _ws_call(gateway, "agents.files.list", {"agentId": agent_id})
if isinstance(result, dict):
result["workspace"] = agent_id
await websocket.send(json.dumps({"type": "openclaw_workspace_files_loaded", "data": result}))
async def handle_get_openclaw_workspace_file(gateway, websocket, data: dict) -> None:
agent_id = data.get("agent_id", "main")
file_name = data.get("file_name", "")
if not file_name:
await websocket.send(json.dumps({"type": "openclaw_workspace_file_loaded", "data": {"error": "file_name is required"}}))
return
result = await _ws_call(gateway, "agents.files.get", {"agentId": agent_id, "name": file_name})
await websocket.send(json.dumps({"type": "openclaw_workspace_file_loaded", "data": result}))

View File

@@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
"""Runtime/state support helpers extracted from the main Gateway module."""
from __future__ import annotations
from typing import Any
from backend.data.provider_utils import normalize_symbol
def normalize_watchlist(raw_tickers: Any) -> list[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: list[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def normalize_agent_workspace_filename(
raw_name: Any,
*,
allowlist: set[str],
) -> str | None:
"""Restrict editable workspace files to a safe allowlist."""
filename = str(raw_name or "").strip()
if filename in allowlist:
return filename
return None
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: list[str] = []
ticker_changes = gateway.market_service.update_tickers(
runtime_config.get("tickers", []),
)
gateway.config["tickers"] = ticker_changes["active"]
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
gateway.config["schedule_mode"] = runtime_config.get(
"schedule_mode",
gateway.config.get("schedule_mode", "daily"),
)
gateway.config["interval_minutes"] = int(
runtime_config.get(
"interval_minutes",
gateway.config.get("interval_minutes", 60),
),
)
gateway.config["trigger_time"] = runtime_config.get(
"trigger_time",
gateway.config.get("trigger_time", "09:30"),
)
if gateway.scheduler:
gateway.scheduler.reconfigure(
mode=gateway.config["schedule_mode"],
trigger_time=gateway.config["trigger_time"],
interval_minutes=gateway.config["interval_minutes"],
)
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(gateway.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
gateway.storage.can_apply_initial_cash()
and gateway.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = gateway.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
gateway.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
gateway.config["initial_cash"] = gateway.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(gateway.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
sync_runtime_state(gateway)
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(gateway.config.get("tickers", [])),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config["margin_requirement"],
"max_comm_cycles": gateway.config["max_comm_cycles"],
"enable_memory": gateway.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"schedule_mode": True,
"interval_minutes": True,
"trigger_time": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def sync_runtime_state(gateway: Any) -> None:
"""Refresh persisted state after runtime config changes."""
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
gateway.state_sync.update_state(
"runtime_config",
{
"tickers": gateway.config.get("tickers", []),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config.get("margin_requirement"),
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
"enable_memory": gateway.config.get("enable_memory", False),
},
)
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
gateway.state_sync.save_state()

View File

@@ -0,0 +1,716 @@
# -*- coding: utf-8 -*-
"""Stock-related Gateway handlers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain
from backend.domains import trading as trading_domain
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.tools.data_tools import prices_to_df
from shared.client import NewsServiceClient, TradingServiceClient
logger = logging.getLogger(__name__)
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
prices = []
source = "polygon"
response = await gateway._call_trading_service(
"get_prices for history",
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
prices = response.prices
source = "trading_service"
if not prices:
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
if not prices:
payload = await asyncio.to_thread(
trading_domain.get_prices_payload,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
prices = payload.get("prices") or []
usage_snapshot = gateway._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
gateway.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(json.dumps({
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
}, ensure_ascii=False, default=str))
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
news_rows = []
source = "polygon"
response = await gateway._call_news_service(
"get_enriched_news",
lambda client: client.get_enriched_news(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_enriched_news,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
refresh_if_stale=False,
)
news_rows = (payload.get("news") or [])[-limit:]
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
source = "market_store"
news_rows = []
response = await gateway._call_news_service(
"get_news_for_date",
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_news_for_date,
gateway.storage.market_store,
ticker=ticker,
date=trade_date,
limit=limit,
refresh_if_stale=False,
)
news_rows = payload.get("news") or []
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": source,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
timeline = []
response = await gateway._call_news_service(
"get_news_timeline",
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
timeline = response.get("timeline") or []
if not timeline:
payload = await asyncio.to_thread(
news_domain.get_news_timeline,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
refresh_if_stale=False,
)
timeline = payload.get("timeline") or []
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
categories = {}
response = await gateway._call_news_service(
"get_categories",
lambda client: client.get_categories(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
),
)
if response is not None:
categories = response.get("categories") or {}
if not categories:
payload = await asyncio.to_thread(
news_domain.get_news_categories,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
refresh_if_stale=False,
)
categories = payload.get("categories") or {}
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
}, ensure_ascii=False))
return
article_ids = data.get("article_ids")
result = None
response = await gateway._call_news_service(
"get_range_explain",
lambda client: client.get_range_explain(
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
),
)
if response is not None:
result = response.get("result")
if result is None:
payload = await asyncio.to_thread(
news_domain.get_range_explain_payload,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
refresh_if_stale=False,
)
result = payload.get("result")
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": "",
"trades": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
start_date = str(data.get("start_date") or "").strip()[:10]
limit = int(data.get("limit", 50))
trades = []
response = await gateway._call_trading_service(
"get_insider_trades",
lambda client: client.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
),
)
if response is not None:
trades = response.insider_trades
if not trades:
payload = await asyncio.to_thread(
trading_domain.get_insider_trades_payload,
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
)
trades = payload.get("insider_trades") or []
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
formatted_trades = [{
"ticker": t.ticker,
"name": t.name,
"title": t.title,
"is_board_director": t.is_board_director,
"transaction_date": t.transaction_date,
"transaction_shares": t.transaction_shares,
"transaction_price_per_share": t.transaction_price_per_share,
"transaction_value": t.transaction_value,
"shares_owned_before_transaction": t.shares_owned_before_transaction,
"shares_owned_after_transaction": t.shares_owned_after_transaction,
"security_title": t.security_title,
"filing_date": t.filing_date,
"holding_change": (
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
),
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
} for t in sorted_trades]
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": ticker,
"start_date": start_date or None,
"end_date": end_date,
"trades": formatted_trades,
}, ensure_ascii=False, default=str))
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
}, ensure_ascii=False))
return
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
result = await gateway._call_news_service(
"get_story",
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=as_of_date,
)
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
}, ensure_ascii=False, default=str))
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
result = await gateway._call_news_service(
"get_similar_days",
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=top_k,
)
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=250)
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
if prices is None:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
import pandas as pd
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
ma_distance = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_value = getattr(signal, ma_key, None)
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
indicators = {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_distance,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": vol_10,
"vol_20d": vol_20,
"vol_60d": vol_60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str))
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
}, ensure_ascii=False))
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
}, ensure_ascii=False))
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
gateway.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
story_result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=story_date,
)
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
similar_result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
}, ensure_ascii=False, default=str))

687
backend/services/market.py Normal file
View File

@@ -0,0 +1,687 @@
# -*- coding: utf-8 -*-
"""
Market Data Service
Supports live and backtest modes
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional
from zoneinfo import ZoneInfo
import pandas_market_calendars as mcal
from backend.config.data_config import get_data_sources
from backend.data.provider_utils import normalize_symbol
logger = logging.getLogger(__name__)
# NYSE timezone and calendar
NYSE_TZ = ZoneInfo("America/New_York")
NYSE_CALENDAR = mcal.get_calendar("NYSE")
class MarketStatus:
"""Market status enum-like class"""
OPEN = "open"
CLOSED = "closed"
PREMARKET = "premarket"
AFTERHOURS = "afterhours"
class MarketService:
"""Market data service for price management"""
def __init__(
self,
tickers: List[str],
poll_interval: int = 10,
backtest_mode: bool = False,
api_key: Optional[str] = None,
backtest_start_date: Optional[str] = None,
backtest_end_date: Optional[str] = None,
):
self.tickers = [normalize_symbol(ticker) for ticker in tickers]
self.poll_interval = poll_interval
self.backtest_mode = backtest_mode
self.api_key = api_key
self.backtest_start_date = backtest_start_date
self.backtest_end_date = backtest_end_date
self.cache: Dict[str, Dict[str, Any]] = {}
self.running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._broadcast_func: Optional[Callable] = None
self._price_record_func: Optional[Callable[..., None]] = None
self._price_manager: Optional[Any] = None
self._current_date: Optional[str] = None
# Market status tracking
self._last_market_status: Optional[str] = None
# Session tracking for live returns
self._session_start_values: Optional[Dict[str, float]] = None
self._session_start_timestamp: Optional[int] = None
def get_live_quote_provider(self) -> Optional[str]:
"""Return the active live quote provider for UI/debugging."""
if self.backtest_mode:
return "backtest"
if self._price_manager and hasattr(self._price_manager, "provider"):
provider = getattr(self._price_manager, "provider", None)
if isinstance(provider, str) and provider.strip():
return provider.strip().lower()
return None
@property
def mode_name(self) -> str:
if self.backtest_mode:
return "BACKTEST"
return "LIVE"
async def start(self, broadcast_func: Callable):
"""Start market data service"""
if self.running:
return
self.running = True
self._loop = asyncio.get_running_loop()
self._broadcast_func = broadcast_func
if self.backtest_mode:
self._start_backtest_mode()
else:
self._start_real_mode()
logger.info(
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
)
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
"""Register an optional callback for persisting runtime price points."""
self._price_record_func = recorder
def _make_price_callback(self) -> Callable:
"""Create thread-safe price callback"""
def callback(price_data: Dict[str, Any]):
symbol = price_data["symbol"]
self.cache[symbol] = price_data
loop = self._loop
if loop and loop.is_running() and self._broadcast_func:
asyncio.run_coroutine_threadsafe(
self._broadcast_price_update(price_data),
loop,
)
return callback
def _start_real_mode(self):
from backend.data.polling_price_manager import PollingPriceManager
provider = self._resolve_live_quote_provider()
if provider == "finnhub" and not self.api_key:
raise ValueError("API key required for live mode")
self._price_manager = PollingPriceManager(
api_key=self.api_key,
poll_interval=self.poll_interval,
provider=provider,
)
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(self.tickers)
self._price_manager.start()
def _resolve_live_quote_provider(self) -> str:
"""Pick the first configured provider that supports live quote polling."""
for provider in get_data_sources():
if provider in {"finnhub", "yfinance"}:
return provider
return "yfinance"
def _start_backtest_mode(self):
from backend.data.historical_price_manager import (
HistoricalPriceManager,
)
self._price_manager = HistoricalPriceManager()
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(self.tickers)
if self.backtest_start_date and self.backtest_end_date:
self._price_manager.preload_data(
self.backtest_start_date,
self.backtest_end_date,
)
self._price_manager.start()
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
"""Broadcast price update to frontend"""
if not self._broadcast_func:
return
symbol = price_data["symbol"]
price = price_data["price"]
open_price = price_data.get("open", price)
ret = (
((price - open_price) / open_price) * 100 if open_price > 0 else 0
)
if self._price_record_func:
try:
self._price_record_func(
ticker=symbol,
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
price=float(price),
open_price=float(open_price) if open_price is not None else None,
ret=float(ret),
source=self.mode_name.lower(),
meta=price_data,
)
except Exception as exc:
logger.warning(
"Failed to record price point for %s: %s",
symbol,
exc,
)
await self._broadcast_func(
{
"type": "price_update",
"symbol": symbol,
"price": price,
"open": open_price,
"ret": ret,
"timestamp": price_data.get("timestamp"),
"realtime_prices": {
t: self._get_cached_price(t) for t in self.tickers
},
},
)
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
"""Get cached price data for a ticker"""
if ticker in self.cache:
return self.cache[ticker]
# Return from price manager if not in cache
if self._price_manager:
price = self._price_manager.get_latest_price(ticker)
if price:
return {"price": price, "symbol": ticker}
return {"price": 0, "symbol": ticker}
def stop(self):
"""Stop market service"""
if not self.running:
return
self.running = False
if self._price_manager:
self._price_manager.stop()
self._price_manager = None
self._loop = None
self._broadcast_func = None
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
"""Hot-update subscribed tickers without restarting the service."""
normalized: List[str] = []
for ticker in tickers:
symbol = normalize_symbol(ticker)
if symbol and symbol not in normalized:
normalized.append(symbol)
previous = list(self.tickers)
removed = [ticker for ticker in previous if ticker not in normalized]
added = [ticker for ticker in normalized if ticker not in previous]
self.tickers = normalized
if self._price_manager:
if removed:
self._price_manager.unsubscribe(removed)
if added:
self._price_manager.subscribe(added)
if self.backtest_mode and self._current_date:
self._price_manager.set_date(self._current_date)
for ticker in removed:
self.cache.pop(ticker, None)
return {
"added": added,
"removed": removed,
"active": list(self.tickers),
}
# Backtest methods
def set_backtest_date(self, date: str):
"""Set current backtest date"""
if not self.backtest_mode or not self._price_manager:
return
self._current_date = date
self._price_manager.set_date(date)
logger.info(f"Backtest date: {date}")
async def emit_market_open(self):
"""Emit market open prices"""
if self.backtest_mode and self._price_manager:
self._price_manager.emit_open_prices()
# Log prices for debugging
prices = self.get_open_prices()
logger.info(f"Open prices: {prices}")
async def emit_market_close(self):
"""Emit market close prices"""
if self.backtest_mode and self._price_manager:
self._price_manager.emit_close_prices()
# Log prices for debugging
prices = self.get_close_prices()
logger.info(f"Close prices: {prices}")
def get_open_prices(self) -> Dict[str, float]:
"""Get open prices for all tickers"""
prices = {}
for ticker in self.tickers:
price = None
# Try price manager first
if self.backtest_mode and self._price_manager:
price = self._price_manager.get_open_price(ticker)
# Fallback to cache
if price is None or price <= 0:
cached = self.cache.get(ticker, {})
price = cached.get("open") or cached.get("price")
prices[ticker] = price if price and price > 0 else 0.0
return prices
def get_close_prices(self) -> Dict[str, float]:
"""Get close prices for all tickers"""
prices = {}
for ticker in self.tickers:
price = None
# Try price manager first
if self.backtest_mode and self._price_manager:
price = self._price_manager.get_close_price(ticker)
# Fallback to cache
if price is None or price <= 0:
cached = self.cache.get(ticker, {})
price = cached.get("close") or cached.get("price")
prices[ticker] = price if price and price > 0 else 0.0
return prices
def get_price_for_date(
self,
ticker: str,
date: str,
price_type: str = "close",
) -> Optional[float]:
"""Get price for a specific date"""
if self.backtest_mode and self._price_manager:
return self._price_manager.get_price_for_date(
ticker,
date,
price_type,
)
return self.get_price_sync(ticker)
# Common methods
def get_price_sync(self, ticker: str) -> Optional[float]:
"""Get latest price synchronously"""
# Try cache first
data = self.cache.get(ticker)
if data and data.get("price"):
return data["price"]
# Try price manager
if self._price_manager:
return self._price_manager.get_latest_price(ticker)
return None
def get_all_prices(self) -> Dict[str, float]:
"""Get all latest prices"""
prices = {}
for ticker in self.tickers:
price = self.get_price_sync(ticker)
prices[ticker] = price if price and price > 0 else 0.0
return prices
# Live mode async waiting methods
def _now_nyse(self) -> datetime:
"""Get current time in NYSE timezone"""
return datetime.now(NYSE_TZ)
def _is_trading_day(self, date: datetime) -> bool:
"""Check if date is a NYSE trading day"""
date_str = date.strftime("%Y-%m-%d")
valid_days = NYSE_CALENDAR.valid_days(
start_date=date_str,
end_date=date_str,
)
return len(valid_days) > 0
def _get_market_hours(self, date: datetime) -> tuple:
"""Get market open and close times for a given date"""
date_str = date.strftime("%Y-%m-%d")
schedule = NYSE_CALENDAR.schedule(
start_date=date_str,
end_date=date_str,
)
if schedule.empty:
return None, None
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
return market_open, market_close
def _next_trading_day(self, from_date: datetime) -> datetime:
"""Find the next trading day from given date"""
check_date = from_date + timedelta(days=1)
for _ in range(10): # Max 10 days ahead (handles holidays)
if self._is_trading_day(check_date):
return check_date
check_date += timedelta(days=1)
return check_date
def _get_trading_date_for_execution(self) -> tuple:
"""
Determine the trading date for execution.
Returns:
(trading_date, market_open_time, market_close_time)
Logic:
- If today is a trading day and market has opened: use today
- If today is a trading day but market hasn't opened: wait for open
- If today is not a trading day: use next trading day
"""
now = self._now_nyse()
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if self._is_trading_day(today):
market_open, market_close = self._get_market_hours(today)
return today, market_open, market_close
else:
# Weekend or holiday - find next trading day
next_day = self._next_trading_day(today)
market_open, market_close = self._get_market_hours(next_day)
return next_day, market_open, market_close
async def wait_for_open_prices(self) -> Dict[str, float]:
"""
Wait for market open and return open prices.
Behavior:
- If market is already open today: return current prices immediately
- If market hasn't opened yet today: wait until open
- If not a trading day: wait until next trading day opens
"""
now = self._now_nyse()
trading_date, market_open, _ = self._get_trading_date_for_execution()
if market_open is None:
logger.warning("Could not determine market hours")
return self.get_all_prices()
trading_date_str = trading_date.strftime("%Y-%m-%d")
# Check if we need to wait
if now < market_open:
wait_seconds = (market_open - now).total_seconds()
logger.info(
f"Waiting {wait_seconds/60:.1f} min for market open "
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
)
await asyncio.sleep(wait_seconds)
# Small delay to ensure prices are available
await asyncio.sleep(5)
else:
logger.info(
f"Market already open for {trading_date_str}, "
f"getting current prices",
)
# Poll until we have valid prices
prices = await self._poll_for_prices()
logger.info(f"Got open prices for {trading_date_str}: {prices}")
return prices
async def wait_for_close_prices(self) -> Dict[str, float]:
"""
Wait for market close and return close prices.
Behavior:
- If market is already closed today: return current prices immediately
- If market hasn't closed yet: wait until close
"""
now = self._now_nyse()
trading_date, _, market_close = self._get_trading_date_for_execution()
if market_close is None:
logger.warning("Could not determine market hours")
return self.get_all_prices()
trading_date_str = trading_date.strftime("%Y-%m-%d")
# Check if we need to wait
if now < market_close:
wait_seconds = (market_close - now).total_seconds()
logger.info(
f"Waiting {wait_seconds/60:.1f} min for market close "
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
)
await asyncio.sleep(wait_seconds)
# Small delay to ensure final prices settle
await asyncio.sleep(10)
else:
logger.info(
f"Market already closed for {trading_date_str}, "
f"getting close prices",
)
# Get final prices
prices = await self._poll_for_prices()
logger.info(f"Got close prices for {trading_date_str}: {prices}")
return prices
def get_live_trading_date(self) -> str:
"""Get the trading date that will be used for live execution"""
trading_date, _, _ = self._get_trading_date_for_execution()
return trading_date.strftime("%Y-%m-%d")
async def _poll_for_prices(
self,
max_retries: int = 12,
) -> Dict[str, float]:
"""Poll until all prices are available"""
for _ in range(max_retries):
prices = self.get_all_prices()
if all(p > 0 for p in prices.values()):
return prices
logger.debug("Waiting for prices to be available...")
await asyncio.sleep(5)
# Return whatever we have
return self.get_all_prices()
# ========== Market Status Methods ==========
def get_market_status(self) -> Dict[str, Any]:
"""
Get current market status
Returns:
Dict with status info:
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
- status_text: Human readable status
- is_trading_day: Whether today is a trading day
- market_open: Market open time (if trading day)
- market_close: Market close time (if trading day)
"""
if self.backtest_mode:
# In backtest mode, always return open
return {
"status": MarketStatus.OPEN,
"status_text": "Backtest Mode",
"is_trading_day": True,
"live_quote_provider": self.get_live_quote_provider(),
}
now = self._now_nyse()
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
is_trading = self._is_trading_day(today)
if not is_trading:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed (Non-trading Day)",
"is_trading_day": False,
"live_quote_provider": self.get_live_quote_provider(),
}
market_open, market_close = self._get_market_hours(today)
if market_open is None or market_close is None:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed",
"is_trading_day": is_trading,
"live_quote_provider": self.get_live_quote_provider(),
}
# Determine status based on current time
if now < market_open:
return {
"status": MarketStatus.PREMARKET,
"status_text": "Pre-Market",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
elif now > market_close:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
else:
return {
"status": MarketStatus.OPEN,
"status_text": "Market Open",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
"live_quote_provider": self.get_live_quote_provider(),
}
async def check_and_broadcast_market_status(self):
"""Check market status and broadcast if changed"""
status = self.get_market_status()
current_status = status["status"]
if current_status != self._last_market_status:
self._last_market_status = current_status
await self._broadcast_market_status(status)
# Handle session transitions
if current_status == MarketStatus.OPEN:
await self._on_session_start()
elif (
current_status == MarketStatus.CLOSED
and self._session_start_values is not None
):
self._on_session_end()
async def _broadcast_market_status(self, status: Dict[str, Any]):
"""Broadcast market status update to frontend"""
if not self._broadcast_func:
return
await self._broadcast_func(
{
"type": "market_status_update",
"market_status": status,
"timestamp": datetime.now(NYSE_TZ).isoformat(),
},
)
logger.info(f"Market status: {status['status_text']}")
async def _on_session_start(self):
"""Called when market session starts - capture baseline values"""
# Wait briefly for prices to be available
await asyncio.sleep(2)
prices = self.get_all_prices()
if prices and any(p > 0 for p in prices.values()):
self._session_start_values = prices.copy()
self._session_start_timestamp = int(
datetime.now().timestamp() * 1000,
)
logger.info(f"Session started with prices: {prices}")
def _on_session_end(self):
"""Called when market session ends - clear session data"""
self._session_start_values = None
self._session_start_timestamp = None
logger.info("Session ended, cleared session data")
def get_session_returns(
self,
current_prices: Dict[str, float],
portfolio_value: Optional[float] = None,
session_start_portfolio_value: Optional[float] = None,
) -> Optional[Dict[str, Any]]:
"""
Calculate session returns (from session start to now)
Args:
current_prices: Current prices for tickers
portfolio_value: Current portfolio value (optional)
session_start_portfolio_value:
Returns:
Dict with return data or None if session not started
"""
if self._session_start_values is None:
return None
timestamp = int(datetime.now().timestamp() * 1000)
returns = {}
# Calculate individual ticker returns
for ticker, start_price in self._session_start_values.items():
current = current_prices.get(ticker)
if current and start_price and start_price > 0:
ret = ((current - start_price) / start_price) * 100
returns[ticker] = round(ret, 4)
result = {
"timestamp": timestamp,
"ticker_returns": returns,
}
# Calculate portfolio return if values provided
if (
portfolio_value is not None
and session_start_portfolio_value is not None
):
if session_start_portfolio_value > 0:
portfolio_ret = (
(portfolio_value - session_start_portfolio_value)
/ session_start_portfolio_value
) * 100
result["portfolio_return"] = round(portfolio_ret, 4)
return result
@property
def session_start_values(self) -> Optional[Dict[str, float]]:
"""Get session start values for external use"""
return self._session_start_values
@property
def session_start_timestamp(self) -> Optional[int]:
"""Get session start timestamp"""
return self._session_start_timestamp

View File

@@ -0,0 +1,754 @@
# -*- coding: utf-8 -*-
"""Thin service wrapper around the OpenClaw CLI."""
from __future__ import annotations
import json
import os
import shlex
import shutil
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from shared.models.openclaw import (
AgentSummary,
AgentsList,
ApprovalRequest,
ApprovalsList,
CronJob,
CronList,
DaemonStatus,
HookStatusEntry,
HookStatusReport,
ModelAliasesList,
ModelFallbacksList,
ModelRow,
ModelsList,
OpenClawStatus,
PairingListResponse,
PluginDiagnostic,
PluginRecord,
PluginsList,
QrCodeResponse,
SecretsAuditReport,
SecurityAuditResponse,
SecurityAuditReport,
SessionEntry,
SessionHistory,
SessionsList,
SkillStatusEntry,
SkillStatusReport,
SkillUpdateResult,
UpdateCheckResult,
UpdateStatusResponse,
normalize_agents,
normalize_approvals,
normalize_cron_jobs,
normalize_daemon_status,
normalize_hooks,
normalize_model_aliases,
normalize_model_fallbacks,
normalize_models,
normalize_pairing,
normalize_plugins,
normalize_qr,
normalize_security_audit,
normalize_secrets_audit,
normalize_session_history,
normalize_sessions,
normalize_skill_update,
normalize_skills,
normalize_status,
normalize_update_status,
)
PROJECT_ROOT = Path(__file__).resolve().parents[2]
REFERENCE_OPENCLAW_ROOT = PROJECT_ROOT / "reference" / "openclaw"
REFERENCE_OPENCLAW_ENTRY = REFERENCE_OPENCLAW_ROOT / "openclaw.mjs"
class OpenClawCliError(RuntimeError):
"""Raised when the OpenClaw CLI invocation fails."""
def __init__(
self,
message: str,
*,
command: list[str],
exit_code: int | None = None,
stdout: str = "",
stderr: str = "",
) -> None:
super().__init__(message)
self.command = command
self.exit_code = exit_code
self.stdout = stdout
self.stderr = stderr
@dataclass(frozen=True)
class OpenClawCliResult:
"""Command execution result."""
command: list[str]
exit_code: int
stdout: str
stderr: str
def resolve_openclaw_base_command() -> list[str]:
"""Resolve the command prefix used to launch OpenClaw."""
explicit = os.getenv("OPENCLAW_CMD", "").strip()
if explicit:
return shlex.split(explicit)
installed = shutil.which("openclaw")
if installed:
return [installed]
if REFERENCE_OPENCLAW_ENTRY.exists():
return [sys.executable if sys.executable.endswith("node") else "node", str(REFERENCE_OPENCLAW_ENTRY)]
return ["openclaw"]
def resolve_openclaw_cwd() -> Path:
"""Resolve the working directory for CLI execution."""
explicit = os.getenv("OPENCLAW_CWD", "").strip()
if explicit:
return Path(explicit).expanduser()
if REFERENCE_OPENCLAW_ROOT.exists():
return REFERENCE_OPENCLAW_ROOT
return PROJECT_ROOT
class OpenClawCliService:
"""OpenClaw CLI integration service."""
def __init__(
self,
*,
base_command: list[str] | None = None,
cwd: Path | None = None,
timeout_seconds: float | None = None,
) -> None:
self.base_command = list(base_command or resolve_openclaw_base_command())
self.cwd = cwd or resolve_openclaw_cwd()
self.timeout_seconds = timeout_seconds or float(
os.getenv("OPENCLAW_TIMEOUT_SECONDS", "15")
)
def health(self) -> dict[str, Any]:
"""Return the current CLI wiring state."""
binary = self.base_command[0] if self.base_command else "openclaw"
resolved = shutil.which(binary) if len(self.base_command) == 1 else binary
return {
"status": "healthy",
"service": "openclaw-service",
"base_command": self.base_command,
"cwd": str(self.cwd),
"binary_resolved": resolved is not None,
"reference_entry_available": REFERENCE_OPENCLAW_ENTRY.exists(),
"timeout_seconds": self.timeout_seconds,
}
def status(self) -> dict[str, Any]:
"""Read `openclaw status --json`."""
return self.run_json(["status", "--json"])
def list_sessions(self) -> dict[str, Any]:
"""Read `openclaw sessions --json`."""
return self.run_json(["sessions", "--json"])
def get_session(self, session_key: str) -> dict[str, Any]:
"""Resolve a single session out of the sessions list."""
payload = self.list_sessions()
sessions = payload.get("sessions") or []
for item in sessions:
if not isinstance(item, dict):
continue
if item.get("key") == session_key or item.get("sessionKey") == session_key:
return item
raise KeyError(session_key)
def get_session_history(self, session_key: str, *, limit: int = 20) -> dict[str, Any]:
"""Read session history with a JSON-first fallback to raw text."""
args = ["sessions", "history", session_key, "--json", "--limit", str(limit)]
try:
return self.run_json(args)
except OpenClawCliError as exc:
raise exc
except json.JSONDecodeError:
result = self.run(args)
return {
"sessionKey": session_key,
"limit": limit,
"rawText": result.stdout,
}
def list_cron_jobs(self) -> dict[str, Any]:
"""Read `openclaw cron list --json`."""
return self.run_json(["cron", "list", "--json"])
def list_approvals(self) -> dict[str, Any]:
"""Read `openclaw approvals get --json`."""
return self.run_json(["approvals", "get", "--json"])
def list_agents(self) -> dict[str, Any]:
"""Read `openclaw agents list --json`."""
return self.run_json(["agents", "list", "--json"])
def list_skills(self) -> dict[str, Any]:
"""Read `openclaw skills list --json`."""
return self.run_json(["skills", "list", "--json"])
def list_models(self) -> dict[str, Any]:
"""Read `openclaw models list --json`."""
return self.run_json(["models", "list", "--json"])
def list_hooks(self) -> dict[str, Any]:
"""Read `openclaw hooks list --json`."""
return self.run_json(["hooks", "list", "--json"])
def list_plugins(self) -> dict[str, Any]:
"""Read `openclaw plugins list --json`."""
return self.run_json(["plugins", "list", "--json"])
def secrets_audit(self) -> dict[str, Any]:
"""Read `openclaw secrets audit --json`."""
return self.run_json(["secrets", "audit", "--json"])
def security_audit(self) -> dict[str, Any]:
"""Read `openclaw security audit --json`."""
return self.run_json(["security", "audit", "--json"])
def daemon_status(self) -> dict[str, Any]:
"""Read `openclaw daemon status --json`."""
return self.run_json(["daemon", "status", "--json"])
def pairing_list(self) -> dict[str, Any]:
"""Read `openclaw pairing list --json`."""
return self.run_json(["pairing", "list", "--json"])
def qr_code(self) -> dict[str, Any]:
"""Read `openclaw qr --json`."""
return self.run_json(["qr", "--json"])
def update_status(self) -> dict[str, Any]:
"""Read `openclaw update status --json`."""
return self.run_json(["update", "status", "--json"])
def list_model_aliases(self) -> dict[str, Any]:
"""Read `openclaw models aliases list --json`."""
return self.run_json(["models", "aliases", "list", "--json"])
def list_model_fallbacks(self) -> dict[str, Any]:
"""Read `openclaw models fallbacks list --json`."""
return self.run_json(["models", "fallbacks", "list", "--json"])
def list_model_image_fallbacks(self) -> dict[str, Any]:
"""Read `openclaw models image-fallbacks list --json`."""
return self.run_json(["models", "image-fallbacks", "list", "--json"])
def skill_update(self, *, slug: str | None = None, all: bool = False) -> dict[str, Any]:
"""Read `openclaw skills update --json`."""
args = ["skills", "update", "--json"]
if slug:
args.append(slug)
if all:
args.append("--all")
return self.run_json(args)
def models_status(self, *, probe: bool = False) -> dict[str, Any]:
"""Read `openclaw models status --json [--probe]`."""
args = ["models", "status", "--json"]
if probe:
args.append("--probe")
return self.run_json(args)
def channels_status(self, *, probe: bool = False) -> dict[str, Any]:
"""Read `openclaw channels status [--probe] --json`."""
args = ["channels", "status", "--json"]
if probe:
args.append("--probe")
return self.run_json(args)
def list_workspace_files(self, workspace_path: str) -> dict[str, Any]:
"""List .md files in an OpenClaw agent workspace with their content.
Reads the workspace directory and returns metadata + content for each .md file.
"""
import json
from pathlib import Path
wp = Path(workspace_path).expanduser().resolve()
if not wp.exists() or not wp.is_dir():
return {"workspace": str(wp), "files": [], "error": "workspace not found"}
md_files = sorted(wp.glob("*.md"))
files = []
for md_file in md_files:
try:
content = md_file.read_text(encoding="utf-8")
# Preview: first 300 chars
preview = content[:300].strip()
files.append({
"name": md_file.name,
"path": str(md_file),
"size": len(content),
"preview": preview,
"previewTruncated": len(content) > 300,
})
except OSError as exc:
files.append({
"name": md_file.name,
"path": str(md_file),
"size": 0,
"preview": "",
"error": str(exc),
})
return {"workspace": str(wp), "files": files}
def channels_list(self) -> dict[str, Any]:
"""Read `openclaw channels list --json`."""
return self.run_json(["channels", "list", "--json"])
def hook_info(self, name: str) -> dict[str, Any]:
"""Read `openclaw hooks info <name> --json`."""
args = ["hooks", "info", name, "--json"]
try:
return self.run_json(args)
except json.JSONDecodeError:
result = self.run(args)
return {"raw": result.stdout}
def hooks_check(self) -> dict[str, Any]:
"""Read `openclaw hooks check --json`."""
return self.run_json(["hooks", "check", "--json"])
def plugins_inspect(self, *, plugin_id: str | None = None, all: bool = False) -> dict[str, Any]:
"""Read `openclaw plugins inspect [--json] [--all]`."""
args = ["plugins", "inspect", "--json"]
if all:
args.append("--all")
elif plugin_id:
args.append(plugin_id)
return self.run_json(args)
# -------------------------------------------------------------------------
# Typed variants — these use Pydantic models and are the preferred path.
# -------------------------------------------------------------------------
def status_model(self) -> OpenClawStatus:
"""Read and parse `openclaw status --json` into a typed model."""
raw = self.status()
return normalize_status(raw)
def list_sessions_model(self) -> SessionsList:
"""Read and parse `openclaw sessions --json` into a typed model."""
raw = self.list_sessions()
return normalize_sessions(raw)
def get_session_model(self, session_key: str) -> SessionEntry:
"""Resolve a single session and return a typed model."""
raw = self.get_session(session_key)
return SessionEntry.model_validate(raw, strict=False)
def get_session_history_model(self, session_key: str, *, limit: int = 20) -> SessionHistory:
"""Read session history and return a typed model."""
raw = self.get_session_history(session_key, limit=limit)
return normalize_session_history(raw, session_key=session_key)
def list_cron_jobs_model(self) -> CronList:
"""Read and parse `openclaw cron list --json` into a typed model."""
raw = self.list_cron_jobs()
return normalize_cron_jobs(raw)
def list_approvals_model(self) -> ApprovalsList:
"""Read and parse `openclaw approvals get --json` into a typed model."""
raw = self.list_approvals()
return normalize_approvals(raw)
# -------------------------------------------------------------------------
# Typed variants
# -------------------------------------------------------------------------
def list_agents_model(self) -> AgentsList:
"""Read and parse `openclaw agents list --json` into a typed model."""
raw = self.list_agents()
if isinstance(raw, list):
return AgentsList(agents=[AgentSummary.model_validate(a, strict=False) for a in raw if isinstance(a, dict)])
return normalize_agents(raw)
def list_skills_model(self) -> SkillStatusReport:
"""Read and parse `openclaw skills list --json` into a typed model."""
raw = self.list_skills()
return normalize_skills(raw)
def list_models_model(self) -> ModelsList:
"""Read and parse `openclaw models list --json` into a typed model."""
raw = self.list_models()
if isinstance(raw, list):
return ModelsList(models=[ModelRow.model_validate(m, strict=False) for m in raw if isinstance(m, dict)])
return normalize_models(raw)
def list_hooks_model(self) -> HookStatusReport:
raw = self.list_hooks()
return normalize_hooks(raw)
def list_plugins_model(self) -> PluginsList:
raw = self.list_plugins()
return normalize_plugins(raw)
def secrets_audit_model(self) -> SecretsAuditReport:
raw = self.secrets_audit()
return normalize_secrets_audit(raw)
def security_audit_model(self) -> SecurityAuditResponse:
raw = self.security_audit()
return normalize_security_audit(raw)
def daemon_status_model(self) -> DaemonStatus:
raw = self.daemon_status()
return normalize_daemon_status(raw)
def pairing_list_model(self) -> PairingListResponse:
raw = self.pairing_list()
return normalize_pairing(raw)
def qr_code_model(self) -> QrCodeResponse:
raw = self.qr_code()
return normalize_qr(raw)
def update_status_model(self) -> UpdateStatusResponse:
raw = self.update_status()
return normalize_update_status(raw)
def list_model_aliases_model(self) -> ModelAliasesList:
raw = self.list_model_aliases()
return normalize_model_aliases(raw)
def list_model_fallbacks_model(self) -> ModelFallbacksList:
raw = self.list_model_fallbacks()
return normalize_model_fallbacks(raw)
def list_model_image_fallbacks_model(self) -> ModelFallbacksList:
raw = self.list_model_image_fallbacks()
return normalize_model_fallbacks(raw)
def skill_update_model(self, *, slug: str | None = None, all: bool = False) -> SkillUpdateResult:
raw = self.skill_update(slug=slug, all=all)
return normalize_skill_update(raw)
def models_status_model(self, *, probe: bool = False) -> dict[str, Any]:
"""Read `openclaw models status --json` and return the raw dict."""
return self.models_status(probe=probe)
def channels_status_model(self, *, probe: bool = False) -> dict[str, Any]:
"""Read `openclaw channels status --json` and return the raw dict."""
return self.channels_status(probe=probe)
def channels_list_model(self) -> dict[str, Any]:
"""Read `openclaw channels list --json` and return the raw dict."""
return self.channels_list()
def hook_info_model(self, name: str) -> dict[str, Any]:
"""Read `openclaw hooks info <name> --json` and return the raw dict."""
return self.hook_info(name)
def hooks_check_model(self) -> dict[str, Any]:
"""Read `openclaw hooks check --json` and return the raw dict."""
return self.hooks_check()
def plugins_inspect_model(self, *, plugin_id: str | None = None, all: bool = False) -> dict[str, Any]:
"""Read `openclaw plugins inspect --json [--all]` and return the raw dict."""
return self.plugins_inspect(plugin_id=plugin_id, all=all)
def agents_bindings(self, *, agent: str | None = None) -> dict[str, Any]:
"""Read `openclaw agents bindings --json [--agent <id>]`."""
args = ["agents", "bindings", "--json"]
if agent:
args.extend(["--agent", agent])
return self.run_json(args)
def agents_bindings_model(self, *, agent: str | None = None) -> dict[str, Any]:
"""Read `openclaw agents bindings --json` and return the raw dict."""
return self.agents_bindings(agent=agent)
def agents_presence(self) -> dict[str, Any]:
"""Read session presence for all agents from runtime session files.
Reads ~/.openclaw/agents/{agentId}/sessions/sessions.json for each agent
and counts sessions in active states within a recency window.
"""
import json
from pathlib import Path
openclaw_home = Path.home() / ".openclaw"
agents_path = openclaw_home / "agents"
if not agents_path.exists():
return {"status": "not_connected", "agents": {}}
ACTIVE_STATES = {
"running", "active", "busy", "blocked", "waiting_approval",
"working", "in_progress", "processing", "thinking", "executing", "streaming",
}
RECENCY_WINDOW_MS = 45 * 60 * 1000 # 45 minutes
result: dict[str, Any] = {"status": "connected", "agents": {}}
try:
for agent_dir in agents_path.iterdir():
if not agent_dir.is_dir():
continue
sessions_file = agent_dir / "sessions" / "sessions.json"
if not sessions_file.exists():
continue
try:
sessions_data = json.loads(sessions_file.read_text())
except (json.JSONDecodeError, OSError):
continue
sessions = sessions_data if isinstance(sessions_data, list) else []
now_ms = 0 # placeholder; we'll skip recency check if no ts field
active_count = 0
for session in sessions:
if not isinstance(session, dict):
continue
state = str(session.get("state") or session.get("status") or "").lower()
if state in ACTIVE_STATES:
active_count += 1
if active_count > 0:
result["agents"][agent_dir.name] = {
"activeSessions": active_count,
"status": "active",
}
else:
result["agents"][agent_dir.name] = {
"activeSessions": 0,
"status": "idle",
}
except OSError:
result["status"] = "partial"
return result
def agents_from_config(self) -> dict[str, Any]:
"""Read agent list directly from openclaw.json config file.
Falls back to scanning ~/.openclaw/agents/ directories when config is absent.
This avoids the CLI timeout from `agents list --json`.
"""
import json
openclaw_home = Path.home() / ".openclaw"
config_path = openclaw_home / "openclaw.json"
if not config_path.exists():
return {"status": "not_connected", "agents": []}
try:
raw = json.loads(config_path.read_text())
except (json.JSONDecodeError, OSError):
return {"status": "partial", "agents": []}
agents_list = raw.get("agents", {}).get("list", [])
if not agents_list:
return {"status": "partial", "agents": [], "detail": "agents.list is empty"}
agents = []
for entry in agents_list:
if not isinstance(entry, dict):
continue
agent_id = entry.get("id", "").strip()
if not agent_id:
continue
agents.append({
"id": agent_id,
"name": entry.get("name", "").strip() or agent_id,
"model": entry.get("model") or "",
"workspace": entry.get("workspace") or "",
"is_default": entry.get("id") == raw.get("agents", {}).get("defaults", {}).get("id"),
})
return {"status": "connected", "agents": agents}
def gateway_status(self, *, url: str | None = None, token: str | None = None) -> dict[str, Any]:
"""Read `openclaw gateway status --json [--url <url>] [--token <token>]`. May fail if gateway is unreachable."""
args = ["gateway", "status", "--json"]
if url:
args.extend(["--url", url])
if token:
args.extend(["--token", token])
return self.run_json(args)
def memory_status(self, *, agent: str | None = None, deep: bool = False) -> dict[str, Any]:
"""Read `openclaw memory status --json [--agent <id>] [--deep]`. Returns array of per-agent status."""
args = ["memory", "status", "--json"]
if agent:
args.extend(["--agent", agent])
if deep:
args.append("--deep")
return self.run_json(args)
# -------------------------------------------------------------------------
# Write agents commands
# -------------------------------------------------------------------------
def agents_add(
self,
name: str,
*,
workspace: str | None = None,
model: str | None = None,
agent_dir: str | None = None,
bind: list[str] | None = None,
non_interactive: bool = False,
) -> dict[str, Any]:
"""Run `openclaw agents add <name> [--workspace <dir>] [--model <id>] [--agent-dir <dir>] [--bind <spec>] [--non-interactive] --json`."""
args = ["agents", "add", name, "--json"]
if workspace:
args.extend(["--workspace", workspace])
if model:
args.extend(["--model", model])
if agent_dir:
args.extend(["--agent-dir", agent_dir])
if bind:
for b in bind:
args.extend(["--bind", b])
if non_interactive:
args.append("--non-interactive")
return self.run_json(args)
def agents_delete(self, id: str, *, force: bool = False) -> dict[str, Any]:
"""Run `openclaw agents delete <id> [--force] --json`."""
args = ["agents", "delete", id, "--json"]
if force:
args.append("--force")
return self.run_json(args)
def agents_bind(
self,
*,
agent: str | None = None,
bind: list[str] | None = None,
) -> dict[str, Any]:
"""Run `openclaw agents bind [--agent <id>] [--bind <spec>] --json`."""
args = ["agents", "bind", "--json"]
if agent:
args.extend(["--agent", agent])
if bind:
for b in bind:
args.extend(["--bind", b])
return self.run_json(args)
def agents_unbind(
self,
*,
agent: str | None = None,
bind: list[str] | None = None,
all: bool = False,
) -> dict[str, Any]:
"""Run `openclaw agents unbind [--agent <id>] [--bind <spec>] [--all] --json`."""
args = ["agents", "unbind", "--json"]
if agent:
args.extend(["--agent", agent])
if bind:
for b in bind:
args.extend(["--bind", b])
if all:
args.append("--all")
return self.run_json(args)
def agents_set_identity(
self,
*,
agent: str | None = None,
workspace: str | None = None,
identity_file: str | None = None,
name: str | None = None,
emoji: str | None = None,
theme: str | None = None,
avatar: str | None = None,
from_identity: bool = False,
) -> dict[str, Any]:
"""Run `openclaw agents set-identity [--agent <id>] [--workspace <dir>] [--identity-file <path>] [--from-identity] [--name <n>] [--emoji <e>] [--theme <t>] [--avatar <a>] --json`."""
args = ["agents", "set-identity", "--json"]
if agent:
args.extend(["--agent", agent])
if workspace:
args.extend(["--workspace", workspace])
if identity_file:
args.extend(["--identity-file", identity_file])
if from_identity:
args.append("--from-identity")
if name:
args.extend(["--name", name])
if emoji:
args.extend(["--emoji", emoji])
if theme:
args.extend(["--theme", theme])
if avatar:
args.extend(["--avatar", avatar])
return self.run_json(args)
def run_json(self, args: list[str]) -> dict[str, Any]:
"""Run the CLI and decode JSON stdout, falling back to stderr."""
result = self.run(args)
text = result.stdout.strip() or result.stderr.strip()
if not text:
return {}
return json.loads(text)
def run(self, args: list[str]) -> OpenClawCliResult:
"""Run the CLI and return stdout/stderr."""
command = [*self.base_command, *args]
env = os.environ.copy()
try:
completed = subprocess.run(
command,
cwd=self.cwd,
env=env,
capture_output=True,
text=True,
timeout=self.timeout_seconds,
check=False,
)
except FileNotFoundError as exc:
raise OpenClawCliError(
"OpenClaw CLI executable was not found.",
command=command,
) from exc
except subprocess.TimeoutExpired as exc:
raise OpenClawCliError(
f"OpenClaw CLI timed out after {self.timeout_seconds:.1f}s.",
command=command,
stdout=exc.stdout or "",
stderr=exc.stderr or "",
) from exc
if completed.returncode != 0:
raise OpenClawCliError(
"OpenClaw CLI command failed.",
command=command,
exit_code=completed.returncode,
stdout=completed.stdout,
stderr=completed.stderr,
)
return OpenClawCliResult(
command=command,
exit_code=completed.returncode,
stdout=completed.stdout,
stderr=completed.stderr,
)

View File

@@ -0,0 +1,280 @@
# -*- coding: utf-8 -*-
"""Query-oriented storage for explain/research data."""
from __future__ import annotations
import json
import sqlite3
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable
from shared.schema import CompanyNews
SCHEMA = """
CREATE TABLE IF NOT EXISTS news_items (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
published_at TEXT,
trade_date TEXT,
source TEXT,
title TEXT NOT NULL,
summary TEXT,
url TEXT,
related TEXT,
category TEXT,
raw_json TEXT NOT NULL,
ingest_run_date TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
ON news_items (ticker, trade_date DESC, published_at DESC);
"""
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
base = item.url or item.title or f"{ticker}-{fallback_index}"
return f"{ticker}:{base}"
def _resolve_trade_date(date_value: str | None) -> str | None:
if not date_value:
return None
normalized = str(date_value).strip()
if not normalized:
return None
if "T" in normalized:
return normalized.split("T", 1)[0]
if " " in normalized:
return normalized.split(" ", 1)[0]
return normalized[:10]
class ResearchDb:
"""Small SQLite helper for explain-oriented news storage."""
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(SCHEMA)
def upsert_news_items(
self,
*,
ticker: str,
items: Iterable[CompanyNews],
ingest_run_date: str | None = None,
) -> list[dict[str, Any]]:
"""Persist provider news and return normalized rows."""
normalized_rows: list[dict[str, Any]] = []
timestamp = datetime.utcnow().isoformat(timespec="seconds")
symbol = str(ticker or "").strip().upper()
if not symbol:
return normalized_rows
with self._connect() as conn:
for index, item in enumerate(items):
news_id = _resolve_news_id(symbol, item, index)
trade_date = _resolve_trade_date(item.date)
payload = item.model_dump()
row = {
"id": news_id,
"ticker": symbol,
"published_at": item.date,
"trade_date": trade_date,
"source": item.source,
"title": item.title,
"summary": item.summary,
"url": item.url,
"related": item.related,
"category": item.category,
"raw_json": _json_dumps(payload),
"ingest_run_date": ingest_run_date,
"created_at": timestamp,
}
conn.execute(
"""
INSERT INTO news_items
(id, ticker, published_at, trade_date, source, title, summary, url,
related, category, raw_json, ingest_run_date, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
ticker = excluded.ticker,
published_at = excluded.published_at,
trade_date = excluded.trade_date,
source = excluded.source,
title = excluded.title,
summary = excluded.summary,
url = excluded.url,
related = excluded.related,
category = excluded.category,
raw_json = excluded.raw_json,
ingest_run_date = excluded.ingest_run_date
""",
(
row["id"],
row["ticker"],
row["published_at"],
row["trade_date"],
row["source"],
row["title"],
row["summary"],
row["url"],
row["related"],
row["category"],
row["raw_json"],
row["ingest_run_date"],
row["created_at"],
),
)
normalized_rows.append(row)
return normalized_rows
def get_news_items(
self,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 20,
) -> list[dict[str, Any]]:
"""Return normalized news rows for explain UI."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return []
sql = """
SELECT id, ticker, published_at, trade_date, source, title, summary,
url, related, category
FROM news_items
WHERE ticker = ?
"""
params: list[Any] = [symbol]
if start_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
params.append(start_date)
if end_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
params.append(end_date)
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
params.append(max(1, int(limit)))
with self._connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [
{
"id": row["id"],
"ticker": row["ticker"],
"date": row["published_at"] or row["trade_date"],
"trade_date": row["trade_date"],
"source": row["source"],
"title": row["title"],
"summary": row["summary"],
"url": row["url"],
"related": row["related"],
"category": row["category"],
}
for row in rows
]
def get_news_timeline(
self,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
) -> list[dict[str, Any]]:
"""Aggregate news counts per trade date for chart markers."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return []
sql = """
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
COUNT(*) AS count,
COUNT(DISTINCT source) AS source_count,
MAX(title) AS top_title
FROM news_items
WHERE ticker = ?
"""
params: list[Any] = [symbol]
if start_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
params.append(start_date)
if end_date:
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
params.append(end_date)
sql += """
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
ORDER BY date ASC
"""
with self._connect() as conn:
rows = conn.execute(sql, params).fetchall()
return [
{
"date": row["date"],
"count": int(row["count"] or 0),
"source_count": int(row["source_count"] or 0),
"top_title": row["top_title"] or "",
}
for row in rows
if row["date"]
]
def get_news_by_ids(
self,
*,
ticker: str,
article_ids: Iterable[str],
) -> list[dict[str, Any]]:
"""Return selected persisted news items."""
symbol = str(ticker or "").strip().upper()
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
if not symbol or not ids:
return []
placeholders = ",".join("?" for _ in ids)
sql = f"""
SELECT id, ticker, published_at, trade_date, source, title, summary,
url, related, category
FROM news_items
WHERE ticker = ? AND id IN ({placeholders})
ORDER BY COALESCE(published_at, trade_date) DESC
"""
with self._connect() as conn:
rows = conn.execute(sql, [symbol, *ids]).fetchall()
return [
{
"id": row["id"],
"ticker": row["ticker"],
"date": row["published_at"] or row["trade_date"],
"trade_date": row["trade_date"],
"source": row["source"],
"title": row["title"],
"summary": row["summary"],
"url": row["url"],
"related": row["related"],
"category": row["category"],
}
for row in rows
]

View File

@@ -0,0 +1,512 @@
# -*- coding: utf-8 -*-
"""Run-scoped SQLite storage for query-oriented runtime history."""
from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
SCHEMA = """
CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
timestamp TEXT,
agent_id TEXT,
agent_name TEXT,
ticker TEXT,
title TEXT,
content TEXT,
payload_json TEXT NOT NULL,
run_date TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS trades (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
side TEXT,
qty REAL,
price REAL,
timestamp TEXT,
trading_date TEXT,
agent_id TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS signals (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
agent_id TEXT,
agent_name TEXT,
role TEXT,
signal TEXT,
confidence REAL,
reasoning_json TEXT,
reasons_json TEXT,
risks_json TEXT,
invalidation TEXT,
next_action TEXT,
intrinsic_value REAL,
fair_value_range_json TEXT,
value_gap_pct REAL,
valuation_methods_json TEXT,
real_return REAL,
is_correct TEXT,
trade_date TEXT,
created_at TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
CREATE TABLE IF NOT EXISTS price_points (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
timestamp TEXT NOT NULL,
price REAL NOT NULL,
open_price REAL,
ret REAL,
source TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
"""
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def _hash_key(*parts: Any) -> str:
raw = "::".join("" if part is None else str(part) for part in parts)
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
class RuntimeDb:
"""Small SQLite helper for append-mostly runtime data."""
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(SCHEMA)
def insert_event(self, event: Dict[str, Any]):
payload = dict(event or {})
if not payload:
return
event_id = payload.get("id") or _hash_key(
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("content"),
payload.get("title"),
)
ticker = payload.get("ticker")
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
ticker = payload["tickers"][0]
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO events
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
event_id,
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("agentName") or payload.get("agent_name"),
ticker,
payload.get("title"),
payload.get("content"),
_json_dumps(payload),
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
),
)
def get_recent_feed_events(
self,
*,
limit: int = 200,
event_types: Optional[Iterable[str]] = None,
) -> list[Dict[str, Any]]:
"""Return recent persisted feed events in newest-first order."""
event_types = tuple(event_types or ())
sql = """
SELECT payload_json
FROM events
"""
params: list[Any] = []
if event_types:
placeholders = ",".join("?" for _ in event_types)
sql += f" WHERE event_type IN ({placeholders})"
params.extend(event_types)
sql += " ORDER BY timestamp DESC LIMIT ?"
params.append(max(1, int(limit)))
with self._connect() as conn:
rows = conn.execute(sql, params).fetchall()
items: list[Dict[str, Any]] = []
for row in rows:
try:
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
except json.JSONDecodeError:
payload = {}
if payload:
items.append(payload)
return items
def get_last_day_feed_events(
self,
*,
current_date: Optional[str] = None,
limit: int = 200,
event_types: Optional[Iterable[str]] = None,
) -> list[Dict[str, Any]]:
"""Return latest trading day events in newest-first order for replay."""
event_types = tuple(event_types or ())
target_date = str(current_date or "").strip() or None
with self._connect() as conn:
if not target_date:
row = conn.execute(
"""
SELECT run_date
FROM events
WHERE run_date IS NOT NULL AND TRIM(run_date) != ''
ORDER BY run_date DESC
LIMIT 1
"""
).fetchone()
target_date = row["run_date"] if row else None
if not target_date:
return []
sql = """
SELECT payload_json
FROM events
WHERE run_date = ?
"""
params: list[Any] = [target_date]
if event_types:
placeholders = ",".join("?" for _ in event_types)
sql += f" AND event_type IN ({placeholders})"
params.extend(event_types)
sql += " ORDER BY timestamp DESC LIMIT ?"
params.append(max(1, int(limit)))
rows = conn.execute(sql, params).fetchall()
items: list[Dict[str, Any]] = []
for row in rows:
try:
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
except json.JSONDecodeError:
payload = {}
if payload:
items.append(payload)
return items
def upsert_trade(self, trade: Dict[str, Any]):
payload = dict(trade or {})
if not payload:
return
trade_id = payload.get("id") or _hash_key(
payload.get("ticker"),
payload.get("timestamp") or payload.get("ts"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO trades
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
trade_id,
payload.get("ticker"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
payload.get("timestamp") or payload.get("ts"),
payload.get("trading_date"),
payload.get("agentId") or payload.get("agent_id"),
_json_dumps(payload),
),
)
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
return
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
fair_value_range_json, value_gap_pct, valuation_methods_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
_json_dumps(payload.get("reasons")),
_json_dumps(payload.get("risks")),
payload.get("invalidation"),
payload.get("next_action"),
payload.get("intrinsic_value"),
_json_dumps(payload.get("fair_value_range")),
payload.get("value_gap_pct"),
_json_dumps(payload.get("valuation_methods")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
with self._connect() as conn:
conn.execute("DELETE FROM signals")
for agent in leaderboard:
agent_id = agent.get("agentId")
agent_name = agent.get("name")
role = agent.get("role")
for signal in agent.get("signals", []) or []:
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
continue
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
conn.execute(
"""
INSERT INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
fair_value_range_json, value_gap_pct, valuation_methods_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
_json_dumps(payload.get("reasons")),
_json_dumps(payload.get("risks")),
payload.get("invalidation"),
payload.get("next_action"),
payload.get("intrinsic_value"),
_json_dumps(payload.get("fair_value_range")),
payload.get("value_gap_pct"),
_json_dumps(payload.get("valuation_methods")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def insert_price_point(
self,
*,
ticker: str,
timestamp: str,
price: float,
open_price: Optional[float] = None,
ret: Optional[float] = None,
source: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
):
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO price_points
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
price_id,
ticker,
timestamp,
price,
open_price,
ret,
source,
_json_dumps(meta or {}),
),
)
def get_stock_explain_snapshot(
self,
ticker: str,
*,
limit_events: int = 24,
limit_trades: int = 12,
limit_signals: int = 12,
) -> Dict[str, list[Dict[str, Any]]]:
"""Fetch query-oriented history for a single ticker."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return {"events": [], "trades": [], "signals": []}
with self._connect() as conn:
trade_rows = conn.execute(
"""
SELECT * FROM trades
WHERE ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(symbol, limit_trades),
).fetchall()
signal_rows = conn.execute(
"""
SELECT * FROM signals
WHERE ticker = ?
ORDER BY trade_date DESC, created_at DESC
LIMIT ?
""",
(symbol, limit_signals),
).fetchall()
event_rows = conn.execute(
"""
SELECT * FROM events
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
).fetchall()
normalized_events = []
seen_event_ids: set[str] = set()
for row in event_rows:
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
content = str(row["content"] or payload.get("content") or "")
title = str(row["title"] or payload.get("title") or "")
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
continue
event_id = row["id"]
if event_id in seen_event_ids:
continue
seen_event_ids.add(event_id)
normalized_events.append(
{
"id": event_id,
"type": "mention",
"timestamp": row["timestamp"],
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
"meta": payload.get("conferenceTitle")
or payload.get("feedType")
or row["event_type"],
"body": content,
"tone": "neutral",
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
},
)
if len(normalized_events) >= limit_events:
break
normalized_trades = [
{
"id": row["id"],
"type": "trade",
"timestamp": row["timestamp"],
"title": f"{row['side']} {int(row['qty'] or 0)}",
"meta": "交易执行",
"body": f"成交价 ${float(row['price'] or 0):.2f}",
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
}
for row in trade_rows
]
normalized_signals = [
{
"id": row["id"],
"type": "signal",
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
"meta": row["role"],
"body": (
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
if row["real_return"] is not None
else "该信号暂未完成后验评估"
),
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
# Extended signal fields
"signal": row["signal"],
"confidence": row["confidence"],
"reasoning": json.loads(row["reasoning_json"]) if row["reasoning_json"] else None,
"reasons": json.loads(row["reasons_json"]) if row["reasons_json"] else None,
"risks": json.loads(row["risks_json"]) if row["risks_json"] else None,
"invalidation": row["invalidation"],
"next_action": row["next_action"],
"intrinsic_value": row["intrinsic_value"],
"fair_value_range": json.loads(row["fair_value_range_json"]) if row["fair_value_range_json"] else None,
"value_gap_pct": row["value_gap_pct"],
"valuation_methods": json.loads(row["valuation_methods_json"]) if row["valuation_methods_json"] else None,
}
for row in signal_rows
]
return {
"events": normalized_events,
"trades": normalized_trades,
"signals": normalized_signals,
}

1256
backend/services/storage.py Normal file

File diff suppressed because it is too large Load Diff