Initial commit of integrated agent system
This commit is contained in:
2
backend/services/__init__.py
Normal file
2
backend/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Services layer for infrastructure components"""
|
||||
923
backend/services/gateway.py
Normal file
923
backend/services/gateway.py
Normal file
@@ -0,0 +1,923 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket Gateway for frontend communication
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.domains import news as news_domain
|
||||
from backend.llm.models import get_agent_model_info
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.technical_signals import StockTechnicalAnalyzer
|
||||
from backend.core.scheduler import Scheduler
|
||||
from backend.services import gateway_admin_handlers
|
||||
from backend.services import gateway_cycle_support
|
||||
from backend.services import gateway_openclaw_handlers
|
||||
from backend.services import gateway_runtime_support
|
||||
from backend.services import gateway_stock_handlers
|
||||
from shared.client import NewsServiceClient
|
||||
from shared.client import TradingServiceClient
|
||||
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient, DEFAULT_GATEWAY_URL as OPENCLAW_WS_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
EDITABLE_AGENT_WORKSPACE_FILES = {
|
||||
"SOUL.md",
|
||||
"PROFILE.md",
|
||||
"AGENTS.md",
|
||||
"MEMORY.md",
|
||||
"POLICY.md",
|
||||
}
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""WebSocket Gateway for frontend communication"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
market_service: MarketService,
|
||||
storage_service: StorageService,
|
||||
pipeline: TradingPipeline,
|
||||
state_sync: Optional[StateSync] = None,
|
||||
scheduler_callback: Optional[Callable] = None,
|
||||
scheduler: Optional[Scheduler] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
self.market_service = market_service
|
||||
self.storage = storage_service
|
||||
self.pipeline = pipeline
|
||||
self.scheduler_callback = scheduler_callback
|
||||
self.scheduler = scheduler
|
||||
self.config = config or {}
|
||||
|
||||
self.mode = self.config.get("mode", "live")
|
||||
self.is_backtest = self.mode == "backtest" or self.config.get(
|
||||
"backtest_mode",
|
||||
False,
|
||||
)
|
||||
|
||||
self.state_sync = state_sync or StateSync(storage=storage_service)
|
||||
# self.state_sync.set_mode(self.is_backtest)
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[ServerConnection] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._cycle_lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._manual_cycle_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
self._watchlist_ingest_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
self._provider_router = get_provider_router()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._project_root = Path(__file__).resolve().parents[2]
|
||||
self._technical_analyzer = StockTechnicalAnalyzer()
|
||||
self._openclaw_ws: OpenClawWebSocketClient | None = None
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server with proper initialization order.
|
||||
|
||||
Phase 1: Start WebSocket server first so frontend can connect immediately
|
||||
Phase 2: Start market data service (pushes data to connected clients)
|
||||
Phase 3: Start scheduler last (triggers trading cycles)
|
||||
"""
|
||||
logger.info(f"Starting gateway on {host}:{port}")
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._provider_router.add_listener(self._on_provider_usage_changed)
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||
self.state_sync.update_state("status", "initializing")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"schedule_mode": self.config.get("schedule_mode", "daily"),
|
||||
"interval_minutes": self.config.get("interval_minutes", 60),
|
||||
"trigger_time": self.config.get("trigger_time", "09:30"),
|
||||
"initial_cash": self.config.get(
|
||||
"initial_cash",
|
||||
self.storage.initial_cash,
|
||||
),
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
self.state_sync.update_state(
|
||||
"data_sources",
|
||||
self._provider_router.get_usage_snapshot(),
|
||||
)
|
||||
|
||||
# Load and display existing portfolio state if available
|
||||
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self.state_sync.state)
|
||||
summary = dashboard_snapshot.get("summary")
|
||||
if summary:
|
||||
logger.info(
|
||||
"Loaded existing portfolio: $%s",
|
||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 1: Start WebSocket server first
|
||||
# This allows frontend to connect immediately and receive status updates
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1/4] Starting WebSocket server...")
|
||||
self.state_sync.update_state("status", "websocket_ready")
|
||||
|
||||
# Create server but don't block yet - we'll serve inside the context manager
|
||||
server = await websockets.serve(
|
||||
self.handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=30,
|
||||
ping_timeout=60,
|
||||
)
|
||||
logger.info(f"WebSocket server ready: ws://{host}:{port}")
|
||||
|
||||
# Give a brief moment for any existing clients to reconnect
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Connect to OpenClaw Gateway (18789) via WebSocket
|
||||
logger.info("Connecting to OpenClaw Gateway...")
|
||||
try:
|
||||
self._openclaw_ws = OpenClawWebSocketClient(
|
||||
url=OPENCLAW_WS_URL,
|
||||
client_name="gateway-client",
|
||||
client_version="1.0.0",
|
||||
)
|
||||
await self._openclaw_ws.connect()
|
||||
logger.info("OpenClaw Gateway WebSocket connected")
|
||||
except Exception as e:
|
||||
logger.warning("Failed to connect to OpenClaw Gateway: %s", e)
|
||||
self._openclaw_ws = None
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 2: Start market data service
|
||||
# Now frontend is connected, start pushing price updates
|
||||
# ======================================================================
|
||||
logger.info("[Phase 2/4] Starting market data service...")
|
||||
self.state_sync.update_state("status", "market_service_starting")
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
self.state_sync.update_state("status", "market_service_ready")
|
||||
logger.info("Market data service ready - price updates active")
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Start market status monitoring
|
||||
# Monitors market open/close and broadcasts status
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/4] Starting market status monitoring...")
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Start scheduler last
|
||||
# Only start trading after everything else is ready
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/4] Starting scheduler...")
|
||||
self.state_sync.update_state("status", "scheduler_starting")
|
||||
|
||||
if self.scheduler:
|
||||
# Wire up heartbeat callback if heartbeat is configured
|
||||
heartbeat_interval = self.config.get("heartbeat_interval", 0)
|
||||
if heartbeat_interval and heartbeat_interval > 0:
|
||||
self.scheduler.set_heartbeat_callback(self.on_heartbeat_trigger)
|
||||
logger.info(
|
||||
f"[Heartbeat] Registered heartbeat callback (interval={heartbeat_interval}s)",
|
||||
)
|
||||
await self.scheduler.start(self.on_strategy_trigger)
|
||||
elif self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
self.state_sync.update_state("status", "running")
|
||||
logger.info(
|
||||
f"Gateway fully operational: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
|
||||
"""Handle provider routing updates from the shared router."""
|
||||
self.state_sync.update_state("data_sources", snapshot)
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.broadcast(
|
||||
{
|
||||
"type": "data_sources_update",
|
||||
"data_sources": snapshot,
|
||||
},
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
@staticmethod
|
||||
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
|
||||
return news_domain.news_rows_need_enrichment(rows)
|
||||
|
||||
def _news_service_url(self) -> str | None:
|
||||
"""Return configured news-service base URL, if any."""
|
||||
candidate = self.config.get("news_service_url") or os.getenv(
|
||||
"NEWS_SERVICE_URL",
|
||||
"",
|
||||
)
|
||||
value = str(candidate or "").strip()
|
||||
return value or None
|
||||
|
||||
def _trading_service_url(self) -> str | None:
|
||||
"""Return configured trading-service base URL, if any."""
|
||||
candidate = self.config.get("trading_service_url") or os.getenv(
|
||||
"TRADING_SERVICE_URL",
|
||||
"",
|
||||
)
|
||||
value = str(candidate or "").strip()
|
||||
return value or None
|
||||
|
||||
async def _call_news_service(
|
||||
self,
|
||||
action: str,
|
||||
callback: Callable[[NewsServiceClient], Any],
|
||||
) -> Any | None:
|
||||
"""Call news-service when configured, otherwise return None."""
|
||||
service_url = self._news_service_url()
|
||||
if not service_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with NewsServiceClient(service_url) as client:
|
||||
return await callback(client)
|
||||
except Exception as exc:
|
||||
logger.warning("news-service %s failed: %s", action, exc)
|
||||
return None
|
||||
|
||||
async def _call_trading_service(
|
||||
self,
|
||||
action: str,
|
||||
callback: Callable[[TradingServiceClient], Any],
|
||||
) -> Any | None:
|
||||
"""Call trading-service when configured, otherwise return None."""
|
||||
service_url = self._trading_service_url()
|
||||
if not service_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with TradingServiceClient(service_url) as client:
|
||||
return await callback(client)
|
||||
except Exception as exc:
|
||||
logger.warning("trading-service %s failed: %s", action, exc)
|
||||
return None
|
||||
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
|
||||
await self._send_initial_state(websocket)
|
||||
await self._handle_client_messages(websocket)
|
||||
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: ServerConnection):
|
||||
try:
|
||||
logger.info("[Gateway] Sending initial state to client...")
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
state_payload["data_sources"] = (
|
||||
self._provider_router.get_usage_snapshot()
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
logger.info("[Gateway] Initial state sent successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"[Gateway] Failed to send initial state: {e}")
|
||||
# Send error response so client knows something went wrong
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "error", "message": "Failed to load initial state"},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send error response to client: {e}")
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type", "unknown")
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
elif msg_type == "get_state":
|
||||
await self._send_initial_state(websocket)
|
||||
elif msg_type == "start_backtest":
|
||||
await self._handle_start_backtest(data)
|
||||
elif msg_type == "trigger_strategy":
|
||||
await self._handle_manual_trigger(websocket, data)
|
||||
elif msg_type == "update_runtime_config":
|
||||
await self._handle_update_runtime_config(websocket, data)
|
||||
elif msg_type == "reload_runtime_assets":
|
||||
await self._handle_reload_runtime_assets()
|
||||
elif msg_type == "update_watchlist":
|
||||
await self._handle_update_watchlist(websocket, data)
|
||||
elif msg_type == "get_agent_skills":
|
||||
await self._handle_get_agent_skills(websocket, data)
|
||||
elif msg_type == "get_agent_profile":
|
||||
await self._handle_get_agent_profile(websocket, data)
|
||||
elif msg_type == "get_skill_detail":
|
||||
await self._handle_get_skill_detail(websocket, data)
|
||||
elif msg_type == "create_agent_local_skill":
|
||||
await self._handle_create_agent_local_skill(websocket, data)
|
||||
elif msg_type == "update_agent_local_skill":
|
||||
await self._handle_update_agent_local_skill(websocket, data)
|
||||
elif msg_type == "delete_agent_local_skill":
|
||||
await self._handle_delete_agent_local_skill(websocket, data)
|
||||
elif msg_type == "remove_agent_skill":
|
||||
await self._handle_remove_agent_skill(websocket, data)
|
||||
elif msg_type == "update_agent_skill":
|
||||
await self._handle_update_agent_skill(websocket, data)
|
||||
elif msg_type == "get_agent_workspace_file":
|
||||
await self._handle_get_agent_workspace_file(websocket, data)
|
||||
elif msg_type == "update_agent_workspace_file":
|
||||
await self._handle_update_agent_workspace_file(websocket, data)
|
||||
elif msg_type == "get_stock_history":
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
elif msg_type == "get_stock_news":
|
||||
await self._handle_get_stock_news(websocket, data)
|
||||
elif msg_type == "get_stock_news_for_date":
|
||||
await self._handle_get_stock_news_for_date(websocket, data)
|
||||
elif msg_type == "get_stock_news_timeline":
|
||||
await self._handle_get_stock_news_timeline(websocket, data)
|
||||
elif msg_type == "get_stock_news_categories":
|
||||
await self._handle_get_stock_news_categories(websocket, data)
|
||||
elif msg_type == "get_stock_range_explain":
|
||||
await self._handle_get_stock_range_explain(websocket, data)
|
||||
elif msg_type == "get_stock_insider_trades":
|
||||
await self._handle_get_stock_insider_trades(websocket, data)
|
||||
elif msg_type == "get_stock_story":
|
||||
await self._handle_get_stock_story(websocket, data)
|
||||
elif msg_type == "get_stock_similar_days":
|
||||
await self._handle_get_stock_similar_days(websocket, data)
|
||||
elif msg_type == "get_stock_technical_indicators":
|
||||
await self._handle_get_stock_technical_indicators(websocket, data)
|
||||
elif msg_type == "run_stock_enrich":
|
||||
await self._handle_run_stock_enrich(websocket, data)
|
||||
elif msg_type == "get_openclaw_status":
|
||||
await self._handle_get_openclaw_status(websocket, data)
|
||||
elif msg_type == "get_openclaw_sessions":
|
||||
await self._handle_get_openclaw_sessions(websocket, data)
|
||||
elif msg_type == "get_openclaw_session_detail":
|
||||
await self._handle_get_openclaw_session_detail(websocket, data)
|
||||
elif msg_type == "get_openclaw_session_history":
|
||||
await self._handle_get_openclaw_session_history(websocket, data)
|
||||
elif msg_type == "get_openclaw_cron":
|
||||
await self._handle_get_openclaw_cron(websocket, data)
|
||||
elif msg_type == "get_openclaw_approvals":
|
||||
await self._handle_get_openclaw_approvals(websocket, data)
|
||||
elif msg_type == "get_openclaw_agents":
|
||||
await self._handle_get_openclaw_agents(websocket, data)
|
||||
elif msg_type == "get_openclaw_agents_presence":
|
||||
await self._handle_get_openclaw_agents_presence(websocket, data)
|
||||
elif msg_type == "get_openclaw_skills":
|
||||
await self._handle_get_openclaw_skills(websocket, data)
|
||||
elif msg_type == "get_openclaw_models":
|
||||
await self._handle_get_openclaw_models(websocket, data)
|
||||
elif msg_type == "get_openclaw_hooks":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_hooks(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_plugins":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_plugins(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_secrets_audit":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_secrets_audit(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_security_audit":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_security_audit(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_daemon_status":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_daemon_status(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_pairing":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_pairing(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_qr":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_qr(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_update_status":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_update_status(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_models_aliases":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_models_aliases(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_models_fallbacks":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_models_fallbacks(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_models_image_fallbacks":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_models_image_fallbacks(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_skill_update":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_skill_update(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_workspace_files":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_workspace_files(self, websocket, data)
|
||||
elif msg_type == "get_openclaw_workspace_file":
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_workspace_file(self, websocket, data)
|
||||
elif msg_type == "openclaw_resolve_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_resolve_session(self, websocket, data)
|
||||
elif msg_type == "openclaw_create_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_create_session(self, websocket, data)
|
||||
elif msg_type == "openclaw_send_message":
|
||||
await gateway_openclaw_handlers.handle_openclaw_send_message(self, websocket, data)
|
||||
elif msg_type == "openclaw_subscribe_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_subscribe_session(self, websocket, data)
|
||||
elif msg_type == "openclaw_unsubscribe_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_unsubscribe_session(self, websocket, data)
|
||||
elif msg_type == "openclaw_reset_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_reset_session(self, websocket, data)
|
||||
elif msg_type == "openclaw_delete_session":
|
||||
await gateway_openclaw_handlers.handle_openclaw_delete_session(self, websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
finally:
|
||||
subscriber_map = getattr(self, "_openclaw_session_subscribers", None)
|
||||
if isinstance(subscriber_map, dict):
|
||||
subscriber_map.pop(websocket, None)
|
||||
|
||||
async def _handle_get_stock_history(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_history(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_explain_events(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_explain_events(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_news(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_news(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_news_for_date(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_news_for_date(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_news_timeline(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_news_timeline(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_news_categories(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_news_categories(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_range_explain(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_range_explain(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_insider_trades(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_insider_trades(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_story(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_story(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_similar_days(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_similar_days(self, websocket, data)
|
||||
|
||||
async def _handle_get_stock_technical_indicators(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_get_stock_technical_indicators(self, websocket, data)
|
||||
|
||||
async def _handle_run_stock_enrich(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
await gateway_stock_handlers.handle_run_stock_enrich(self, websocket, data)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
dates = data.get("dates", [])
|
||||
if dates and self._backtest_task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_backtest_dates(dates),
|
||||
)
|
||||
task.add_done_callback(self._handle_backtest_exception)
|
||||
self._backtest_task = task
|
||||
|
||||
async def _handle_manual_trigger(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Run one live trading cycle on demand."""
|
||||
if self.is_backtest:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "Manual trigger is only available in live mode.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
if (
|
||||
self._cycle_lock.locked()
|
||||
or (
|
||||
self._manual_cycle_task is not None
|
||||
and not self._manual_cycle_task.done()
|
||||
)
|
||||
):
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "A trading cycle is already running.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
await self.state_sync.on_system_message("已有任务在运行,已忽略手动触发")
|
||||
return
|
||||
|
||||
requested_date = data.get("date")
|
||||
await self.state_sync.on_system_message("收到手动触发请求,准备开始新一轮分析与决策")
|
||||
task = asyncio.create_task(
|
||||
self.on_strategy_trigger(
|
||||
date=requested_date or datetime.now().strftime("%Y-%m-%d"),
|
||||
),
|
||||
)
|
||||
task.add_done_callback(self._handle_manual_cycle_exception)
|
||||
self._manual_cycle_task = task
|
||||
|
||||
async def _handle_reload_runtime_assets(self):
|
||||
await gateway_admin_handlers.handle_reload_runtime_assets(self)
|
||||
|
||||
async def _handle_update_runtime_config(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_update_runtime_config(self, websocket, data)
|
||||
|
||||
async def _handle_update_watchlist(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_update_watchlist(self, websocket, data)
|
||||
|
||||
async def _handle_get_agent_skills(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_get_agent_skills(self, websocket, data)
|
||||
|
||||
async def _handle_get_agent_profile(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_get_agent_profile(self, websocket, data)
|
||||
|
||||
async def _handle_get_skill_detail(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_get_skill_detail(self, websocket, data)
|
||||
|
||||
async def _handle_create_agent_local_skill(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_create_agent_local_skill(self, websocket, data)
|
||||
|
||||
async def _handle_update_agent_local_skill(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_update_agent_local_skill(self, websocket, data)
|
||||
|
||||
async def _handle_delete_agent_local_skill(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_delete_agent_local_skill(self, websocket, data)
|
||||
|
||||
async def _handle_remove_agent_skill(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_remove_agent_skill(self, websocket, data)
|
||||
|
||||
async def _handle_update_agent_skill(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_update_agent_skill(self, websocket, data)
|
||||
|
||||
async def _handle_get_agent_workspace_file(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_get_agent_workspace_file(self, websocket, data)
|
||||
|
||||
async def _handle_update_agent_workspace_file(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_admin_handlers.handle_update_agent_workspace_file(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_status(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_status(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_sessions(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_sessions(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_session_detail(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_session_detail(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_session_history(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_session_history(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_cron(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_cron(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_approvals(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_approvals(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_agents(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_agents(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_agents_presence(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_agents_presence(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_skills(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_skills(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_models(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_models(self, websocket, data)
|
||||
|
||||
async def _handle_get_openclaw_workspace_files(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await gateway_openclaw_handlers.handle_get_openclaw_workspace_files(self, websocket, data)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
return gateway_runtime_support.normalize_watchlist(raw_tickers)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_agent_workspace_filename(raw_name: Any) -> Optional[str]:
|
||||
return gateway_runtime_support.normalize_agent_workspace_filename(
|
||||
raw_name,
|
||||
allowlist=EDITABLE_AGENT_WORKSPACE_FILES,
|
||||
)
|
||||
|
||||
def _apply_runtime_config(
|
||||
self,
|
||||
runtime_config: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
return gateway_runtime_support.apply_runtime_config(self, runtime_config)
|
||||
|
||||
def _sync_runtime_state(self) -> None:
|
||||
gateway_runtime_support.sync_runtime_state(self)
|
||||
|
||||
def _schedule_watchlist_market_store_refresh(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
gateway_cycle_support.schedule_watchlist_market_store_refresh(self, tickers)
|
||||
|
||||
async def _refresh_market_store_for_watchlist(
|
||||
self,
|
||||
tickers: List[str],
|
||||
) -> None:
|
||||
await gateway_cycle_support.refresh_market_store_for_watchlist(self, tickers)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message, ensure_ascii=False, default=str)
|
||||
|
||||
async with self.lock:
|
||||
tasks = [
|
||||
self._send_to_client(client, message_json)
|
||||
for client in self.connected_clients.copy()
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: ServerConnection,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.ConnectionClosed:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(client)
|
||||
|
||||
async def _market_status_monitor(self):
|
||||
await gateway_cycle_support.market_status_monitor(self)
|
||||
|
||||
async def _update_and_broadcast_live_returns(self):
|
||||
await gateway_cycle_support.update_and_broadcast_live_returns(self)
|
||||
|
||||
async def on_strategy_trigger(self, date: str):
|
||||
await gateway_cycle_support.on_strategy_trigger(self, date)
|
||||
|
||||
async def on_heartbeat_trigger(self, date: str):
|
||||
await gateway_cycle_support.on_heartbeat_trigger(self, date)
|
||||
|
||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||
await gateway_cycle_support.run_backtest_cycle(self, date, tickers)
|
||||
|
||||
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
||||
await gateway_cycle_support.run_live_cycle(self, date, tickers)
|
||||
|
||||
async def _finalize_cycle(self, date: str):
|
||||
await gateway_cycle_support.finalize_cycle(self, date)
|
||||
|
||||
async def _get_market_caps(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
) -> Dict[str, float]:
|
||||
return await gateway_cycle_support.get_market_caps(self, tickers, date)
|
||||
|
||||
async def _broadcast_portfolio_updates(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
await gateway_cycle_support.broadcast_portfolio_updates(self, result, prices)
|
||||
|
||||
def _save_cycle_results(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
settlement_result: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
gateway_cycle_support.save_cycle_results(
|
||||
self,
|
||||
result,
|
||||
date,
|
||||
prices,
|
||||
settlement_result,
|
||||
)
|
||||
|
||||
async def _run_backtest_dates(self, dates: List[str]):
|
||||
await gateway_cycle_support.run_backtest_dates(self, dates)
|
||||
|
||||
def _handle_backtest_exception(self, task: asyncio.Task):
|
||||
gateway_cycle_support.handle_backtest_exception(self, task)
|
||||
|
||||
def _handle_manual_cycle_exception(self, task: asyncio.Task):
|
||||
gateway_cycle_support.handle_manual_cycle_exception(self, task)
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
gateway_cycle_support.set_backtest_dates(self, dates)
|
||||
|
||||
def stop(self):
|
||||
gateway_cycle_support.stop_gateway(self)
|
||||
426
backend/services/gateway_admin_handlers.py
Normal file
426
backend/services/gateway_admin_handlers.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime/workspace/skills handlers extracted from the main Gateway module.
|
||||
|
||||
Deprecated note:
|
||||
Agent/workspace/skill read-write operations are being migrated to
|
||||
agent_service REST endpoints. These websocket handlers remain as a
|
||||
compatibility fallback and should not be considered the primary control
|
||||
plane path for frontend reads/writes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import load_agent_profiles
|
||||
from backend.config.bootstrap_config import (
|
||||
get_bootstrap_config_for_run,
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.llm.models import get_agent_model_info
|
||||
|
||||
|
||||
async def handle_reload_runtime_assets(gateway: Any) -> None:
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
runtime_config = resolve_runtime_config(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
enable_memory=gateway.config.get("enable_memory", False),
|
||||
schedule_mode=gateway.config.get("schedule_mode", "daily"),
|
||||
interval_minutes=gateway.config.get("interval_minutes", 60),
|
||||
trigger_time=gateway.config.get("trigger_time", "09:30"),
|
||||
)
|
||||
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||
runtime_updates = gateway._apply_runtime_config(runtime_config)
|
||||
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
|
||||
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
|
||||
|
||||
|
||||
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
|
||||
if schedule_mode:
|
||||
if schedule_mode not in {"daily", "intraday"}:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
|
||||
return
|
||||
updates["schedule_mode"] = schedule_mode
|
||||
|
||||
interval_minutes = data.get("interval_minutes")
|
||||
if interval_minutes is not None:
|
||||
try:
|
||||
parsed_interval = int(interval_minutes)
|
||||
except (TypeError, ValueError):
|
||||
parsed_interval = 0
|
||||
if parsed_interval <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
|
||||
return
|
||||
updates["interval_minutes"] = parsed_interval
|
||||
|
||||
trigger_time = data.get("trigger_time")
|
||||
if trigger_time is not None:
|
||||
raw_trigger = str(trigger_time).strip()
|
||||
if raw_trigger and raw_trigger != "now":
|
||||
try:
|
||||
datetime.strptime(raw_trigger, "%H:%M")
|
||||
except ValueError:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
|
||||
return
|
||||
updates["trigger_time"] = raw_trigger or "09:30"
|
||||
|
||||
max_comm_cycles = data.get("max_comm_cycles")
|
||||
if max_comm_cycles is not None:
|
||||
try:
|
||||
parsed_cycles = int(max_comm_cycles)
|
||||
except (TypeError, ValueError):
|
||||
parsed_cycles = 0
|
||||
if parsed_cycles <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
|
||||
return
|
||||
updates["max_comm_cycles"] = parsed_cycles
|
||||
|
||||
initial_cash = data.get("initial_cash")
|
||||
if initial_cash is not None:
|
||||
try:
|
||||
parsed_initial_cash = float(initial_cash)
|
||||
except (TypeError, ValueError):
|
||||
parsed_initial_cash = 0.0
|
||||
if parsed_initial_cash <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
|
||||
return
|
||||
updates["initial_cash"] = parsed_initial_cash
|
||||
|
||||
margin_requirement = data.get("margin_requirement")
|
||||
if margin_requirement is not None:
|
||||
try:
|
||||
parsed_margin_requirement = float(margin_requirement)
|
||||
except (TypeError, ValueError):
|
||||
parsed_margin_requirement = -1.0
|
||||
if parsed_margin_requirement < 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
|
||||
return
|
||||
updates["margin_requirement"] = parsed_margin_requirement
|
||||
|
||||
enable_memory = data.get("enable_memory")
|
||||
if enable_memory is not None:
|
||||
updates["enable_memory"] = bool(enable_memory)
|
||||
|
||||
if not updates:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
updates=updates,
|
||||
)
|
||||
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
|
||||
await handle_reload_runtime_assets(gateway)
|
||||
|
||||
|
||||
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
tickers = gateway._normalize_watchlist(data.get("tickers"))
|
||||
if not tickers:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
updates={"tickers": tickers},
|
||||
)
|
||||
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
|
||||
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
|
||||
await handle_reload_runtime_assets(gateway)
|
||||
gateway._schedule_watchlist_market_store_refresh(tickers)
|
||||
|
||||
|
||||
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
if not agent_id:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
||||
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
|
||||
enabled = set(agent_config.enabled_skills)
|
||||
disabled = set(agent_config.disabled_skills)
|
||||
|
||||
payload = []
|
||||
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
|
||||
if item.skill_name in disabled:
|
||||
status = "disabled"
|
||||
elif item.skill_name in enabled:
|
||||
status = "enabled"
|
||||
elif item.skill_name in resolved_skills:
|
||||
status = "active"
|
||||
else:
|
||||
status = "available"
|
||||
payload.append({
|
||||
"skill_name": item.skill_name,
|
||||
"name": item.name,
|
||||
"description": item.description,
|
||||
"version": item.version,
|
||||
"source": item.source,
|
||||
"tools": item.tools,
|
||||
"status": status,
|
||||
})
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_skills_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"skills": payload,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
if not agent_id:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
|
||||
if not isinstance(active_tool_groups, list):
|
||||
active_tool_groups = []
|
||||
disabled_tool_groups = agent_config.disabled_tool_groups
|
||||
if disabled_tool_groups:
|
||||
disabled_set = set(disabled_tool_groups)
|
||||
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
|
||||
|
||||
default_skills = profile.get("skills", [])
|
||||
if not isinstance(default_skills, list):
|
||||
default_skills = []
|
||||
resolved_skills = skills_manager.resolve_agent_skill_names(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
default_skills=default_skills,
|
||||
)
|
||||
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||
model_name, model_provider = get_agent_model_info(agent_id)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_profile_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"profile": {
|
||||
"model_name": model_name,
|
||||
"model_provider": model_provider,
|
||||
"prompt_files": prompt_files,
|
||||
"default_skills": default_skills,
|
||||
"resolved_skills": resolved_skills,
|
||||
"active_tool_groups": active_tool_groups,
|
||||
"disabled_tool_groups": disabled_tool_groups,
|
||||
"enabled_skills": agent_config.enabled_skills,
|
||||
"disabled_skills": agent_config.disabled_skills,
|
||||
},
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
if agent_id:
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
else:
|
||||
detail = skills_manager.load_skill_document(skill_name)
|
||||
except FileNotFoundError:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "skill_detail_loaded",
|
||||
"agent_id": agent_id,
|
||||
"skill": detail,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
except (ValueError, FileExistsError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||
|
||||
|
||||
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
content = data.get("content")
|
||||
if not agent_id or not skill_name or not isinstance(content, str):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||
|
||||
|
||||
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
skill_names = {
|
||||
item.skill_name
|
||||
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||
if item.source != "local"
|
||||
}
|
||||
if skill_name not in skill_names:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
enabled = data.get("enabled")
|
||||
if not agent_id or not skill_name or not isinstance(enabled, bool):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
|
||||
if skill_name not in skill_names:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
if enabled:
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
|
||||
else:
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_skill_updated",
|
||||
"agent_id": agent_id,
|
||||
"skill_name": skill_name,
|
||||
"enabled": enabled,
|
||||
}, ensure_ascii=False))
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||
if not agent_id or not filename:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
content = path.read_text(encoding="utf-8") if path.exists() else ""
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_workspace_file_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||
content = data.get("content")
|
||||
if not agent_id or not filename or not isinstance(content, str):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
path.write_text(content, encoding="utf-8")
|
||||
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
|
||||
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})
|
||||
372
backend/services/gateway_cycle_support.py
Normal file
372
backend/services/gateway_cycle_support.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Cycle and monitoring helpers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_ingest import ingest_symbols, refresh_news_for_symbols
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
|
||||
if not tickers:
|
||||
return
|
||||
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
gateway._watchlist_ingest_task = asyncio.create_task(
|
||||
refresh_market_store_for_watchlist(gateway, tickers),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Refresh the long-lived market store after a watchlist update."""
|
||||
try:
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
ingest_symbols,
|
||||
tickers,
|
||||
mode="incremental",
|
||||
)
|
||||
summary = ", ".join(
|
||||
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||
for item in results
|
||||
)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据已同步: {summary}",
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据同步失败: {exc}",
|
||||
)
|
||||
|
||||
|
||||
async def market_status_monitor(gateway: Any) -> None:
|
||||
"""Periodically check and broadcast market status changes."""
|
||||
while True:
|
||||
try:
|
||||
await gateway.market_service.check_and_broadcast_market_status()
|
||||
|
||||
status = gateway.market_service.get_market_status()
|
||||
if status["status"] == "open" and not gateway.storage.is_live_session_active:
|
||||
gateway.storage.start_live_session()
|
||||
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
|
||||
gateway._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
gateway.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: $%s",
|
||||
f"{gateway._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif status["status"] != "open" and gateway.storage.is_live_session_active:
|
||||
gateway.storage.end_live_session()
|
||||
gateway._session_start_portfolio_value = None
|
||||
|
||||
if gateway.storage.is_live_session_active:
|
||||
await update_and_broadcast_live_returns(gateway)
|
||||
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Market status monitor error: %s", exc)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
async def update_and_broadcast_live_returns(gateway: Any) -> None:
|
||||
"""Calculate and broadcast live returns for current session."""
|
||||
if not gateway.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
prices = gateway.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
state = gateway.storage.load_internal_state()
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = baseline_history[-1]["v"] if baseline_history else None
|
||||
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
current_momentum = momentum_history[-1]["v"] if momentum_history else None
|
||||
|
||||
point = gateway.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
if point:
|
||||
live_returns = gateway.storage.get_live_returns()
|
||||
await gateway.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def on_strategy_trigger(gateway: Any, date: str) -> None:
|
||||
"""Handle trading cycle trigger."""
|
||||
if gateway._cycle_lock.locked():
|
||||
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
||||
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
|
||||
return
|
||||
|
||||
async with gateway._cycle_lock:
|
||||
logger.info("Strategy triggered for %s", date)
|
||||
tickers = gateway.config.get("tickers", [])
|
||||
if gateway.is_backtest:
|
||||
await run_backtest_cycle(gateway, date, tickers)
|
||||
else:
|
||||
await run_live_cycle(gateway, date, tickers)
|
||||
|
||||
|
||||
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
|
||||
"""Run lightweight heartbeat check for all analysts."""
|
||||
logger.info("[Heartbeat] Running heartbeat check for %s", date)
|
||||
analysts = gateway.pipeline._all_analysts()
|
||||
|
||||
for analyst in analysts:
|
||||
try:
|
||||
logger.debug(
|
||||
"[Heartbeat] No heartbeat configured for %s, skipping",
|
||||
analyst.name,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
|
||||
|
||||
|
||||
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
gateway.market_service.set_backtest_date(date)
|
||||
await gateway.market_service.emit_market_open()
|
||||
|
||||
await gateway.state_sync.on_cycle_start(date)
|
||||
|
||||
prices = gateway.market_service.get_open_prices()
|
||||
close_prices = gateway.market_service.get_close_prices()
|
||||
market_caps = await get_market_caps(gateway, tickers, date)
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await gateway.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, date)
|
||||
|
||||
|
||||
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
trading_date = gateway.market_service.get_live_trading_date()
|
||||
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
|
||||
|
||||
try:
|
||||
news_refresh = await asyncio.to_thread(
|
||||
refresh_news_for_symbols,
|
||||
tickers,
|
||||
end_date=trading_date,
|
||||
store=gateway.storage.market_store,
|
||||
)
|
||||
logger.info(
|
||||
"News refresh complete: %s",
|
||||
", ".join(
|
||||
f"{item['symbol']} news={item['news']}"
|
||||
for item in news_refresh
|
||||
) or "no symbols",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Live cycle news refresh failed: %s", exc)
|
||||
|
||||
await gateway.state_sync.on_cycle_start(trading_date)
|
||||
|
||||
market_caps = await get_market_caps(gateway, tickers, trading_date)
|
||||
schedule_mode = gateway.config.get("schedule_mode", "daily")
|
||||
market_status = gateway.market_service.get_market_status()
|
||||
current_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
if schedule_mode == "intraday":
|
||||
execute_decisions = market_status.get("status") == "open"
|
||||
if execute_decisions:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
|
||||
else:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
prices=current_prices,
|
||||
market_caps=market_caps,
|
||||
execute_decisions=execute_decisions,
|
||||
)
|
||||
close_prices = current_prices
|
||||
else:
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
|
||||
)
|
||||
close_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, trading_date)
|
||||
|
||||
|
||||
async def finalize_cycle(gateway: Any, date: str) -> None:
|
||||
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
|
||||
summary = dashboard_snapshot.get("summary") or {}
|
||||
if gateway.storage.is_live_session_active:
|
||||
summary.update(gateway.storage.get_live_returns())
|
||||
|
||||
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
leaderboard = dashboard_snapshot.get("leaderboard") or []
|
||||
if leaderboard:
|
||||
await gateway.state_sync.on_leaderboard_update(leaderboard)
|
||||
|
||||
|
||||
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
|
||||
market_caps: dict[str, float] = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = None
|
||||
response = await gateway._call_trading_service(
|
||||
f"get_market_cap for {ticker}",
|
||||
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
|
||||
)
|
||||
if response is not None:
|
||||
market_cap = response.get("market_cap")
|
||||
if market_cap is None:
|
||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||
market_cap = payload.get("market_cap")
|
||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||
market_caps[ticker] = 1e9
|
||||
return market_caps
|
||||
|
||||
|
||||
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await gateway.state_sync.on_holdings_update(holdings)
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await gateway.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await gateway.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
|
||||
def save_cycle_results(
|
||||
gateway: Any,
|
||||
result: dict[str, Any],
|
||||
date: str,
|
||||
prices: dict[str, float],
|
||||
settlement_result: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
|
||||
if portfolio:
|
||||
gateway.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
|
||||
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
|
||||
try:
|
||||
for date in dates:
|
||||
await gateway.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
|
||||
except Exception as exc:
|
||||
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
|
||||
raise
|
||||
finally:
|
||||
gateway._backtest_task = None
|
||||
|
||||
|
||||
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
gateway._manual_cycle_task = None
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Manual cycle task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
gateway._backtest_start_date = dates[0]
|
||||
gateway._backtest_end_date = dates[-1]
|
||||
|
||||
|
||||
def stop_gateway(gateway: Any) -> None:
|
||||
gateway.state_sync.save_state()
|
||||
gateway.market_service.stop()
|
||||
if gateway._backtest_task:
|
||||
gateway._backtest_task.cancel()
|
||||
if gateway._market_status_task:
|
||||
gateway._market_status_task.cancel()
|
||||
if gateway._watchlist_ingest_task:
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
# Close OpenClaw WebSocket connection
|
||||
if gateway._openclaw_ws:
|
||||
import asyncio
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
loop.create_task(gateway._openclaw_ws.disconnect())
|
||||
else:
|
||||
loop.run_until_complete(gateway._openclaw_ws.disconnect())
|
||||
except Exception:
|
||||
pass
|
||||
534
backend/services/gateway_openclaw_handlers.py
Normal file
534
backend/services/gateway_openclaw_handlers.py
Normal file
@@ -0,0 +1,534 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""OpenClaw WebSocket handlers — gateway calls OpenClaw Gateway via WebSocket."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.services.gateway import Gateway
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ensure_session_bridge(gateway) -> None:
|
||||
"""Forward OpenClaw session events into 大时代 frontend websockets."""
|
||||
if getattr(gateway, "_openclaw_session_bridge_ready", False):
|
||||
return
|
||||
|
||||
async def _forward(event) -> None:
|
||||
payload = event.payload or {}
|
||||
session_key = str(payload.get("sessionKey") or payload.get("key") or "").strip()
|
||||
if not session_key:
|
||||
return
|
||||
|
||||
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
|
||||
targets = [
|
||||
ws
|
||||
for ws, session_keys in list(subscriber_map.items())
|
||||
if session_key in session_keys
|
||||
]
|
||||
if not targets:
|
||||
return
|
||||
|
||||
message = json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_event",
|
||||
"event": event.event,
|
||||
"session_key": session_key,
|
||||
"payload": payload,
|
||||
}
|
||||
)
|
||||
stale = []
|
||||
for ws in targets:
|
||||
try:
|
||||
await ws.send(message)
|
||||
except Exception:
|
||||
stale.append(ws)
|
||||
|
||||
for ws in stale:
|
||||
try:
|
||||
subscriber_map.pop(ws, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _handler(event) -> None:
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(_forward(event))
|
||||
except Exception as exc:
|
||||
logger.debug("OpenClaw session bridge skipped event: %s", exc)
|
||||
|
||||
client = _get_ws_client(gateway)
|
||||
client.add_event_handler(_handler)
|
||||
gateway._openclaw_session_bridge_ready = True
|
||||
gateway._openclaw_session_bridge_handler = _handler
|
||||
if not hasattr(gateway, "_openclaw_session_subscribers"):
|
||||
gateway._openclaw_session_subscribers = {}
|
||||
|
||||
|
||||
def _get_ws_client(gateway) -> "OpenClawWebSocketClient":
|
||||
"""Get the OpenClaw WebSocket client from gateway."""
|
||||
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient
|
||||
client = gateway._openclaw_ws
|
||||
if client is None:
|
||||
raise RuntimeError("OpenClaw Gateway not connected")
|
||||
return client
|
||||
|
||||
|
||||
async def _ws_call(gateway, method: str, params: dict | None = None) -> dict:
|
||||
"""Call OpenClaw Gateway via WebSocket and return result."""
|
||||
try:
|
||||
client = _get_ws_client(gateway)
|
||||
return await client.call_method(method, params)
|
||||
except Exception as exc:
|
||||
logger.warning("OpenClaw Gateway call failed for %s: %s", method, exc)
|
||||
return {"error": str(exc)[:200]}
|
||||
|
||||
|
||||
async def handle_get_openclaw_status(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "status")
|
||||
await websocket.send(json.dumps({"type": "openclaw_status_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_sessions(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "sessions.list", {"limit": 50, "includeLastMessage": True})
|
||||
await websocket.send(json.dumps({"type": "openclaw_sessions_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_session_detail(gateway, websocket, data: dict) -> None:
|
||||
session_key = data.get("session_key", "")
|
||||
result = await _ws_call(gateway, "sessions.list", {"limit": 200, "includeLastMessage": True})
|
||||
session = None
|
||||
if isinstance(result, dict):
|
||||
for item in result.get("sessions", []) or []:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("key") == session_key or item.get("sessionKey") == session_key:
|
||||
session = item
|
||||
break
|
||||
await websocket.send(json.dumps({
|
||||
"type": "openclaw_session_detail_loaded",
|
||||
"data": {"session": session, "error": None if session else f"session '{session_key}' not found"},
|
||||
"session_key": session_key,
|
||||
}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_session_history(gateway, websocket, data: dict) -> None:
|
||||
session_key = data.get("session_key", "")
|
||||
limit = data.get("limit", 20)
|
||||
try:
|
||||
from backend.services.openclaw_cli import OpenClawCliService
|
||||
|
||||
result = OpenClawCliService().get_session_history_model(session_key, limit=limit)
|
||||
payload = {
|
||||
"session_key": result.session_key,
|
||||
"session_id": result.session_id,
|
||||
"history": result.events,
|
||||
"events": result.events,
|
||||
"raw_text": result.raw_text,
|
||||
}
|
||||
except Exception as exc:
|
||||
payload = {"error": str(exc)[:200], "history": []}
|
||||
await websocket.send(json.dumps({
|
||||
"type": "openclaw_session_history_loaded",
|
||||
"data": payload,
|
||||
"session_key": session_key,
|
||||
}))
|
||||
|
||||
|
||||
async def handle_openclaw_resolve_session(gateway, websocket, data: dict) -> None:
|
||||
params = {}
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
label = str(data.get("label") or "").strip()
|
||||
channel = str(data.get("channel") or "").strip()
|
||||
if agent_id:
|
||||
params["agentId"] = agent_id
|
||||
if label:
|
||||
params["label"] = label
|
||||
if channel:
|
||||
params["channel"] = channel
|
||||
params["includeGlobal"] = bool(data.get("include_global", True))
|
||||
result = await _ws_call(gateway, "sessions.resolve", params)
|
||||
await websocket.send(json.dumps({"type": "openclaw_session_resolved", "data": result}))
|
||||
|
||||
|
||||
async def handle_openclaw_create_session(gateway, websocket, data: dict) -> None:
|
||||
params = {}
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
label = str(data.get("label") or "").strip()
|
||||
model = str(data.get("model") or "").strip()
|
||||
initial_message = str(data.get("initial_message") or "").strip()
|
||||
if agent_id:
|
||||
params["agentId"] = agent_id
|
||||
if label:
|
||||
params["label"] = label
|
||||
if model:
|
||||
params["model"] = model
|
||||
if initial_message:
|
||||
params["message"] = initial_message
|
||||
result = await _ws_call(gateway, "sessions.create", params)
|
||||
await websocket.send(json.dumps({"type": "openclaw_session_created", "data": result}))
|
||||
|
||||
|
||||
async def handle_openclaw_send_message(gateway, websocket, data: dict) -> None:
|
||||
session_key = str(data.get("session_key") or "").strip()
|
||||
message = str(data.get("message") or "").strip()
|
||||
thinking = str(data.get("thinking") or "").strip()
|
||||
if not session_key or not message:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_message_sent",
|
||||
"data": {"error": "session_key and message are required"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
params = {"key": session_key, "message": message}
|
||||
if thinking:
|
||||
params["thinking"] = thinking
|
||||
result = await _ws_call(gateway, "sessions.send", params)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_message_sent",
|
||||
"data": result,
|
||||
"session_key": session_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_openclaw_subscribe_session(gateway, websocket, data: dict) -> None:
|
||||
session_key = str(data.get("session_key") or "").strip()
|
||||
if not session_key:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_subscribed",
|
||||
"data": {"error": "session_key is required"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
_ensure_session_bridge(gateway)
|
||||
result = await _ws_call(gateway, "sessions.messages.subscribe", {"key": session_key})
|
||||
if not isinstance(result, dict) or not result.get("error"):
|
||||
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
|
||||
subscriber_map.setdefault(websocket, set()).add(session_key)
|
||||
gateway._openclaw_session_subscribers = subscriber_map
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_subscribed",
|
||||
"data": result,
|
||||
"session_key": session_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_openclaw_unsubscribe_session(gateway, websocket, data: dict) -> None:
|
||||
session_key = str(data.get("session_key") or "").strip()
|
||||
if not session_key:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_unsubscribed",
|
||||
"data": {"error": "session_key is required"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
result = await _ws_call(gateway, "sessions.messages.unsubscribe", {"key": session_key})
|
||||
subscriber_map = getattr(gateway, "_openclaw_session_subscribers", {})
|
||||
session_keys = subscriber_map.get(websocket)
|
||||
if isinstance(session_keys, set):
|
||||
session_keys.discard(session_key)
|
||||
if not session_keys:
|
||||
subscriber_map.pop(websocket, None)
|
||||
gateway._openclaw_session_subscribers = subscriber_map
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_unsubscribed",
|
||||
"data": result,
|
||||
"session_key": session_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_openclaw_reset_session(gateway, websocket, data: dict) -> None:
|
||||
session_key = str(data.get("session_key") or "").strip()
|
||||
if not session_key:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_reset",
|
||||
"data": {"error": "session_key is required"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
result = await _ws_call(gateway, "sessions.reset", {"key": session_key})
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_reset",
|
||||
"data": result,
|
||||
"session_key": session_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_openclaw_delete_session(gateway, websocket, data: dict) -> None:
|
||||
session_key = str(data.get("session_key") or "").strip()
|
||||
if not session_key:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_deleted",
|
||||
"data": {"error": "session_key is required"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
result = await _ws_call(gateway, "sessions.delete", {"key": session_key})
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "openclaw_session_deleted",
|
||||
"data": result,
|
||||
"session_key": session_key,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def handle_get_openclaw_cron(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "cron.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_cron_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_approvals(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "exec.approvals.get")
|
||||
await websocket.send(json.dumps({"type": "openclaw_approvals_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_agents(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "agents.list")
|
||||
sessions_result = await _ws_call(
|
||||
gateway,
|
||||
"sessions.list",
|
||||
{"limit": 200, "includeLastMessage": True},
|
||||
)
|
||||
config_result = await _ws_call(gateway, "config.get")
|
||||
session_model_by_agent: dict[str, str] = {}
|
||||
default_session_model: str | None = None
|
||||
agent_skills_by_id: dict[str, list[str] | None] = {}
|
||||
default_agent_skills: list[str] | None = None
|
||||
|
||||
parsed_config = config_result.get("parsed") if isinstance(config_result, dict) else None
|
||||
if isinstance(parsed_config, dict):
|
||||
agents_cfg = parsed_config.get("agents")
|
||||
if isinstance(agents_cfg, dict):
|
||||
defaults_cfg = agents_cfg.get("defaults")
|
||||
if isinstance(defaults_cfg, dict):
|
||||
default_skills = defaults_cfg.get("skills")
|
||||
if isinstance(default_skills, list):
|
||||
default_agent_skills = [
|
||||
str(skill).strip()
|
||||
for skill in default_skills
|
||||
if str(skill).strip()
|
||||
]
|
||||
list_cfg = agents_cfg.get("list")
|
||||
if isinstance(list_cfg, list):
|
||||
for entry in list_cfg:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
agent_id = str(entry.get("id") or "").strip()
|
||||
if not agent_id:
|
||||
continue
|
||||
skills = entry.get("skills")
|
||||
if isinstance(skills, list):
|
||||
agent_skills_by_id[agent_id] = [
|
||||
str(skill).strip()
|
||||
for skill in skills
|
||||
if str(skill).strip()
|
||||
]
|
||||
elif skills == []:
|
||||
agent_skills_by_id[agent_id] = []
|
||||
|
||||
if isinstance(sessions_result, dict) and isinstance(sessions_result.get("sessions"), list):
|
||||
defaults = sessions_result.get("defaults")
|
||||
if isinstance(defaults, dict):
|
||||
value = (
|
||||
defaults.get("model")
|
||||
or defaults.get("modelName")
|
||||
or defaults.get("model_name")
|
||||
)
|
||||
if value:
|
||||
default_session_model = str(value)
|
||||
for session in sessions_result.get("sessions", []):
|
||||
if not isinstance(session, dict):
|
||||
continue
|
||||
agent_id = str(
|
||||
session.get("agentId")
|
||||
or session.get("agent_id")
|
||||
or ""
|
||||
).strip()
|
||||
if not agent_id:
|
||||
key = str(session.get("key") or session.get("sessionKey") or "").strip()
|
||||
parts = key.split(":")
|
||||
if len(parts) >= 3 and parts[0] == "agent":
|
||||
agent_id = parts[1]
|
||||
model_value = (
|
||||
session.get("model")
|
||||
or session.get("modelName")
|
||||
or session.get("model_name")
|
||||
or session.get("resolvedModel")
|
||||
or session.get("resolved_model")
|
||||
or session.get("defaultModel")
|
||||
or session.get("default_model")
|
||||
)
|
||||
if agent_id and model_value and agent_id not in session_model_by_agent:
|
||||
session_model_by_agent[agent_id] = str(model_value)
|
||||
|
||||
if isinstance(result, dict) and isinstance(result.get("agents"), list):
|
||||
normalized_agents = []
|
||||
for agent in result.get("agents", []):
|
||||
if not isinstance(agent, dict):
|
||||
normalized_agents.append(agent)
|
||||
continue
|
||||
normalized = dict(agent)
|
||||
if not normalized.get("model"):
|
||||
normalized["model"] = (
|
||||
normalized.get("modelName")
|
||||
or normalized.get("model_name")
|
||||
or normalized.get("resolvedModel")
|
||||
or normalized.get("resolved_model")
|
||||
or normalized.get("defaultModel")
|
||||
or normalized.get("default_model")
|
||||
or session_model_by_agent.get(str(normalized.get("id") or "").strip())
|
||||
or default_session_model
|
||||
)
|
||||
agent_id = str(normalized.get("id") or "").strip()
|
||||
if "skills" not in normalized:
|
||||
normalized["skills"] = agent_skills_by_id.get(agent_id, default_agent_skills)
|
||||
normalized_agents.append(normalized)
|
||||
result = {**result, "agents": normalized_agents}
|
||||
await websocket.send(json.dumps({"type": "openclaw_agents_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_agents_presence(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "node.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_agents_presence_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_skills(gateway, websocket, data: dict) -> None:
|
||||
agent_id = str(data.get("agent_id") or "").strip()
|
||||
params = {"agentId": agent_id} if agent_id else {}
|
||||
result = await _ws_call(gateway, "skills.status", params)
|
||||
await websocket.send(json.dumps({"type": "openclaw_skills_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_models(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "models.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_models_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_hooks(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "tools.catalog")
|
||||
await websocket.send(json.dumps({"type": "openclaw_hooks_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_plugins(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "config.get")
|
||||
await websocket.send(json.dumps({"type": "openclaw_plugins_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_secrets_audit(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "secrets.reload")
|
||||
await websocket.send(json.dumps({"type": "openclaw_secrets_audit_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_security_audit(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "gateway.identity.get")
|
||||
await websocket.send(json.dumps({"type": "openclaw_security_audit_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_daemon_status(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "doctor.memory.status")
|
||||
await websocket.send(json.dumps({"type": "openclaw_daemon_status_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_pairing(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "device.pair.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_pairing_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_qr(gateway, websocket, data: dict) -> None:
|
||||
await websocket.send(json.dumps({"type": "openclaw_qr_loaded", "data": {"error": "QR code not available via WebSocket"}}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_update_status(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "update.run")
|
||||
await websocket.send(json.dumps({"type": "openclaw_update_status_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_models_aliases(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "models.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_models_aliases_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_models_fallbacks(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "models.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_models_fallbacks_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_models_image_fallbacks(gateway, websocket, data: dict) -> None:
|
||||
result = await _ws_call(gateway, "models.list")
|
||||
await websocket.send(json.dumps({"type": "openclaw_models_image_fallbacks_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_skill_update(gateway, websocket, data: dict) -> None:
|
||||
slug = data.get("slug")
|
||||
all_flag = data.get("all", False)
|
||||
params = {}
|
||||
if slug is not None:
|
||||
params["slug"] = slug
|
||||
if all_flag:
|
||||
params["all"] = "true"
|
||||
result = await _ws_call(gateway, "skills.update", params)
|
||||
await websocket.send(json.dumps({"type": "openclaw_skill_update_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_workspace_files(gateway, websocket, data: dict) -> None:
|
||||
raw_workspace = data.get("workspace", "")
|
||||
# Use the workspace param (which is actually the agent.id from frontend) as agent_id
|
||||
agent_id = raw_workspace or "main"
|
||||
result = await _ws_call(gateway, "agents.files.list", {"agentId": agent_id})
|
||||
if isinstance(result, dict):
|
||||
result["workspace"] = agent_id
|
||||
await websocket.send(json.dumps({"type": "openclaw_workspace_files_loaded", "data": result}))
|
||||
|
||||
|
||||
async def handle_get_openclaw_workspace_file(gateway, websocket, data: dict) -> None:
|
||||
agent_id = data.get("agent_id", "main")
|
||||
file_name = data.get("file_name", "")
|
||||
if not file_name:
|
||||
await websocket.send(json.dumps({"type": "openclaw_workspace_file_loaded", "data": {"error": "file_name is required"}}))
|
||||
return
|
||||
result = await _ws_call(gateway, "agents.files.get", {"agentId": agent_id, "name": file_name})
|
||||
await websocket.send(json.dumps({"type": "openclaw_workspace_file_loaded", "data": result}))
|
||||
161
backend/services/gateway_runtime_support.py
Normal file
161
backend/services/gateway_runtime_support.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime/state support helpers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
|
||||
def normalize_watchlist(raw_tickers: Any) -> list[str]:
|
||||
"""Parse watchlist payloads from websocket messages."""
|
||||
if raw_tickers is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_tickers, str):
|
||||
candidates = raw_tickers.split(",")
|
||||
elif isinstance(raw_tickers, list):
|
||||
candidates = raw_tickers
|
||||
else:
|
||||
candidates = [raw_tickers]
|
||||
|
||||
tickers: list[str] = []
|
||||
for candidate in candidates:
|
||||
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||
if symbol and symbol not in tickers:
|
||||
tickers.append(symbol)
|
||||
return tickers
|
||||
|
||||
|
||||
def normalize_agent_workspace_filename(
|
||||
raw_name: Any,
|
||||
*,
|
||||
allowlist: set[str],
|
||||
) -> str | None:
|
||||
"""Restrict editable workspace files to a safe allowlist."""
|
||||
filename = str(raw_name or "").strip()
|
||||
if filename in allowlist:
|
||||
return filename
|
||||
return None
|
||||
|
||||
|
||||
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply runtime config to gateway-owned services and state."""
|
||||
warnings: list[str] = []
|
||||
|
||||
ticker_changes = gateway.market_service.update_tickers(
|
||||
runtime_config.get("tickers", []),
|
||||
)
|
||||
gateway.config["tickers"] = ticker_changes["active"]
|
||||
|
||||
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
|
||||
gateway.config["schedule_mode"] = runtime_config.get(
|
||||
"schedule_mode",
|
||||
gateway.config.get("schedule_mode", "daily"),
|
||||
)
|
||||
gateway.config["interval_minutes"] = int(
|
||||
runtime_config.get(
|
||||
"interval_minutes",
|
||||
gateway.config.get("interval_minutes", 60),
|
||||
),
|
||||
)
|
||||
gateway.config["trigger_time"] = runtime_config.get(
|
||||
"trigger_time",
|
||||
gateway.config.get("trigger_time", "09:30"),
|
||||
)
|
||||
|
||||
if gateway.scheduler:
|
||||
gateway.scheduler.reconfigure(
|
||||
mode=gateway.config["schedule_mode"],
|
||||
trigger_time=gateway.config["trigger_time"],
|
||||
interval_minutes=gateway.config["interval_minutes"],
|
||||
)
|
||||
|
||||
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||
margin_requirement=runtime_config["margin_requirement"],
|
||||
)
|
||||
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
|
||||
"margin_requirement",
|
||||
runtime_config["margin_requirement"],
|
||||
)
|
||||
|
||||
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||
current_initial_cash = float(gateway.storage.initial_cash)
|
||||
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||
if not initial_cash_applied:
|
||||
if (
|
||||
gateway.storage.can_apply_initial_cash()
|
||||
and gateway.pipeline.pm.can_apply_initial_cash()
|
||||
):
|
||||
initial_cash_applied = gateway.storage.apply_initial_cash(
|
||||
requested_initial_cash,
|
||||
)
|
||||
if initial_cash_applied:
|
||||
gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||
initial_cash=requested_initial_cash,
|
||||
)
|
||||
gateway.config["initial_cash"] = gateway.storage.initial_cash
|
||||
else:
|
||||
warnings.append(
|
||||
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||
"because the run already has positions, margin usage, or trades.",
|
||||
)
|
||||
|
||||
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||
current_enable_memory = bool(gateway.config.get("enable_memory", False))
|
||||
if requested_enable_memory != current_enable_memory:
|
||||
warnings.append(
|
||||
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||
"because long-term memory contexts are created at startup.",
|
||||
)
|
||||
|
||||
sync_runtime_state(gateway)
|
||||
|
||||
return {
|
||||
"runtime_config_requested": runtime_config,
|
||||
"runtime_config_applied": {
|
||||
"tickers": list(gateway.config.get("tickers", [])),
|
||||
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||
"initial_cash": gateway.storage.initial_cash,
|
||||
"margin_requirement": gateway.config["margin_requirement"],
|
||||
"max_comm_cycles": gateway.config["max_comm_cycles"],
|
||||
"enable_memory": gateway.config.get("enable_memory", False),
|
||||
},
|
||||
"runtime_config_status": {
|
||||
"tickers": True,
|
||||
"schedule_mode": True,
|
||||
"interval_minutes": True,
|
||||
"trigger_time": True,
|
||||
"initial_cash": initial_cash_applied,
|
||||
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||
"max_comm_cycles": True,
|
||||
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||
},
|
||||
"ticker_changes": ticker_changes,
|
||||
"runtime_config_warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
def sync_runtime_state(gateway: Any) -> None:
|
||||
"""Refresh persisted state after runtime config changes."""
|
||||
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
|
||||
gateway.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": gateway.config.get("tickers", []),
|
||||
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||
"initial_cash": gateway.storage.initial_cash,
|
||||
"margin_requirement": gateway.config.get("margin_requirement"),
|
||||
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
|
||||
"enable_memory": gateway.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
|
||||
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
|
||||
gateway.state_sync.save_state()
|
||||
716
backend/services/gateway_stock_handlers.py
Normal file
716
backend/services/gateway_stock_handlers.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Stock-related Gateway handlers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.domains import news as news_domain
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.tools.data_tools import prices_to_df
|
||||
from shared.client import NewsServiceClient, TradingServiceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": "",
|
||||
"prices": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
prices = []
|
||||
source = "polygon"
|
||||
response = await gateway._call_trading_service(
|
||||
"get_prices for history",
|
||||
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||
)
|
||||
if response is not None:
|
||||
prices = response.prices
|
||||
source = "trading_service"
|
||||
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||
if not prices:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_prices_payload,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
await asyncio.to_thread(
|
||||
gateway.storage.market_store.upsert_ohlc,
|
||||
ticker,
|
||||
[price.model_dump() for price in prices],
|
||||
source=source or "provider",
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_explain_events_loaded",
|
||||
"ticker": ticker,
|
||||
"events": snapshot.get("events", []),
|
||||
"signals": snapshot.get("signals", []),
|
||||
"trades": snapshot.get("trades", []),
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": "",
|
||||
"news": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 30)
|
||||
limit = data.get("limit", 12)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 180))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 30
|
||||
try:
|
||||
limit = max(1, min(int(limit), 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 12
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
news_rows = []
|
||||
source = "polygon"
|
||||
response = await gateway._call_news_service(
|
||||
"get_enriched_news",
|
||||
lambda client: client.get_enriched_news(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
news_rows = response.get("news") or []
|
||||
source = "news_service"
|
||||
|
||||
if not news_rows:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_enriched_news,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=max(limit, 50),
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
news_rows = (payload.get("news") or [])[-limit:]
|
||||
source = "market_store"
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": ticker,
|
||||
"news": news_rows[-limit:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
trade_date = str(data.get("date") or "").strip()
|
||||
if not ticker or not trade_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": [],
|
||||
"error": "ticker and date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
limit = data.get("limit", 20)
|
||||
try:
|
||||
limit = max(1, min(int(limit), 50))
|
||||
except (TypeError, ValueError):
|
||||
limit = 20
|
||||
|
||||
source = "market_store"
|
||||
news_rows = []
|
||||
response = await gateway._call_news_service(
|
||||
"get_news_for_date",
|
||||
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
|
||||
)
|
||||
if response is not None:
|
||||
news_rows = response.get("news") or []
|
||||
source = "news_service"
|
||||
|
||||
if not news_rows:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_for_date,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=trade_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
news_rows = payload.get("news") or []
|
||||
source = "market_store"
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": news_rows,
|
||||
"source": source,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": "",
|
||||
"timeline": [],
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
timeline = []
|
||||
response = await gateway._call_news_service(
|
||||
"get_news_timeline",
|
||||
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||
)
|
||||
if response is not None:
|
||||
timeline = response.get("timeline") or []
|
||||
|
||||
if not timeline:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_timeline,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
timeline = payload.get("timeline") or []
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": "",
|
||||
"categories": {},
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
categories = {}
|
||||
response = await gateway._call_news_service(
|
||||
"get_categories",
|
||||
lambda client: client.get_categories(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
categories = response.get("categories") or {}
|
||||
|
||||
if not categories:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_categories,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
categories = payload.get("categories") or {}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": ticker,
|
||||
"categories": categories,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()
|
||||
end_date = str(data.get("end_date") or "").strip()
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": {"error": "ticker, start_date, end_date are required"},
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
article_ids = data.get("article_ids")
|
||||
result = None
|
||||
response = await gateway._call_news_service(
|
||||
"get_range_explain",
|
||||
lambda client: client.get_range_explain(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||
limit=100,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
result = response.get("result")
|
||||
|
||||
if result is None:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_range_explain_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||
limit=100,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
result = payload.get("result")
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": result,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": "",
|
||||
"trades": [],
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
limit = int(data.get("limit", 50))
|
||||
|
||||
trades = []
|
||||
response = await gateway._call_trading_service(
|
||||
"get_insider_trades",
|
||||
lambda client: client.get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
trades = response.insider_trades
|
||||
|
||||
if not trades:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_insider_trades_payload,
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
)
|
||||
trades = payload.get("insider_trades") or []
|
||||
|
||||
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||
formatted_trades = [{
|
||||
"ticker": t.ticker,
|
||||
"name": t.name,
|
||||
"title": t.title,
|
||||
"is_board_director": t.is_board_director,
|
||||
"transaction_date": t.transaction_date,
|
||||
"transaction_shares": t.transaction_shares,
|
||||
"transaction_price_per_share": t.transaction_price_per_share,
|
||||
"transaction_value": t.transaction_value,
|
||||
"shares_owned_before_transaction": t.shares_owned_before_transaction,
|
||||
"shares_owned_after_transaction": t.shares_owned_after_transaction,
|
||||
"security_title": t.security_title,
|
||||
"filing_date": t.filing_date,
|
||||
"holding_change": (
|
||||
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
|
||||
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
|
||||
),
|
||||
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
|
||||
} for t in sorted_trades]
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date or None,
|
||||
"end_date": end_date,
|
||||
"trades": formatted_trades,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": "",
|
||||
"story": "",
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||
result = await gateway._call_news_service(
|
||||
"get_story",
|
||||
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
|
||||
)
|
||||
if result is None:
|
||||
result = await asyncio.to_thread(
|
||||
news_domain.get_story_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": ticker,
|
||||
"as_of_date": as_of_date,
|
||||
"story": result.get("story") or "",
|
||||
"source": result.get("source") or "local",
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
target_date = str(data.get("date") or "").strip()[:10]
|
||||
if not ticker or not target_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
"items": [],
|
||||
"error": "ticker and date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
top_k = data.get("top_k", 8)
|
||||
try:
|
||||
top_k = max(1, min(int(top_k), 20))
|
||||
except (TypeError, ValueError):
|
||||
top_k = 8
|
||||
|
||||
result = await gateway._call_news_service(
|
||||
"get_similar_days",
|
||||
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
|
||||
)
|
||||
if result is None:
|
||||
result = await asyncio.to_thread(
|
||||
news_domain.get_similar_days_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=target_date,
|
||||
n_similar=top_k,
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
**result,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "ticker is required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
try:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=250)
|
||||
|
||||
prices = None
|
||||
response = await gateway._call_trading_service(
|
||||
"get_prices",
|
||||
lambda client: client.get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
prices = response.prices
|
||||
|
||||
if prices is None:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
|
||||
if not prices or len(prices) < 20:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "Insufficient price data",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
df = prices_to_df(prices)
|
||||
signal = gateway._technical_analyzer.analyze(ticker, df)
|
||||
|
||||
import pandas as pd
|
||||
df_sorted = df.sort_values("time").reset_index(drop=True)
|
||||
df_sorted["returns"] = df_sorted["close"].pct_change()
|
||||
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
|
||||
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
|
||||
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
|
||||
ma_distance = {}
|
||||
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
|
||||
ma_value = getattr(signal, ma_key, None)
|
||||
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
|
||||
|
||||
indicators = {
|
||||
"ticker": ticker,
|
||||
"current_price": signal.current_price,
|
||||
"ma": {
|
||||
"ma5": signal.ma5,
|
||||
"ma10": signal.ma10,
|
||||
"ma20": signal.ma20,
|
||||
"ma50": signal.ma50,
|
||||
"ma200": signal.ma200,
|
||||
"distance": ma_distance,
|
||||
},
|
||||
"rsi": {
|
||||
"rsi14": signal.rsi14,
|
||||
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
|
||||
},
|
||||
"macd": {
|
||||
"macd": signal.macd,
|
||||
"signal": signal.macd_signal,
|
||||
"histogram": signal.macd - signal.macd_signal,
|
||||
},
|
||||
"bollinger": {
|
||||
"upper": signal.bollinger_upper,
|
||||
"mid": signal.bollinger_mid,
|
||||
"lower": signal.bollinger_lower,
|
||||
},
|
||||
"volatility": {
|
||||
"vol_10d": vol_10,
|
||||
"vol_20d": vol_20,
|
||||
"vol_60d": vol_60,
|
||||
"annualized": signal.annualized_volatility_pct,
|
||||
"risk_level": signal.risk_level,
|
||||
},
|
||||
"trend": signal.trend,
|
||||
"mean_reversion": signal.mean_reversion_signal,
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": indicators,
|
||||
}, ensure_ascii=False, default=str))
|
||||
except Exception as exc:
|
||||
logger.exception("Error getting technical indicators for %s", ticker)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": str(exc),
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
end_date = str(data.get("end_date") or "").strip()[:10]
|
||||
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
||||
target_date = str(data.get("target_date") or "").strip()[:10]
|
||||
force = bool(data.get("force", False))
|
||||
rebuild_story = bool(data.get("rebuild_story", True))
|
||||
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
||||
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
||||
limit = data.get("limit", 200)
|
||||
|
||||
try:
|
||||
limit = max(10, min(int(limit), 500))
|
||||
except (TypeError, ValueError):
|
||||
limit = 200
|
||||
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "ticker, start_date, end_date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
if only_local_to_llm and not llm_enrichment_enabled():
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
gateway.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
skip_existing=not force,
|
||||
only_reanalyze_local=only_local_to_llm,
|
||||
)
|
||||
|
||||
story_status = None
|
||||
if rebuild_story and story_date:
|
||||
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
|
||||
story_result = await asyncio.to_thread(
|
||||
news_domain.get_story_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
|
||||
|
||||
similar_status = None
|
||||
if rebuild_similar_days and target_date:
|
||||
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
|
||||
similar_result = await asyncio.to_thread(
|
||||
news_domain.get_similar_days_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=target_date,
|
||||
n_similar=8,
|
||||
)
|
||||
similar_status = {
|
||||
"target_date": target_date,
|
||||
"count": len(similar_result.get("items") or []),
|
||||
"error": similar_result.get("error"),
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"story_date": story_date or None,
|
||||
"target_date": target_date or None,
|
||||
"force": force,
|
||||
"only_local_to_llm": only_local_to_llm,
|
||||
"stats": result,
|
||||
"story_status": story_status,
|
||||
"similar_status": similar_status,
|
||||
}, ensure_ascii=False, default=str))
|
||||
687
backend/services/market.py
Normal file
687
backend/services/market.py
Normal file
@@ -0,0 +1,687 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Market Data Service
|
||||
Supports live and backtest modes
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
from backend.config.data_config import get_data_sources
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone and calendar
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class MarketStatus:
|
||||
"""Market status enum-like class"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
PREMARKET = "premarket"
|
||||
AFTERHOURS = "afterhours"
|
||||
|
||||
|
||||
class MarketService:
|
||||
"""Market data service for price management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tickers: List[str],
|
||||
poll_interval: int = 10,
|
||||
backtest_mode: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
backtest_start_date: Optional[str] = None,
|
||||
backtest_end_date: Optional[str] = None,
|
||||
):
|
||||
self.tickers = [normalize_symbol(ticker) for ticker in tickers]
|
||||
self.poll_interval = poll_interval
|
||||
self.backtest_mode = backtest_mode
|
||||
self.api_key = api_key
|
||||
self.backtest_start_date = backtest_start_date
|
||||
self.backtest_end_date = backtest_end_date
|
||||
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_record_func: Optional[Callable[..., None]] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
# Market status tracking
|
||||
self._last_market_status: Optional[str] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
def get_live_quote_provider(self) -> Optional[str]:
|
||||
"""Return the active live quote provider for UI/debugging."""
|
||||
if self.backtest_mode:
|
||||
return "backtest"
|
||||
if self._price_manager and hasattr(self._price_manager, "provider"):
|
||||
provider = getattr(self._price_manager, "provider", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
return provider.strip().lower()
|
||||
return None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
return "BACKTEST"
|
||||
return "LIVE"
|
||||
|
||||
async def start(self, broadcast_func: Callable):
|
||||
"""Start market data service"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._broadcast_func = broadcast_func
|
||||
|
||||
if self.backtest_mode:
|
||||
self._start_backtest_mode()
|
||||
else:
|
||||
self._start_real_mode()
|
||||
|
||||
logger.info(
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
|
||||
"""Register an optional callback for persisting runtime price points."""
|
||||
self._price_record_func = recorder
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
def callback(price_data: Dict[str, Any]):
|
||||
symbol = price_data["symbol"]
|
||||
self.cache[symbol] = price_data
|
||||
|
||||
loop = self._loop
|
||||
if loop and loop.is_running() and self._broadcast_func:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_price_update(price_data),
|
||||
loop,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
provider = self._resolve_live_quote_provider()
|
||||
|
||||
if provider == "finnhub" and not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
self._price_manager = PollingPriceManager(
|
||||
api_key=self.api_key,
|
||||
poll_interval=self.poll_interval,
|
||||
provider=provider,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _resolve_live_quote_provider(self) -> str:
|
||||
"""Pick the first configured provider that supports live quote polling."""
|
||||
for provider in get_data_sources():
|
||||
if provider in {"finnhub", "yfinance"}:
|
||||
return provider
|
||||
return "yfinance"
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
)
|
||||
|
||||
self._price_manager = HistoricalPriceManager()
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
|
||||
if self.backtest_start_date and self.backtest_end_date:
|
||||
self._price_manager.preload_data(
|
||||
self.backtest_start_date,
|
||||
self.backtest_end_date,
|
||||
)
|
||||
|
||||
self._price_manager.start()
|
||||
|
||||
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
|
||||
"""Broadcast price update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
symbol = price_data["symbol"]
|
||||
price = price_data["price"]
|
||||
open_price = price_data.get("open", price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
if self._price_record_func:
|
||||
try:
|
||||
self._price_record_func(
|
||||
ticker=symbol,
|
||||
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
|
||||
price=float(price),
|
||||
open_price=float(open_price) if open_price is not None else None,
|
||||
ret=float(ret),
|
||||
source=self.mode_name.lower(),
|
||||
meta=price_data,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to record price point for %s: %s",
|
||||
symbol,
|
||||
exc,
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"open": open_price,
|
||||
"ret": ret,
|
||||
"timestamp": price_data.get("timestamp"),
|
||||
"realtime_prices": {
|
||||
t: self._get_cached_price(t) for t in self.tickers
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
|
||||
"""Get cached price data for a ticker"""
|
||||
if ticker in self.cache:
|
||||
return self.cache[ticker]
|
||||
# Return from price manager if not in cache
|
||||
if self._price_manager:
|
||||
price = self._price_manager.get_latest_price(ticker)
|
||||
if price:
|
||||
return {"price": price, "symbol": ticker}
|
||||
return {"price": 0, "symbol": ticker}
|
||||
|
||||
def stop(self):
|
||||
"""Stop market service"""
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
if self._price_manager:
|
||||
self._price_manager.stop()
|
||||
self._price_manager = None
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
|
||||
"""Hot-update subscribed tickers without restarting the service."""
|
||||
normalized: List[str] = []
|
||||
for ticker in tickers:
|
||||
symbol = normalize_symbol(ticker)
|
||||
if symbol and symbol not in normalized:
|
||||
normalized.append(symbol)
|
||||
|
||||
previous = list(self.tickers)
|
||||
removed = [ticker for ticker in previous if ticker not in normalized]
|
||||
added = [ticker for ticker in normalized if ticker not in previous]
|
||||
self.tickers = normalized
|
||||
|
||||
if self._price_manager:
|
||||
if removed:
|
||||
self._price_manager.unsubscribe(removed)
|
||||
if added:
|
||||
self._price_manager.subscribe(added)
|
||||
|
||||
if self.backtest_mode and self._current_date:
|
||||
self._price_manager.set_date(self._current_date)
|
||||
|
||||
for ticker in removed:
|
||||
self.cache.pop(ticker, None)
|
||||
|
||||
return {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"active": list(self.tickers),
|
||||
}
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
if not self.backtest_mode or not self._price_manager:
|
||||
return
|
||||
self._current_date = date
|
||||
self._price_manager.set_date(date)
|
||||
logger.info(f"Backtest date: {date}")
|
||||
|
||||
async def emit_market_open(self):
|
||||
"""Emit market open prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_open_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_open_prices()
|
||||
logger.info(f"Open prices: {prices}")
|
||||
|
||||
async def emit_market_close(self):
|
||||
"""Emit market close prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_close_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_close_prices()
|
||||
logger.info(f"Close prices: {prices}")
|
||||
|
||||
def get_open_prices(self) -> Dict[str, float]:
|
||||
"""Get open prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_open_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("open") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_close_prices(self) -> Dict[str, float]:
|
||||
"""Get close prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_close_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("close") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
ticker: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
return self._price_manager.get_price_for_date(
|
||||
ticker,
|
||||
date,
|
||||
price_type,
|
||||
)
|
||||
return self.get_price_sync(ticker)
|
||||
|
||||
# Common methods
|
||||
def get_price_sync(self, ticker: str) -> Optional[float]:
|
||||
"""Get latest price synchronously"""
|
||||
# Try cache first
|
||||
data = self.cache.get(ticker)
|
||||
if data and data.get("price"):
|
||||
return data["price"]
|
||||
# Try price manager
|
||||
if self._price_manager:
|
||||
return self._price_manager.get_latest_price(ticker)
|
||||
return None
|
||||
|
||||
def get_all_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = self.get_price_sync(ticker)
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
# Live mode async waiting methods
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _get_market_hours(self, date: datetime) -> tuple:
|
||||
"""Get market open and close times for a given date"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
schedule = NYSE_CALENDAR.schedule(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
if schedule.empty:
|
||||
return None, None
|
||||
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
|
||||
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
|
||||
return market_open, market_close
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date + timedelta(days=1)
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
def _get_trading_date_for_execution(self) -> tuple:
|
||||
"""
|
||||
Determine the trading date for execution.
|
||||
|
||||
Returns:
|
||||
(trading_date, market_open_time, market_close_time)
|
||||
|
||||
Logic:
|
||||
- If today is a trading day and market has opened: use today
|
||||
- If today is a trading day but market hasn't opened: wait for open
|
||||
- If today is not a trading day: use next trading day
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if self._is_trading_day(today):
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
return today, market_open, market_close
|
||||
else:
|
||||
# Weekend or holiday - find next trading day
|
||||
next_day = self._next_trading_day(today)
|
||||
market_open, market_close = self._get_market_hours(next_day)
|
||||
return next_day, market_open, market_close
|
||||
|
||||
async def wait_for_open_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market open and return open prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already open today: return current prices immediately
|
||||
- If market hasn't opened yet today: wait until open
|
||||
- If not a trading day: wait until next trading day opens
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, market_open, _ = self._get_trading_date_for_execution()
|
||||
|
||||
if market_open is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_open:
|
||||
wait_seconds = (market_open - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market open "
|
||||
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure prices are available
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already open for {trading_date_str}, "
|
||||
f"getting current prices",
|
||||
)
|
||||
|
||||
# Poll until we have valid prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got open prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
async def wait_for_close_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market close and return close prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already closed today: return current prices immediately
|
||||
- If market hasn't closed yet: wait until close
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, _, market_close = self._get_trading_date_for_execution()
|
||||
|
||||
if market_close is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_close:
|
||||
wait_seconds = (market_close - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market close "
|
||||
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure final prices settle
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already closed for {trading_date_str}, "
|
||||
f"getting close prices",
|
||||
)
|
||||
|
||||
# Get final prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got close prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
def get_live_trading_date(self) -> str:
|
||||
"""Get the trading date that will be used for live execution"""
|
||||
trading_date, _, _ = self._get_trading_date_for_execution()
|
||||
return trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
async def _poll_for_prices(
|
||||
self,
|
||||
max_retries: int = 12,
|
||||
) -> Dict[str, float]:
|
||||
"""Poll until all prices are available"""
|
||||
for _ in range(max_retries):
|
||||
prices = self.get_all_prices()
|
||||
if all(p > 0 for p in prices.values()):
|
||||
return prices
|
||||
logger.debug("Waiting for prices to be available...")
|
||||
await asyncio.sleep(5)
|
||||
# Return whatever we have
|
||||
return self.get_all_prices()
|
||||
|
||||
# ========== Market Status Methods ==========
|
||||
|
||||
def get_market_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market status
|
||||
|
||||
Returns:
|
||||
Dict with status info:
|
||||
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
|
||||
- status_text: Human readable status
|
||||
- is_trading_day: Whether today is a trading day
|
||||
- market_open: Market open time (if trading day)
|
||||
- market_close: Market close time (if trading day)
|
||||
"""
|
||||
if self.backtest_mode:
|
||||
# In backtest mode, always return open
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
is_trading = self._is_trading_day(today)
|
||||
|
||||
if not is_trading:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
|
||||
if market_open is None or market_close is None:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
if now < market_open:
|
||||
return {
|
||||
"status": MarketStatus.PREMARKET,
|
||||
"status_text": "Pre-Market",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Market Open",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
"live_quote_provider": self.get_live_quote_provider(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
"""Check market status and broadcast if changed"""
|
||||
status = self.get_market_status()
|
||||
current_status = status["status"]
|
||||
|
||||
if current_status != self._last_market_status:
|
||||
self._last_market_status = current_status
|
||||
await self._broadcast_market_status(status)
|
||||
|
||||
# Handle session transitions
|
||||
if current_status == MarketStatus.OPEN:
|
||||
await self._on_session_start()
|
||||
elif (
|
||||
current_status == MarketStatus.CLOSED
|
||||
and self._session_start_values is not None
|
||||
):
|
||||
self._on_session_end()
|
||||
|
||||
async def _broadcast_market_status(self, status: Dict[str, Any]):
|
||||
"""Broadcast market status update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "market_status_update",
|
||||
"market_status": status,
|
||||
"timestamp": datetime.now(NYSE_TZ).isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(f"Market status: {status['status_text']}")
|
||||
|
||||
async def _on_session_start(self):
|
||||
"""Called when market session starts - capture baseline values"""
|
||||
# Wait briefly for prices to be available
|
||||
await asyncio.sleep(2)
|
||||
|
||||
prices = self.get_all_prices()
|
||||
if prices and any(p > 0 for p in prices.values()):
|
||||
self._session_start_values = prices.copy()
|
||||
self._session_start_timestamp = int(
|
||||
datetime.now().timestamp() * 1000,
|
||||
)
|
||||
logger.info(f"Session started with prices: {prices}")
|
||||
|
||||
def _on_session_end(self):
|
||||
"""Called when market session ends - clear session data"""
|
||||
self._session_start_values = None
|
||||
self._session_start_timestamp = None
|
||||
logger.info("Session ended, cleared session data")
|
||||
|
||||
def get_session_returns(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
portfolio_value: Optional[float] = None,
|
||||
session_start_portfolio_value: Optional[float] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate session returns (from session start to now)
|
||||
|
||||
Args:
|
||||
current_prices: Current prices for tickers
|
||||
portfolio_value: Current portfolio value (optional)
|
||||
session_start_portfolio_value:
|
||||
|
||||
Returns:
|
||||
Dict with return data or None if session not started
|
||||
"""
|
||||
if self._session_start_values is None:
|
||||
return None
|
||||
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
returns = {}
|
||||
|
||||
# Calculate individual ticker returns
|
||||
for ticker, start_price in self._session_start_values.items():
|
||||
current = current_prices.get(ticker)
|
||||
if current and start_price and start_price > 0:
|
||||
ret = ((current - start_price) / start_price) * 100
|
||||
returns[ticker] = round(ret, 4)
|
||||
|
||||
result = {
|
||||
"timestamp": timestamp,
|
||||
"ticker_returns": returns,
|
||||
}
|
||||
|
||||
# Calculate portfolio return if values provided
|
||||
if (
|
||||
portfolio_value is not None
|
||||
and session_start_portfolio_value is not None
|
||||
):
|
||||
if session_start_portfolio_value > 0:
|
||||
portfolio_ret = (
|
||||
(portfolio_value - session_start_portfolio_value)
|
||||
/ session_start_portfolio_value
|
||||
) * 100
|
||||
result["portfolio_return"] = round(portfolio_ret, 4)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def session_start_values(self) -> Optional[Dict[str, float]]:
|
||||
"""Get session start values for external use"""
|
||||
return self._session_start_values
|
||||
|
||||
@property
|
||||
def session_start_timestamp(self) -> Optional[int]:
|
||||
"""Get session start timestamp"""
|
||||
return self._session_start_timestamp
|
||||
754
backend/services/openclaw_cli.py
Normal file
754
backend/services/openclaw_cli.py
Normal file
@@ -0,0 +1,754 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Thin service wrapper around the OpenClaw CLI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from shared.models.openclaw import (
|
||||
AgentSummary,
|
||||
AgentsList,
|
||||
ApprovalRequest,
|
||||
ApprovalsList,
|
||||
CronJob,
|
||||
CronList,
|
||||
DaemonStatus,
|
||||
HookStatusEntry,
|
||||
HookStatusReport,
|
||||
ModelAliasesList,
|
||||
ModelFallbacksList,
|
||||
ModelRow,
|
||||
ModelsList,
|
||||
OpenClawStatus,
|
||||
PairingListResponse,
|
||||
PluginDiagnostic,
|
||||
PluginRecord,
|
||||
PluginsList,
|
||||
QrCodeResponse,
|
||||
SecretsAuditReport,
|
||||
SecurityAuditResponse,
|
||||
SecurityAuditReport,
|
||||
SessionEntry,
|
||||
SessionHistory,
|
||||
SessionsList,
|
||||
SkillStatusEntry,
|
||||
SkillStatusReport,
|
||||
SkillUpdateResult,
|
||||
UpdateCheckResult,
|
||||
UpdateStatusResponse,
|
||||
normalize_agents,
|
||||
normalize_approvals,
|
||||
normalize_cron_jobs,
|
||||
normalize_daemon_status,
|
||||
normalize_hooks,
|
||||
normalize_model_aliases,
|
||||
normalize_model_fallbacks,
|
||||
normalize_models,
|
||||
normalize_pairing,
|
||||
normalize_plugins,
|
||||
normalize_qr,
|
||||
normalize_security_audit,
|
||||
normalize_secrets_audit,
|
||||
normalize_session_history,
|
||||
normalize_sessions,
|
||||
normalize_skill_update,
|
||||
normalize_skills,
|
||||
normalize_status,
|
||||
normalize_update_status,
|
||||
)
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
REFERENCE_OPENCLAW_ROOT = PROJECT_ROOT / "reference" / "openclaw"
|
||||
REFERENCE_OPENCLAW_ENTRY = REFERENCE_OPENCLAW_ROOT / "openclaw.mjs"
|
||||
|
||||
|
||||
class OpenClawCliError(RuntimeError):
|
||||
"""Raised when the OpenClaw CLI invocation fails."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
*,
|
||||
command: list[str],
|
||||
exit_code: int | None = None,
|
||||
stdout: str = "",
|
||||
stderr: str = "",
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.command = command
|
||||
self.exit_code = exit_code
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpenClawCliResult:
|
||||
"""Command execution result."""
|
||||
|
||||
command: list[str]
|
||||
exit_code: int
|
||||
stdout: str
|
||||
stderr: str
|
||||
|
||||
|
||||
def resolve_openclaw_base_command() -> list[str]:
|
||||
"""Resolve the command prefix used to launch OpenClaw."""
|
||||
explicit = os.getenv("OPENCLAW_CMD", "").strip()
|
||||
if explicit:
|
||||
return shlex.split(explicit)
|
||||
|
||||
installed = shutil.which("openclaw")
|
||||
if installed:
|
||||
return [installed]
|
||||
|
||||
if REFERENCE_OPENCLAW_ENTRY.exists():
|
||||
return [sys.executable if sys.executable.endswith("node") else "node", str(REFERENCE_OPENCLAW_ENTRY)]
|
||||
|
||||
return ["openclaw"]
|
||||
|
||||
|
||||
def resolve_openclaw_cwd() -> Path:
|
||||
"""Resolve the working directory for CLI execution."""
|
||||
explicit = os.getenv("OPENCLAW_CWD", "").strip()
|
||||
if explicit:
|
||||
return Path(explicit).expanduser()
|
||||
if REFERENCE_OPENCLAW_ROOT.exists():
|
||||
return REFERENCE_OPENCLAW_ROOT
|
||||
return PROJECT_ROOT
|
||||
|
||||
|
||||
class OpenClawCliService:
|
||||
"""OpenClaw CLI integration service."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
base_command: list[str] | None = None,
|
||||
cwd: Path | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
) -> None:
|
||||
self.base_command = list(base_command or resolve_openclaw_base_command())
|
||||
self.cwd = cwd or resolve_openclaw_cwd()
|
||||
self.timeout_seconds = timeout_seconds or float(
|
||||
os.getenv("OPENCLAW_TIMEOUT_SECONDS", "15")
|
||||
)
|
||||
|
||||
def health(self) -> dict[str, Any]:
|
||||
"""Return the current CLI wiring state."""
|
||||
binary = self.base_command[0] if self.base_command else "openclaw"
|
||||
resolved = shutil.which(binary) if len(self.base_command) == 1 else binary
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "openclaw-service",
|
||||
"base_command": self.base_command,
|
||||
"cwd": str(self.cwd),
|
||||
"binary_resolved": resolved is not None,
|
||||
"reference_entry_available": REFERENCE_OPENCLAW_ENTRY.exists(),
|
||||
"timeout_seconds": self.timeout_seconds,
|
||||
}
|
||||
|
||||
def status(self) -> dict[str, Any]:
|
||||
"""Read `openclaw status --json`."""
|
||||
return self.run_json(["status", "--json"])
|
||||
|
||||
def list_sessions(self) -> dict[str, Any]:
|
||||
"""Read `openclaw sessions --json`."""
|
||||
return self.run_json(["sessions", "--json"])
|
||||
|
||||
def get_session(self, session_key: str) -> dict[str, Any]:
|
||||
"""Resolve a single session out of the sessions list."""
|
||||
payload = self.list_sessions()
|
||||
sessions = payload.get("sessions") or []
|
||||
for item in sessions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if item.get("key") == session_key or item.get("sessionKey") == session_key:
|
||||
return item
|
||||
raise KeyError(session_key)
|
||||
|
||||
def get_session_history(self, session_key: str, *, limit: int = 20) -> dict[str, Any]:
|
||||
"""Read session history with a JSON-first fallback to raw text."""
|
||||
args = ["sessions", "history", session_key, "--json", "--limit", str(limit)]
|
||||
try:
|
||||
return self.run_json(args)
|
||||
except OpenClawCliError as exc:
|
||||
raise exc
|
||||
except json.JSONDecodeError:
|
||||
result = self.run(args)
|
||||
return {
|
||||
"sessionKey": session_key,
|
||||
"limit": limit,
|
||||
"rawText": result.stdout,
|
||||
}
|
||||
|
||||
def list_cron_jobs(self) -> dict[str, Any]:
|
||||
"""Read `openclaw cron list --json`."""
|
||||
return self.run_json(["cron", "list", "--json"])
|
||||
|
||||
def list_approvals(self) -> dict[str, Any]:
|
||||
"""Read `openclaw approvals get --json`."""
|
||||
return self.run_json(["approvals", "get", "--json"])
|
||||
|
||||
def list_agents(self) -> dict[str, Any]:
|
||||
"""Read `openclaw agents list --json`."""
|
||||
return self.run_json(["agents", "list", "--json"])
|
||||
|
||||
def list_skills(self) -> dict[str, Any]:
|
||||
"""Read `openclaw skills list --json`."""
|
||||
return self.run_json(["skills", "list", "--json"])
|
||||
|
||||
def list_models(self) -> dict[str, Any]:
|
||||
"""Read `openclaw models list --json`."""
|
||||
return self.run_json(["models", "list", "--json"])
|
||||
|
||||
def list_hooks(self) -> dict[str, Any]:
|
||||
"""Read `openclaw hooks list --json`."""
|
||||
return self.run_json(["hooks", "list", "--json"])
|
||||
|
||||
def list_plugins(self) -> dict[str, Any]:
|
||||
"""Read `openclaw plugins list --json`."""
|
||||
return self.run_json(["plugins", "list", "--json"])
|
||||
|
||||
def secrets_audit(self) -> dict[str, Any]:
|
||||
"""Read `openclaw secrets audit --json`."""
|
||||
return self.run_json(["secrets", "audit", "--json"])
|
||||
|
||||
def security_audit(self) -> dict[str, Any]:
|
||||
"""Read `openclaw security audit --json`."""
|
||||
return self.run_json(["security", "audit", "--json"])
|
||||
|
||||
def daemon_status(self) -> dict[str, Any]:
|
||||
"""Read `openclaw daemon status --json`."""
|
||||
return self.run_json(["daemon", "status", "--json"])
|
||||
|
||||
def pairing_list(self) -> dict[str, Any]:
|
||||
"""Read `openclaw pairing list --json`."""
|
||||
return self.run_json(["pairing", "list", "--json"])
|
||||
|
||||
def qr_code(self) -> dict[str, Any]:
|
||||
"""Read `openclaw qr --json`."""
|
||||
return self.run_json(["qr", "--json"])
|
||||
|
||||
def update_status(self) -> dict[str, Any]:
|
||||
"""Read `openclaw update status --json`."""
|
||||
return self.run_json(["update", "status", "--json"])
|
||||
|
||||
def list_model_aliases(self) -> dict[str, Any]:
|
||||
"""Read `openclaw models aliases list --json`."""
|
||||
return self.run_json(["models", "aliases", "list", "--json"])
|
||||
|
||||
def list_model_fallbacks(self) -> dict[str, Any]:
|
||||
"""Read `openclaw models fallbacks list --json`."""
|
||||
return self.run_json(["models", "fallbacks", "list", "--json"])
|
||||
|
||||
def list_model_image_fallbacks(self) -> dict[str, Any]:
|
||||
"""Read `openclaw models image-fallbacks list --json`."""
|
||||
return self.run_json(["models", "image-fallbacks", "list", "--json"])
|
||||
|
||||
def skill_update(self, *, slug: str | None = None, all: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw skills update --json`."""
|
||||
args = ["skills", "update", "--json"]
|
||||
if slug:
|
||||
args.append(slug)
|
||||
if all:
|
||||
args.append("--all")
|
||||
return self.run_json(args)
|
||||
|
||||
def models_status(self, *, probe: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw models status --json [--probe]`."""
|
||||
args = ["models", "status", "--json"]
|
||||
if probe:
|
||||
args.append("--probe")
|
||||
return self.run_json(args)
|
||||
|
||||
def channels_status(self, *, probe: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw channels status [--probe] --json`."""
|
||||
args = ["channels", "status", "--json"]
|
||||
if probe:
|
||||
args.append("--probe")
|
||||
return self.run_json(args)
|
||||
|
||||
def list_workspace_files(self, workspace_path: str) -> dict[str, Any]:
|
||||
"""List .md files in an OpenClaw agent workspace with their content.
|
||||
|
||||
Reads the workspace directory and returns metadata + content for each .md file.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
wp = Path(workspace_path).expanduser().resolve()
|
||||
if not wp.exists() or not wp.is_dir():
|
||||
return {"workspace": str(wp), "files": [], "error": "workspace not found"}
|
||||
|
||||
md_files = sorted(wp.glob("*.md"))
|
||||
files = []
|
||||
for md_file in md_files:
|
||||
try:
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
# Preview: first 300 chars
|
||||
preview = content[:300].strip()
|
||||
files.append({
|
||||
"name": md_file.name,
|
||||
"path": str(md_file),
|
||||
"size": len(content),
|
||||
"preview": preview,
|
||||
"previewTruncated": len(content) > 300,
|
||||
})
|
||||
except OSError as exc:
|
||||
files.append({
|
||||
"name": md_file.name,
|
||||
"path": str(md_file),
|
||||
"size": 0,
|
||||
"preview": "",
|
||||
"error": str(exc),
|
||||
})
|
||||
|
||||
return {"workspace": str(wp), "files": files}
|
||||
|
||||
def channels_list(self) -> dict[str, Any]:
|
||||
"""Read `openclaw channels list --json`."""
|
||||
return self.run_json(["channels", "list", "--json"])
|
||||
|
||||
def hook_info(self, name: str) -> dict[str, Any]:
|
||||
"""Read `openclaw hooks info <name> --json`."""
|
||||
args = ["hooks", "info", name, "--json"]
|
||||
try:
|
||||
return self.run_json(args)
|
||||
except json.JSONDecodeError:
|
||||
result = self.run(args)
|
||||
return {"raw": result.stdout}
|
||||
|
||||
def hooks_check(self) -> dict[str, Any]:
|
||||
"""Read `openclaw hooks check --json`."""
|
||||
return self.run_json(["hooks", "check", "--json"])
|
||||
|
||||
def plugins_inspect(self, *, plugin_id: str | None = None, all: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw plugins inspect [--json] [--all]`."""
|
||||
args = ["plugins", "inspect", "--json"]
|
||||
if all:
|
||||
args.append("--all")
|
||||
elif plugin_id:
|
||||
args.append(plugin_id)
|
||||
return self.run_json(args)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Typed variants — these use Pydantic models and are the preferred path.
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def status_model(self) -> OpenClawStatus:
|
||||
"""Read and parse `openclaw status --json` into a typed model."""
|
||||
raw = self.status()
|
||||
return normalize_status(raw)
|
||||
|
||||
def list_sessions_model(self) -> SessionsList:
|
||||
"""Read and parse `openclaw sessions --json` into a typed model."""
|
||||
raw = self.list_sessions()
|
||||
return normalize_sessions(raw)
|
||||
|
||||
def get_session_model(self, session_key: str) -> SessionEntry:
|
||||
"""Resolve a single session and return a typed model."""
|
||||
raw = self.get_session(session_key)
|
||||
return SessionEntry.model_validate(raw, strict=False)
|
||||
|
||||
def get_session_history_model(self, session_key: str, *, limit: int = 20) -> SessionHistory:
|
||||
"""Read session history and return a typed model."""
|
||||
raw = self.get_session_history(session_key, limit=limit)
|
||||
return normalize_session_history(raw, session_key=session_key)
|
||||
|
||||
def list_cron_jobs_model(self) -> CronList:
|
||||
"""Read and parse `openclaw cron list --json` into a typed model."""
|
||||
raw = self.list_cron_jobs()
|
||||
return normalize_cron_jobs(raw)
|
||||
|
||||
def list_approvals_model(self) -> ApprovalsList:
|
||||
"""Read and parse `openclaw approvals get --json` into a typed model."""
|
||||
raw = self.list_approvals()
|
||||
return normalize_approvals(raw)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Typed variants
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def list_agents_model(self) -> AgentsList:
|
||||
"""Read and parse `openclaw agents list --json` into a typed model."""
|
||||
raw = self.list_agents()
|
||||
if isinstance(raw, list):
|
||||
return AgentsList(agents=[AgentSummary.model_validate(a, strict=False) for a in raw if isinstance(a, dict)])
|
||||
return normalize_agents(raw)
|
||||
|
||||
def list_skills_model(self) -> SkillStatusReport:
|
||||
"""Read and parse `openclaw skills list --json` into a typed model."""
|
||||
raw = self.list_skills()
|
||||
return normalize_skills(raw)
|
||||
|
||||
def list_models_model(self) -> ModelsList:
|
||||
"""Read and parse `openclaw models list --json` into a typed model."""
|
||||
raw = self.list_models()
|
||||
if isinstance(raw, list):
|
||||
return ModelsList(models=[ModelRow.model_validate(m, strict=False) for m in raw if isinstance(m, dict)])
|
||||
return normalize_models(raw)
|
||||
|
||||
def list_hooks_model(self) -> HookStatusReport:
|
||||
raw = self.list_hooks()
|
||||
return normalize_hooks(raw)
|
||||
|
||||
def list_plugins_model(self) -> PluginsList:
|
||||
raw = self.list_plugins()
|
||||
return normalize_plugins(raw)
|
||||
|
||||
def secrets_audit_model(self) -> SecretsAuditReport:
|
||||
raw = self.secrets_audit()
|
||||
return normalize_secrets_audit(raw)
|
||||
|
||||
def security_audit_model(self) -> SecurityAuditResponse:
|
||||
raw = self.security_audit()
|
||||
return normalize_security_audit(raw)
|
||||
|
||||
def daemon_status_model(self) -> DaemonStatus:
|
||||
raw = self.daemon_status()
|
||||
return normalize_daemon_status(raw)
|
||||
|
||||
def pairing_list_model(self) -> PairingListResponse:
|
||||
raw = self.pairing_list()
|
||||
return normalize_pairing(raw)
|
||||
|
||||
def qr_code_model(self) -> QrCodeResponse:
|
||||
raw = self.qr_code()
|
||||
return normalize_qr(raw)
|
||||
|
||||
def update_status_model(self) -> UpdateStatusResponse:
|
||||
raw = self.update_status()
|
||||
return normalize_update_status(raw)
|
||||
|
||||
def list_model_aliases_model(self) -> ModelAliasesList:
|
||||
raw = self.list_model_aliases()
|
||||
return normalize_model_aliases(raw)
|
||||
|
||||
def list_model_fallbacks_model(self) -> ModelFallbacksList:
|
||||
raw = self.list_model_fallbacks()
|
||||
return normalize_model_fallbacks(raw)
|
||||
|
||||
def list_model_image_fallbacks_model(self) -> ModelFallbacksList:
|
||||
raw = self.list_model_image_fallbacks()
|
||||
return normalize_model_fallbacks(raw)
|
||||
|
||||
def skill_update_model(self, *, slug: str | None = None, all: bool = False) -> SkillUpdateResult:
|
||||
raw = self.skill_update(slug=slug, all=all)
|
||||
return normalize_skill_update(raw)
|
||||
|
||||
def models_status_model(self, *, probe: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw models status --json` and return the raw dict."""
|
||||
return self.models_status(probe=probe)
|
||||
|
||||
def channels_status_model(self, *, probe: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw channels status --json` and return the raw dict."""
|
||||
return self.channels_status(probe=probe)
|
||||
|
||||
def channels_list_model(self) -> dict[str, Any]:
|
||||
"""Read `openclaw channels list --json` and return the raw dict."""
|
||||
return self.channels_list()
|
||||
|
||||
def hook_info_model(self, name: str) -> dict[str, Any]:
|
||||
"""Read `openclaw hooks info <name> --json` and return the raw dict."""
|
||||
return self.hook_info(name)
|
||||
|
||||
def hooks_check_model(self) -> dict[str, Any]:
|
||||
"""Read `openclaw hooks check --json` and return the raw dict."""
|
||||
return self.hooks_check()
|
||||
|
||||
def plugins_inspect_model(self, *, plugin_id: str | None = None, all: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw plugins inspect --json [--all]` and return the raw dict."""
|
||||
return self.plugins_inspect(plugin_id=plugin_id, all=all)
|
||||
|
||||
def agents_bindings(self, *, agent: str | None = None) -> dict[str, Any]:
|
||||
"""Read `openclaw agents bindings --json [--agent <id>]`."""
|
||||
args = ["agents", "bindings", "--json"]
|
||||
if agent:
|
||||
args.extend(["--agent", agent])
|
||||
return self.run_json(args)
|
||||
|
||||
def agents_bindings_model(self, *, agent: str | None = None) -> dict[str, Any]:
|
||||
"""Read `openclaw agents bindings --json` and return the raw dict."""
|
||||
return self.agents_bindings(agent=agent)
|
||||
|
||||
def agents_presence(self) -> dict[str, Any]:
|
||||
"""Read session presence for all agents from runtime session files.
|
||||
|
||||
Reads ~/.openclaw/agents/{agentId}/sessions/sessions.json for each agent
|
||||
and counts sessions in active states within a recency window.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
openclaw_home = Path.home() / ".openclaw"
|
||||
agents_path = openclaw_home / "agents"
|
||||
|
||||
if not agents_path.exists():
|
||||
return {"status": "not_connected", "agents": {}}
|
||||
|
||||
ACTIVE_STATES = {
|
||||
"running", "active", "busy", "blocked", "waiting_approval",
|
||||
"working", "in_progress", "processing", "thinking", "executing", "streaming",
|
||||
}
|
||||
|
||||
RECENCY_WINDOW_MS = 45 * 60 * 1000 # 45 minutes
|
||||
|
||||
result: dict[str, Any] = {"status": "connected", "agents": {}}
|
||||
|
||||
try:
|
||||
for agent_dir in agents_path.iterdir():
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
sessions_file = agent_dir / "sessions" / "sessions.json"
|
||||
if not sessions_file.exists():
|
||||
continue
|
||||
|
||||
try:
|
||||
sessions_data = json.loads(sessions_file.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
continue
|
||||
|
||||
sessions = sessions_data if isinstance(sessions_data, list) else []
|
||||
now_ms = 0 # placeholder; we'll skip recency check if no ts field
|
||||
|
||||
active_count = 0
|
||||
for session in sessions:
|
||||
if not isinstance(session, dict):
|
||||
continue
|
||||
state = str(session.get("state") or session.get("status") or "").lower()
|
||||
if state in ACTIVE_STATES:
|
||||
active_count += 1
|
||||
|
||||
if active_count > 0:
|
||||
result["agents"][agent_dir.name] = {
|
||||
"activeSessions": active_count,
|
||||
"status": "active",
|
||||
}
|
||||
else:
|
||||
result["agents"][agent_dir.name] = {
|
||||
"activeSessions": 0,
|
||||
"status": "idle",
|
||||
}
|
||||
except OSError:
|
||||
result["status"] = "partial"
|
||||
|
||||
return result
|
||||
|
||||
def agents_from_config(self) -> dict[str, Any]:
|
||||
"""Read agent list directly from openclaw.json config file.
|
||||
|
||||
Falls back to scanning ~/.openclaw/agents/ directories when config is absent.
|
||||
This avoids the CLI timeout from `agents list --json`.
|
||||
"""
|
||||
import json
|
||||
|
||||
openclaw_home = Path.home() / ".openclaw"
|
||||
config_path = openclaw_home / "openclaw.json"
|
||||
|
||||
if not config_path.exists():
|
||||
return {"status": "not_connected", "agents": []}
|
||||
|
||||
try:
|
||||
raw = json.loads(config_path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
return {"status": "partial", "agents": []}
|
||||
|
||||
agents_list = raw.get("agents", {}).get("list", [])
|
||||
if not agents_list:
|
||||
return {"status": "partial", "agents": [], "detail": "agents.list is empty"}
|
||||
|
||||
agents = []
|
||||
for entry in agents_list:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
agent_id = entry.get("id", "").strip()
|
||||
if not agent_id:
|
||||
continue
|
||||
agents.append({
|
||||
"id": agent_id,
|
||||
"name": entry.get("name", "").strip() or agent_id,
|
||||
"model": entry.get("model") or "",
|
||||
"workspace": entry.get("workspace") or "",
|
||||
"is_default": entry.get("id") == raw.get("agents", {}).get("defaults", {}).get("id"),
|
||||
})
|
||||
|
||||
return {"status": "connected", "agents": agents}
|
||||
|
||||
def gateway_status(self, *, url: str | None = None, token: str | None = None) -> dict[str, Any]:
|
||||
"""Read `openclaw gateway status --json [--url <url>] [--token <token>]`. May fail if gateway is unreachable."""
|
||||
args = ["gateway", "status", "--json"]
|
||||
if url:
|
||||
args.extend(["--url", url])
|
||||
if token:
|
||||
args.extend(["--token", token])
|
||||
return self.run_json(args)
|
||||
|
||||
def memory_status(self, *, agent: str | None = None, deep: bool = False) -> dict[str, Any]:
|
||||
"""Read `openclaw memory status --json [--agent <id>] [--deep]`. Returns array of per-agent status."""
|
||||
args = ["memory", "status", "--json"]
|
||||
if agent:
|
||||
args.extend(["--agent", agent])
|
||||
if deep:
|
||||
args.append("--deep")
|
||||
return self.run_json(args)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Write agents commands
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def agents_add(
|
||||
self,
|
||||
name: str,
|
||||
*,
|
||||
workspace: str | None = None,
|
||||
model: str | None = None,
|
||||
agent_dir: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
non_interactive: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Run `openclaw agents add <name> [--workspace <dir>] [--model <id>] [--agent-dir <dir>] [--bind <spec>] [--non-interactive] --json`."""
|
||||
args = ["agents", "add", name, "--json"]
|
||||
if workspace:
|
||||
args.extend(["--workspace", workspace])
|
||||
if model:
|
||||
args.extend(["--model", model])
|
||||
if agent_dir:
|
||||
args.extend(["--agent-dir", agent_dir])
|
||||
if bind:
|
||||
for b in bind:
|
||||
args.extend(["--bind", b])
|
||||
if non_interactive:
|
||||
args.append("--non-interactive")
|
||||
return self.run_json(args)
|
||||
|
||||
def agents_delete(self, id: str, *, force: bool = False) -> dict[str, Any]:
|
||||
"""Run `openclaw agents delete <id> [--force] --json`."""
|
||||
args = ["agents", "delete", id, "--json"]
|
||||
if force:
|
||||
args.append("--force")
|
||||
return self.run_json(args)
|
||||
|
||||
def agents_bind(
|
||||
self,
|
||||
*,
|
||||
agent: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Run `openclaw agents bind [--agent <id>] [--bind <spec>] --json`."""
|
||||
args = ["agents", "bind", "--json"]
|
||||
if agent:
|
||||
args.extend(["--agent", agent])
|
||||
if bind:
|
||||
for b in bind:
|
||||
args.extend(["--bind", b])
|
||||
return self.run_json(args)
|
||||
|
||||
def agents_unbind(
|
||||
self,
|
||||
*,
|
||||
agent: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
all: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Run `openclaw agents unbind [--agent <id>] [--bind <spec>] [--all] --json`."""
|
||||
args = ["agents", "unbind", "--json"]
|
||||
if agent:
|
||||
args.extend(["--agent", agent])
|
||||
if bind:
|
||||
for b in bind:
|
||||
args.extend(["--bind", b])
|
||||
if all:
|
||||
args.append("--all")
|
||||
return self.run_json(args)
|
||||
|
||||
def agents_set_identity(
|
||||
self,
|
||||
*,
|
||||
agent: str | None = None,
|
||||
workspace: str | None = None,
|
||||
identity_file: str | None = None,
|
||||
name: str | None = None,
|
||||
emoji: str | None = None,
|
||||
theme: str | None = None,
|
||||
avatar: str | None = None,
|
||||
from_identity: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Run `openclaw agents set-identity [--agent <id>] [--workspace <dir>] [--identity-file <path>] [--from-identity] [--name <n>] [--emoji <e>] [--theme <t>] [--avatar <a>] --json`."""
|
||||
args = ["agents", "set-identity", "--json"]
|
||||
if agent:
|
||||
args.extend(["--agent", agent])
|
||||
if workspace:
|
||||
args.extend(["--workspace", workspace])
|
||||
if identity_file:
|
||||
args.extend(["--identity-file", identity_file])
|
||||
if from_identity:
|
||||
args.append("--from-identity")
|
||||
if name:
|
||||
args.extend(["--name", name])
|
||||
if emoji:
|
||||
args.extend(["--emoji", emoji])
|
||||
if theme:
|
||||
args.extend(["--theme", theme])
|
||||
if avatar:
|
||||
args.extend(["--avatar", avatar])
|
||||
return self.run_json(args)
|
||||
|
||||
def run_json(self, args: list[str]) -> dict[str, Any]:
|
||||
"""Run the CLI and decode JSON stdout, falling back to stderr."""
|
||||
result = self.run(args)
|
||||
text = result.stdout.strip() or result.stderr.strip()
|
||||
if not text:
|
||||
return {}
|
||||
return json.loads(text)
|
||||
|
||||
def run(self, args: list[str]) -> OpenClawCliResult:
|
||||
"""Run the CLI and return stdout/stderr."""
|
||||
command = [*self.base_command, *args]
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
completed = subprocess.run(
|
||||
command,
|
||||
cwd=self.cwd,
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.timeout_seconds,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError as exc:
|
||||
raise OpenClawCliError(
|
||||
"OpenClaw CLI executable was not found.",
|
||||
command=command,
|
||||
) from exc
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise OpenClawCliError(
|
||||
f"OpenClaw CLI timed out after {self.timeout_seconds:.1f}s.",
|
||||
command=command,
|
||||
stdout=exc.stdout or "",
|
||||
stderr=exc.stderr or "",
|
||||
) from exc
|
||||
|
||||
if completed.returncode != 0:
|
||||
raise OpenClawCliError(
|
||||
"OpenClaw CLI command failed.",
|
||||
command=command,
|
||||
exit_code=completed.returncode,
|
||||
stdout=completed.stdout,
|
||||
stderr=completed.stderr,
|
||||
)
|
||||
|
||||
return OpenClawCliResult(
|
||||
command=command,
|
||||
exit_code=completed.returncode,
|
||||
stdout=completed.stdout,
|
||||
stderr=completed.stderr,
|
||||
)
|
||||
280
backend/services/research_db.py
Normal file
280
backend/services/research_db.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Query-oriented storage for explain/research data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from shared.schema import CompanyNews
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS news_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
published_at TEXT,
|
||||
trade_date TEXT,
|
||||
source TEXT,
|
||||
title TEXT NOT NULL,
|
||||
summary TEXT,
|
||||
url TEXT,
|
||||
related TEXT,
|
||||
category TEXT,
|
||||
raw_json TEXT NOT NULL,
|
||||
ingest_run_date TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
|
||||
ON news_items (ticker, trade_date DESC, published_at DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
|
||||
base = item.url or item.title or f"{ticker}-{fallback_index}"
|
||||
return f"{ticker}:{base}"
|
||||
|
||||
|
||||
def _resolve_trade_date(date_value: str | None) -> str | None:
|
||||
if not date_value:
|
||||
return None
|
||||
normalized = str(date_value).strip()
|
||||
if not normalized:
|
||||
return None
|
||||
if "T" in normalized:
|
||||
return normalized.split("T", 1)[0]
|
||||
if " " in normalized:
|
||||
return normalized.split(" ", 1)[0]
|
||||
return normalized[:10]
|
||||
|
||||
|
||||
class ResearchDb:
|
||||
"""Small SQLite helper for explain-oriented news storage."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def upsert_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
items: Iterable[CompanyNews],
|
||||
ingest_run_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Persist provider news and return normalized rows."""
|
||||
normalized_rows: list[dict[str, Any]] = []
|
||||
timestamp = datetime.utcnow().isoformat(timespec="seconds")
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return normalized_rows
|
||||
|
||||
with self._connect() as conn:
|
||||
for index, item in enumerate(items):
|
||||
news_id = _resolve_news_id(symbol, item, index)
|
||||
trade_date = _resolve_trade_date(item.date)
|
||||
payload = item.model_dump()
|
||||
row = {
|
||||
"id": news_id,
|
||||
"ticker": symbol,
|
||||
"published_at": item.date,
|
||||
"trade_date": trade_date,
|
||||
"source": item.source,
|
||||
"title": item.title,
|
||||
"summary": item.summary,
|
||||
"url": item.url,
|
||||
"related": item.related,
|
||||
"category": item.category,
|
||||
"raw_json": _json_dumps(payload),
|
||||
"ingest_run_date": ingest_run_date,
|
||||
"created_at": timestamp,
|
||||
}
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO news_items
|
||||
(id, ticker, published_at, trade_date, source, title, summary, url,
|
||||
related, category, raw_json, ingest_run_date, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
ticker = excluded.ticker,
|
||||
published_at = excluded.published_at,
|
||||
trade_date = excluded.trade_date,
|
||||
source = excluded.source,
|
||||
title = excluded.title,
|
||||
summary = excluded.summary,
|
||||
url = excluded.url,
|
||||
related = excluded.related,
|
||||
category = excluded.category,
|
||||
raw_json = excluded.raw_json,
|
||||
ingest_run_date = excluded.ingest_run_date
|
||||
""",
|
||||
(
|
||||
row["id"],
|
||||
row["ticker"],
|
||||
row["published_at"],
|
||||
row["trade_date"],
|
||||
row["source"],
|
||||
row["title"],
|
||||
row["summary"],
|
||||
row["url"],
|
||||
row["related"],
|
||||
row["category"],
|
||||
row["raw_json"],
|
||||
row["ingest_run_date"],
|
||||
row["created_at"],
|
||||
),
|
||||
)
|
||||
normalized_rows.append(row)
|
||||
return normalized_rows
|
||||
|
||||
def get_news_items(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return normalized news rows for explain UI."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
def get_news_timeline(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Aggregate news counts per trade date for chart markers."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
|
||||
COUNT(*) AS count,
|
||||
COUNT(DISTINCT source) AS source_count,
|
||||
MAX(title) AS top_title
|
||||
FROM news_items
|
||||
WHERE ticker = ?
|
||||
"""
|
||||
params: list[Any] = [symbol]
|
||||
if start_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||
params.append(start_date)
|
||||
if end_date:
|
||||
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||
params.append(end_date)
|
||||
sql += """
|
||||
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
|
||||
ORDER BY date ASC
|
||||
"""
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"date": row["date"],
|
||||
"count": int(row["count"] or 0),
|
||||
"source_count": int(row["source_count"] or 0),
|
||||
"top_title": row["top_title"] or "",
|
||||
}
|
||||
for row in rows
|
||||
if row["date"]
|
||||
]
|
||||
|
||||
def get_news_by_ids(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
article_ids: Iterable[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return selected persisted news items."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
|
||||
if not symbol or not ids:
|
||||
return []
|
||||
|
||||
placeholders = ",".join("?" for _ in ids)
|
||||
sql = f"""
|
||||
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||
url, related, category
|
||||
FROM news_items
|
||||
WHERE ticker = ? AND id IN ({placeholders})
|
||||
ORDER BY COALESCE(published_at, trade_date) DESC
|
||||
"""
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, [symbol, *ids]).fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"ticker": row["ticker"],
|
||||
"date": row["published_at"] or row["trade_date"],
|
||||
"trade_date": row["trade_date"],
|
||||
"source": row["source"],
|
||||
"title": row["title"],
|
||||
"summary": row["summary"],
|
||||
"url": row["url"],
|
||||
"related": row["related"],
|
||||
"category": row["category"],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
512
backend/services/runtime_db.py
Normal file
512
backend/services/runtime_db.py
Normal file
@@ -0,0 +1,512 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped SQLite storage for query-oriented runtime history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
timestamp TEXT,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
ticker TEXT,
|
||||
title TEXT,
|
||||
content TEXT,
|
||||
payload_json TEXT NOT NULL,
|
||||
run_date TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
side TEXT,
|
||||
qty REAL,
|
||||
price REAL,
|
||||
timestamp TEXT,
|
||||
trading_date TEXT,
|
||||
agent_id TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signals (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
role TEXT,
|
||||
signal TEXT,
|
||||
confidence REAL,
|
||||
reasoning_json TEXT,
|
||||
reasons_json TEXT,
|
||||
risks_json TEXT,
|
||||
invalidation TEXT,
|
||||
next_action TEXT,
|
||||
intrinsic_value REAL,
|
||||
fair_value_range_json TEXT,
|
||||
value_gap_pct REAL,
|
||||
valuation_methods_json TEXT,
|
||||
real_return REAL,
|
||||
is_correct TEXT,
|
||||
trade_date TEXT,
|
||||
created_at TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS price_points (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
open_price REAL,
|
||||
ret REAL,
|
||||
source TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_key(*parts: Any) -> str:
|
||||
raw = "::".join("" if part is None else str(part) for part in parts)
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class RuntimeDb:
|
||||
"""Small SQLite helper for append-mostly runtime data."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def insert_event(self, event: Dict[str, Any]):
|
||||
payload = dict(event or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
event_id = payload.get("id") or _hash_key(
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("content"),
|
||||
payload.get("title"),
|
||||
)
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
|
||||
ticker = payload["tickers"][0]
|
||||
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO events
|
||||
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
event_id,
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("agentName") or payload.get("agent_name"),
|
||||
ticker,
|
||||
payload.get("title"),
|
||||
payload.get("content"),
|
||||
_json_dumps(payload),
|
||||
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
|
||||
),
|
||||
)
|
||||
|
||||
def get_recent_feed_events(
|
||||
self,
|
||||
*,
|
||||
limit: int = 200,
|
||||
event_types: Optional[Iterable[str]] = None,
|
||||
) -> list[Dict[str, Any]]:
|
||||
"""Return recent persisted feed events in newest-first order."""
|
||||
event_types = tuple(event_types or ())
|
||||
sql = """
|
||||
SELECT payload_json
|
||||
FROM events
|
||||
"""
|
||||
params: list[Any] = []
|
||||
if event_types:
|
||||
placeholders = ",".join("?" for _ in event_types)
|
||||
sql += f" WHERE event_type IN ({placeholders})"
|
||||
params.extend(event_types)
|
||||
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
items: list[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {}
|
||||
if payload:
|
||||
items.append(payload)
|
||||
return items
|
||||
|
||||
def get_last_day_feed_events(
|
||||
self,
|
||||
*,
|
||||
current_date: Optional[str] = None,
|
||||
limit: int = 200,
|
||||
event_types: Optional[Iterable[str]] = None,
|
||||
) -> list[Dict[str, Any]]:
|
||||
"""Return latest trading day events in newest-first order for replay."""
|
||||
event_types = tuple(event_types or ())
|
||||
target_date = str(current_date or "").strip() or None
|
||||
|
||||
with self._connect() as conn:
|
||||
if not target_date:
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT run_date
|
||||
FROM events
|
||||
WHERE run_date IS NOT NULL AND TRIM(run_date) != ''
|
||||
ORDER BY run_date DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
).fetchone()
|
||||
target_date = row["run_date"] if row else None
|
||||
|
||||
if not target_date:
|
||||
return []
|
||||
|
||||
sql = """
|
||||
SELECT payload_json
|
||||
FROM events
|
||||
WHERE run_date = ?
|
||||
"""
|
||||
params: list[Any] = [target_date]
|
||||
if event_types:
|
||||
placeholders = ",".join("?" for _ in event_types)
|
||||
sql += f" AND event_type IN ({placeholders})"
|
||||
params.extend(event_types)
|
||||
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
|
||||
items: list[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {}
|
||||
if payload:
|
||||
items.append(payload)
|
||||
return items
|
||||
|
||||
def upsert_trade(self, trade: Dict[str, Any]):
|
||||
payload = dict(trade or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
trade_id = payload.get("id") or _hash_key(
|
||||
payload.get("ticker"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO trades
|
||||
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trade_id,
|
||||
payload.get("ticker"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("trading_date"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
return
|
||||
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
_json_dumps(payload.get("reasons")),
|
||||
_json_dumps(payload.get("risks")),
|
||||
payload.get("invalidation"),
|
||||
payload.get("next_action"),
|
||||
payload.get("intrinsic_value"),
|
||||
_json_dumps(payload.get("fair_value_range")),
|
||||
payload.get("value_gap_pct"),
|
||||
_json_dumps(payload.get("valuation_methods")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM signals")
|
||||
for agent in leaderboard:
|
||||
agent_id = agent.get("agentId")
|
||||
agent_name = agent.get("name")
|
||||
role = agent.get("role")
|
||||
for signal in agent.get("signals", []) or []:
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
continue
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
_json_dumps(payload.get("reasons")),
|
||||
_json_dumps(payload.get("risks")),
|
||||
payload.get("invalidation"),
|
||||
payload.get("next_action"),
|
||||
payload.get("intrinsic_value"),
|
||||
_json_dumps(payload.get("fair_value_range")),
|
||||
payload.get("value_gap_pct"),
|
||||
_json_dumps(payload.get("valuation_methods")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def insert_price_point(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
timestamp: str,
|
||||
price: float,
|
||||
open_price: Optional[float] = None,
|
||||
ret: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO price_points
|
||||
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
price_id,
|
||||
ticker,
|
||||
timestamp,
|
||||
price,
|
||||
open_price,
|
||||
ret,
|
||||
source,
|
||||
_json_dumps(meta or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def get_stock_explain_snapshot(
|
||||
self,
|
||||
ticker: str,
|
||||
*,
|
||||
limit_events: int = 24,
|
||||
limit_trades: int = 12,
|
||||
limit_signals: int = 12,
|
||||
) -> Dict[str, list[Dict[str, Any]]]:
|
||||
"""Fetch query-oriented history for a single ticker."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return {"events": [], "trades": [], "signals": []}
|
||||
|
||||
with self._connect() as conn:
|
||||
trade_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM trades
|
||||
WHERE ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_trades),
|
||||
).fetchall()
|
||||
signal_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM signals
|
||||
WHERE ticker = ?
|
||||
ORDER BY trade_date DESC, created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_signals),
|
||||
).fetchall()
|
||||
event_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM events
|
||||
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
|
||||
).fetchall()
|
||||
|
||||
normalized_events = []
|
||||
seen_event_ids: set[str] = set()
|
||||
for row in event_rows:
|
||||
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||
content = str(row["content"] or payload.get("content") or "")
|
||||
title = str(row["title"] or payload.get("title") or "")
|
||||
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
|
||||
continue
|
||||
event_id = row["id"]
|
||||
if event_id in seen_event_ids:
|
||||
continue
|
||||
seen_event_ids.add(event_id)
|
||||
normalized_events.append(
|
||||
{
|
||||
"id": event_id,
|
||||
"type": "mention",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
|
||||
"meta": payload.get("conferenceTitle")
|
||||
or payload.get("feedType")
|
||||
or row["event_type"],
|
||||
"body": content,
|
||||
"tone": "neutral",
|
||||
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
|
||||
},
|
||||
)
|
||||
if len(normalized_events) >= limit_events:
|
||||
break
|
||||
|
||||
normalized_trades = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "trade",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": f"{row['side']} {int(row['qty'] or 0)} 股",
|
||||
"meta": "交易执行",
|
||||
"body": f"成交价 ${float(row['price'] or 0):.2f}",
|
||||
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
|
||||
}
|
||||
for row in trade_rows
|
||||
]
|
||||
|
||||
normalized_signals = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "signal",
|
||||
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
|
||||
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
|
||||
"meta": row["role"],
|
||||
"body": (
|
||||
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
|
||||
if row["real_return"] is not None
|
||||
else "该信号暂未完成后验评估"
|
||||
),
|
||||
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
|
||||
# Extended signal fields
|
||||
"signal": row["signal"],
|
||||
"confidence": row["confidence"],
|
||||
"reasoning": json.loads(row["reasoning_json"]) if row["reasoning_json"] else None,
|
||||
"reasons": json.loads(row["reasons_json"]) if row["reasons_json"] else None,
|
||||
"risks": json.loads(row["risks_json"]) if row["risks_json"] else None,
|
||||
"invalidation": row["invalidation"],
|
||||
"next_action": row["next_action"],
|
||||
"intrinsic_value": row["intrinsic_value"],
|
||||
"fair_value_range": json.loads(row["fair_value_range_json"]) if row["fair_value_range_json"] else None,
|
||||
"value_gap_pct": row["value_gap_pct"],
|
||||
"valuation_methods": json.loads(row["valuation_methods_json"]) if row["valuation_methods_json"] else None,
|
||||
}
|
||||
for row in signal_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"events": normalized_events,
|
||||
"trades": normalized_trades,
|
||||
"signals": normalized_signals,
|
||||
}
|
||||
1256
backend/services/storage.py
Normal file
1256
backend/services/storage.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user