2504 lines
89 KiB
Python
2504 lines
89 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
WebSocket Gateway for frontend communication
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, List, Optional, Set
|
|
|
|
import websockets
|
|
from websockets.asyncio.server import ServerConnection
|
|
|
|
from backend.config.bootstrap_config import (
|
|
get_bootstrap_config_for_run,
|
|
resolve_runtime_config,
|
|
update_bootstrap_values_for_run,
|
|
)
|
|
from backend.agents.agent_workspace import load_agent_workspace_config
|
|
from backend.agents.skills_manager import SkillsManager
|
|
from backend.agents.toolkit_factory import load_agent_profiles
|
|
from backend.data.provider_utils import normalize_symbol
|
|
from backend.data.market_ingest import ingest_symbols
|
|
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
|
from backend.enrich.news_enricher import enrich_news_for_symbol
|
|
from backend.explain.range_explainer import build_range_explanation
|
|
from backend.explain.similarity_service import find_similar_days
|
|
from backend.explain.story_service import get_or_create_stock_story
|
|
from backend.llm.models import get_agent_model_info
|
|
from backend.utils.msg_adapter import FrontendAdapter
|
|
from backend.utils.terminal_dashboard import get_dashboard
|
|
from backend.core.pipeline import TradingPipeline
|
|
from backend.core.state_sync import StateSync
|
|
from backend.services.market import MarketService
|
|
from backend.services.storage import StorageService
|
|
from backend.data.provider_router import get_provider_router
|
|
from backend.tools.data_tools import get_prices
|
|
from backend.tools.data_tools import get_company_news
|
|
from backend.core.scheduler import Scheduler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
EDITABLE_AGENT_WORKSPACE_FILES = {
|
|
"SOUL.md",
|
|
"PROFILE.md",
|
|
"AGENTS.md",
|
|
"MEMORY.md",
|
|
"POLICY.md",
|
|
"HEARTBEAT.md",
|
|
"ROLE.md",
|
|
"STYLE.md",
|
|
}
|
|
|
|
|
|
class Gateway:
|
|
"""WebSocket Gateway for frontend communication"""
|
|
|
|
def __init__(
|
|
self,
|
|
market_service: MarketService,
|
|
storage_service: StorageService,
|
|
pipeline: TradingPipeline,
|
|
state_sync: Optional[StateSync] = None,
|
|
scheduler_callback: Optional[Callable] = None,
|
|
scheduler: Optional[Scheduler] = None,
|
|
config: Dict[str, Any] = None,
|
|
):
|
|
self.market_service = market_service
|
|
self.storage = storage_service
|
|
self.pipeline = pipeline
|
|
self.scheduler_callback = scheduler_callback
|
|
self.scheduler = scheduler
|
|
self.config = config or {}
|
|
|
|
self.mode = self.config.get("mode", "live")
|
|
self.is_backtest = self.mode == "backtest" or self.config.get(
|
|
"backtest_mode",
|
|
False,
|
|
)
|
|
|
|
self.state_sync = state_sync or StateSync(storage=storage_service)
|
|
# self.state_sync.set_mode(self.is_backtest)
|
|
self.state_sync.set_broadcast_fn(self.broadcast)
|
|
self.pipeline.state_sync = self.state_sync
|
|
|
|
self.connected_clients: Set[ServerConnection] = set()
|
|
self.lock = asyncio.Lock()
|
|
self._cycle_lock = asyncio.Lock()
|
|
self._backtest_task: Optional[asyncio.Task] = None
|
|
self._manual_cycle_task: Optional[asyncio.Task] = None
|
|
self._backtest_start_date: Optional[str] = None
|
|
self._backtest_end_date: Optional[str] = None
|
|
self._dashboard = get_dashboard()
|
|
self._market_status_task: Optional[asyncio.Task] = None
|
|
self._watchlist_ingest_task: Optional[asyncio.Task] = None
|
|
|
|
# Session tracking for live returns
|
|
self._session_start_portfolio_value: Optional[float] = None
|
|
self._provider_router = get_provider_router()
|
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self._project_root = Path(__file__).resolve().parents[2]
|
|
|
|
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
|
"""Start gateway server"""
|
|
logger.info(f"Starting gateway on {host}:{port}")
|
|
self._loop = asyncio.get_running_loop()
|
|
self._provider_router.add_listener(self._on_provider_usage_changed)
|
|
|
|
# Initialize terminal dashboard
|
|
self._dashboard.set_config(
|
|
mode=self.mode,
|
|
config_name=self.config.get("config_name", "default"),
|
|
host=host,
|
|
port=port,
|
|
poll_interval=self.config.get("poll_interval", 10),
|
|
mock=self.config.get("mock_mode", False),
|
|
tickers=self.config.get("tickers", []),
|
|
initial_cash=self.storage.initial_cash,
|
|
start_date=self._backtest_start_date or "",
|
|
end_date=self._backtest_end_date or "",
|
|
data_sources=self._provider_router.get_usage_snapshot(),
|
|
)
|
|
self._dashboard.start()
|
|
|
|
self.state_sync.load_state()
|
|
self.market_service.set_price_recorder(self.storage.record_price_point)
|
|
self.state_sync.update_state("status", "running")
|
|
self.state_sync.update_state("server_mode", self.mode)
|
|
self.state_sync.update_state("is_backtest", self.is_backtest)
|
|
self.state_sync.update_state(
|
|
"is_mock_mode",
|
|
self.config.get("mock_mode", False),
|
|
)
|
|
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
|
self.state_sync.update_state(
|
|
"runtime_config",
|
|
{
|
|
"tickers": self.config.get("tickers", []),
|
|
"schedule_mode": self.config.get("schedule_mode", "daily"),
|
|
"interval_minutes": self.config.get("interval_minutes", 60),
|
|
"trigger_time": self.config.get("trigger_time", "09:30"),
|
|
"initial_cash": self.config.get(
|
|
"initial_cash",
|
|
self.storage.initial_cash,
|
|
),
|
|
"margin_requirement": self.config.get("margin_requirement"),
|
|
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
|
"enable_memory": self.config.get("enable_memory", False),
|
|
},
|
|
)
|
|
self.state_sync.update_state(
|
|
"data_sources",
|
|
self._provider_router.get_usage_snapshot(),
|
|
)
|
|
|
|
# Load and display existing portfolio state if available
|
|
summary = self.storage.load_file("summary")
|
|
if summary:
|
|
holdings = self.storage.load_file("holdings") or []
|
|
trades = self.storage.load_file("trades") or []
|
|
current_date = self.state_sync.state.get("current_date")
|
|
self._dashboard.update(
|
|
date=current_date or "-",
|
|
status="running",
|
|
portfolio=summary,
|
|
holdings=holdings,
|
|
trades=trades,
|
|
)
|
|
logger.info(
|
|
"Loaded existing portfolio: $%s",
|
|
f"{summary.get('totalAssetValue', 0):,.2f}",
|
|
)
|
|
|
|
await self.market_service.start(broadcast_func=self.broadcast)
|
|
|
|
if self.scheduler:
|
|
await self.scheduler.start(self.on_strategy_trigger)
|
|
elif self.scheduler_callback:
|
|
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
|
|
|
# Start market status monitoring (only for live mode)
|
|
if not self.is_backtest:
|
|
self._market_status_task = asyncio.create_task(
|
|
self._market_status_monitor(),
|
|
)
|
|
|
|
async with websockets.serve(
|
|
self.handle_client,
|
|
host,
|
|
port,
|
|
ping_interval=30,
|
|
ping_timeout=60,
|
|
):
|
|
logger.info(
|
|
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
|
)
|
|
await asyncio.Future()
|
|
|
|
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
|
|
"""Handle provider routing updates from the shared router."""
|
|
self.state_sync.update_state("data_sources", snapshot)
|
|
self._dashboard.update(data_sources=snapshot)
|
|
if self._loop and self._loop.is_running():
|
|
asyncio.run_coroutine_threadsafe(
|
|
self.broadcast(
|
|
{
|
|
"type": "data_sources_update",
|
|
"data_sources": snapshot,
|
|
},
|
|
),
|
|
self._loop,
|
|
)
|
|
|
|
@property
|
|
def state(self) -> Dict[str, Any]:
|
|
return self.state_sync.state
|
|
|
|
@staticmethod
|
|
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
|
|
if not rows:
|
|
return True
|
|
return all(
|
|
not row.get("sentiment")
|
|
and not row.get("relevance")
|
|
and not row.get("key_discussion")
|
|
for row in rows
|
|
)
|
|
|
|
async def handle_client(self, websocket: ServerConnection):
|
|
"""Handle WebSocket client connection"""
|
|
async with self.lock:
|
|
self.connected_clients.add(websocket)
|
|
|
|
await self._send_initial_state(websocket)
|
|
await self._handle_client_messages(websocket)
|
|
|
|
async with self.lock:
|
|
self.connected_clients.discard(websocket)
|
|
|
|
async def _send_initial_state(self, websocket: ServerConnection):
|
|
state_payload = self.state_sync.get_initial_state_payload(
|
|
include_dashboard=True,
|
|
)
|
|
state_payload["data_sources"] = (
|
|
self._provider_router.get_usage_snapshot()
|
|
)
|
|
# Include market status in initial state
|
|
state_payload[
|
|
"market_status"
|
|
] = self.market_service.get_market_status()
|
|
|
|
# Include live returns if session is active
|
|
if self.storage.is_live_session_active:
|
|
live_returns = self.storage.get_live_returns()
|
|
if "portfolio" in state_payload:
|
|
state_payload["portfolio"].update(live_returns)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{"type": "initial_state", "state": state_payload},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_client_messages(
|
|
self,
|
|
websocket: ServerConnection,
|
|
):
|
|
try:
|
|
async for message in websocket:
|
|
data = json.loads(message)
|
|
msg_type = data.get("type", "unknown")
|
|
|
|
if msg_type == "ping":
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "pong",
|
|
"timestamp": datetime.now().isoformat(),
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
elif msg_type == "get_state":
|
|
await self._send_initial_state(websocket)
|
|
elif msg_type == "start_backtest":
|
|
await self._handle_start_backtest(data)
|
|
elif msg_type == "trigger_strategy":
|
|
await self._handle_manual_trigger(websocket, data)
|
|
elif msg_type == "update_runtime_config":
|
|
await self._handle_update_runtime_config(websocket, data)
|
|
elif msg_type == "reload_runtime_assets":
|
|
await self._handle_reload_runtime_assets()
|
|
elif msg_type == "update_watchlist":
|
|
await self._handle_update_watchlist(websocket, data)
|
|
elif msg_type == "get_agent_skills":
|
|
await self._handle_get_agent_skills(websocket, data)
|
|
elif msg_type == "get_agent_profile":
|
|
await self._handle_get_agent_profile(websocket, data)
|
|
elif msg_type == "get_skill_detail":
|
|
await self._handle_get_skill_detail(websocket, data)
|
|
elif msg_type == "create_agent_local_skill":
|
|
await self._handle_create_agent_local_skill(websocket, data)
|
|
elif msg_type == "update_agent_local_skill":
|
|
await self._handle_update_agent_local_skill(websocket, data)
|
|
elif msg_type == "delete_agent_local_skill":
|
|
await self._handle_delete_agent_local_skill(websocket, data)
|
|
elif msg_type == "remove_agent_skill":
|
|
await self._handle_remove_agent_skill(websocket, data)
|
|
elif msg_type == "update_agent_skill":
|
|
await self._handle_update_agent_skill(websocket, data)
|
|
elif msg_type == "get_agent_workspace_file":
|
|
await self._handle_get_agent_workspace_file(websocket, data)
|
|
elif msg_type == "update_agent_workspace_file":
|
|
await self._handle_update_agent_workspace_file(websocket, data)
|
|
elif msg_type == "get_stock_history":
|
|
await self._handle_get_stock_history(websocket, data)
|
|
elif msg_type == "get_stock_explain_events":
|
|
await self._handle_get_stock_explain_events(websocket, data)
|
|
elif msg_type == "get_stock_news":
|
|
await self._handle_get_stock_news(websocket, data)
|
|
elif msg_type == "get_stock_news_for_date":
|
|
await self._handle_get_stock_news_for_date(websocket, data)
|
|
elif msg_type == "get_stock_news_timeline":
|
|
await self._handle_get_stock_news_timeline(websocket, data)
|
|
elif msg_type == "get_stock_news_categories":
|
|
await self._handle_get_stock_news_categories(websocket, data)
|
|
elif msg_type == "get_stock_range_explain":
|
|
await self._handle_get_stock_range_explain(websocket, data)
|
|
elif msg_type == "get_stock_story":
|
|
await self._handle_get_stock_story(websocket, data)
|
|
elif msg_type == "get_stock_similar_days":
|
|
await self._handle_get_stock_similar_days(websocket, data)
|
|
elif msg_type == "run_stock_enrich":
|
|
await self._handle_run_stock_enrich(websocket, data)
|
|
|
|
except websockets.ConnectionClosed:
|
|
pass
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
async def _handle_get_stock_history(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
if not ticker:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_history_loaded",
|
|
"ticker": "",
|
|
"prices": [],
|
|
"source": None,
|
|
"error": "invalid ticker",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
lookback_days = data.get("lookback_days", 90)
|
|
try:
|
|
lookback_days = max(7, min(int(lookback_days), 365))
|
|
except (TypeError, ValueError):
|
|
lookback_days = 90
|
|
|
|
end_date = self.state_sync.state.get("current_date")
|
|
if not end_date:
|
|
end_date = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
try:
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
|
except ValueError:
|
|
end_dt = datetime.now()
|
|
end_date = end_dt.strftime("%Y-%m-%d")
|
|
|
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
|
"%Y-%m-%d",
|
|
)
|
|
|
|
prices = await asyncio.to_thread(
|
|
self.storage.market_store.get_ohlc,
|
|
ticker,
|
|
start_date,
|
|
end_date,
|
|
)
|
|
source = "polygon"
|
|
if not prices:
|
|
prices = await asyncio.to_thread(
|
|
get_prices,
|
|
ticker,
|
|
start_date,
|
|
end_date,
|
|
)
|
|
usage_snapshot = self._provider_router.get_usage_snapshot()
|
|
source = usage_snapshot.get("last_success", {}).get("prices")
|
|
if prices:
|
|
await asyncio.to_thread(
|
|
self.storage.market_store.upsert_ohlc,
|
|
ticker,
|
|
[price.model_dump() for price in prices],
|
|
source=source or "provider",
|
|
)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_history_loaded",
|
|
"ticker": ticker,
|
|
"prices": [
|
|
price if isinstance(price, dict) else price.model_dump()
|
|
for price in prices
|
|
][-120:],
|
|
"source": source,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_explain_events(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_explain_events_loaded",
|
|
"ticker": ticker,
|
|
"events": snapshot.get("events", []),
|
|
"signals": snapshot.get("signals", []),
|
|
"trades": snapshot.get("trades", []),
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_news(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
if not ticker:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_loaded",
|
|
"ticker": "",
|
|
"news": [],
|
|
"source": None,
|
|
"error": "invalid ticker",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
lookback_days = data.get("lookback_days", 30)
|
|
limit = data.get("limit", 12)
|
|
try:
|
|
lookback_days = max(7, min(int(lookback_days), 180))
|
|
except (TypeError, ValueError):
|
|
lookback_days = 30
|
|
try:
|
|
limit = max(1, min(int(limit), 30))
|
|
except (TypeError, ValueError):
|
|
limit = 12
|
|
|
|
end_date = self.state_sync.state.get("current_date")
|
|
if not end_date:
|
|
end_date = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
try:
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
|
except ValueError:
|
|
end_dt = datetime.now()
|
|
end_date = end_dt.strftime("%Y-%m-%d")
|
|
|
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
|
"%Y-%m-%d",
|
|
)
|
|
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
)
|
|
source = "polygon"
|
|
if self._news_rows_need_enrichment(news_rows):
|
|
news = await asyncio.to_thread(
|
|
get_company_news,
|
|
ticker,
|
|
end_date,
|
|
start_date,
|
|
limit,
|
|
)
|
|
if news:
|
|
usage_snapshot = self._provider_router.get_usage_snapshot()
|
|
source = usage_snapshot.get("last_success", {}).get("company_news")
|
|
await asyncio.to_thread(
|
|
self.storage.market_store.upsert_news,
|
|
ticker,
|
|
[item.model_dump() for item in news],
|
|
source=source or "provider",
|
|
)
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=max(limit, 50),
|
|
)
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
)
|
|
source = source or "market_store"
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_loaded",
|
|
"ticker": ticker,
|
|
"news": news_rows[-limit:],
|
|
"source": source,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_news_for_date(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
trade_date = str(data.get("date") or "").strip()
|
|
if not ticker or not trade_date:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_for_date_loaded",
|
|
"ticker": ticker,
|
|
"date": trade_date,
|
|
"news": [],
|
|
"error": "ticker and date are required",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
limit = data.get("limit", 20)
|
|
try:
|
|
limit = max(1, min(int(limit), 50))
|
|
except (TypeError, ValueError):
|
|
limit = 20
|
|
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
trade_date=trade_date,
|
|
limit=limit,
|
|
)
|
|
if self._news_rows_need_enrichment(news_rows):
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=trade_date,
|
|
end_date=trade_date,
|
|
limit=limit,
|
|
)
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
trade_date=trade_date,
|
|
limit=limit,
|
|
)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_for_date_loaded",
|
|
"ticker": ticker,
|
|
"date": trade_date,
|
|
"news": news_rows,
|
|
"source": "market_store",
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_news_timeline(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
if not ticker:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_timeline_loaded",
|
|
"ticker": "",
|
|
"timeline": [],
|
|
"error": "invalid ticker",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
lookback_days = data.get("lookback_days", 90)
|
|
try:
|
|
lookback_days = max(7, min(int(lookback_days), 365))
|
|
except (TypeError, ValueError):
|
|
lookback_days = 90
|
|
|
|
end_date = self.state_sync.state.get("current_date")
|
|
if not end_date:
|
|
end_date = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
try:
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
|
except ValueError:
|
|
end_dt = datetime.now()
|
|
end_date = end_dt.strftime("%Y-%m-%d")
|
|
|
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
|
"%Y-%m-%d",
|
|
)
|
|
timeline = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_timeline_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
if not timeline:
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=200,
|
|
)
|
|
timeline = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_timeline_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_timeline_loaded",
|
|
"ticker": ticker,
|
|
"timeline": timeline,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_news_categories(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
if not ticker:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_categories_loaded",
|
|
"ticker": "",
|
|
"categories": {},
|
|
"error": "invalid ticker",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
lookback_days = data.get("lookback_days", 90)
|
|
try:
|
|
lookback_days = max(7, min(int(lookback_days), 365))
|
|
except (TypeError, ValueError):
|
|
lookback_days = 90
|
|
|
|
end_date = self.state_sync.state.get("current_date")
|
|
if not end_date:
|
|
end_date = datetime.now().strftime("%Y-%m-%d")
|
|
|
|
try:
|
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
|
except ValueError:
|
|
end_dt = datetime.now()
|
|
end_date = end_dt.strftime("%Y-%m-%d")
|
|
|
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
|
"%Y-%m-%d",
|
|
)
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=200,
|
|
)
|
|
if self._news_rows_need_enrichment(news_rows):
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=200,
|
|
)
|
|
categories = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_categories_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=200,
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_news_categories_loaded",
|
|
"ticker": ticker,
|
|
"categories": categories,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_range_explain(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
start_date = str(data.get("start_date") or "").strip()
|
|
end_date = str(data.get("end_date") or "").strip()
|
|
if not ticker or not start_date or not end_date:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_range_explain_loaded",
|
|
"ticker": ticker,
|
|
"result": {"error": "ticker, start_date, end_date are required"},
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
article_ids = data.get("article_ids")
|
|
if isinstance(article_ids, list) and article_ids:
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_by_ids_enriched,
|
|
ticker,
|
|
article_ids,
|
|
)
|
|
if self._news_rows_need_enrichment(news_rows):
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=100,
|
|
)
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_by_ids_enriched,
|
|
ticker,
|
|
article_ids,
|
|
)
|
|
else:
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=100,
|
|
)
|
|
if not news_rows:
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=100,
|
|
)
|
|
news_rows = await asyncio.to_thread(
|
|
self.storage.market_store.get_news_items_enriched,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=100,
|
|
)
|
|
|
|
result = await asyncio.to_thread(
|
|
build_range_explanation,
|
|
ticker=ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
news_rows=news_rows,
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_range_explain_loaded",
|
|
"ticker": ticker,
|
|
"result": result,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_story(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
if not ticker:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_story_loaded",
|
|
"ticker": "",
|
|
"story": "",
|
|
"error": "invalid ticker",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
as_of_date = str(
|
|
data.get("as_of_date")
|
|
or self.state_sync.state.get("current_date")
|
|
or datetime.now().strftime("%Y-%m-%d")
|
|
).strip()[:10]
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
end_date=as_of_date,
|
|
limit=80,
|
|
)
|
|
result = await asyncio.to_thread(
|
|
get_or_create_stock_story,
|
|
self.storage.market_store,
|
|
symbol=ticker,
|
|
as_of_date=as_of_date,
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_story_loaded",
|
|
"ticker": ticker,
|
|
"as_of_date": as_of_date,
|
|
"story": result.get("story") or "",
|
|
"source": result.get("source") or "local",
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_stock_similar_days(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
target_date = str(data.get("date") or "").strip()[:10]
|
|
if not ticker or not target_date:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_similar_days_loaded",
|
|
"ticker": ticker,
|
|
"date": target_date,
|
|
"items": [],
|
|
"error": "ticker and date are required",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
top_k = data.get("top_k", 8)
|
|
try:
|
|
top_k = max(1, min(int(top_k), 20))
|
|
except (TypeError, ValueError):
|
|
top_k = 8
|
|
|
|
await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
end_date=target_date,
|
|
limit=200,
|
|
)
|
|
result = await asyncio.to_thread(
|
|
find_similar_days,
|
|
self.storage.market_store,
|
|
symbol=ticker,
|
|
target_date=target_date,
|
|
top_k=top_k,
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_similar_days_loaded",
|
|
"ticker": ticker,
|
|
"date": target_date,
|
|
**result,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_run_stock_enrich(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
):
|
|
ticker = normalize_symbol(data.get("ticker", ""))
|
|
start_date = str(data.get("start_date") or "").strip()[:10]
|
|
end_date = str(data.get("end_date") or "").strip()[:10]
|
|
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
|
target_date = str(data.get("target_date") or "").strip()[:10]
|
|
force = bool(data.get("force", False))
|
|
rebuild_story = bool(data.get("rebuild_story", True))
|
|
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
|
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
|
limit = data.get("limit", 200)
|
|
|
|
try:
|
|
limit = max(10, min(int(limit), 500))
|
|
except (TypeError, ValueError):
|
|
limit = 200
|
|
|
|
if not ticker or not start_date or not end_date:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_enrich_completed",
|
|
"ticker": ticker,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"error": "ticker, start_date, end_date are required",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
if only_local_to_llm and not llm_enrichment_enabled():
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_enrich_completed",
|
|
"ticker": ticker,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
result = await asyncio.to_thread(
|
|
enrich_news_for_symbol,
|
|
self.storage.market_store,
|
|
ticker,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
limit=limit,
|
|
skip_existing=not force,
|
|
only_reanalyze_local=only_local_to_llm,
|
|
)
|
|
|
|
story_status = None
|
|
if rebuild_story and story_date:
|
|
await asyncio.to_thread(
|
|
self.storage.market_store.delete_story_cache,
|
|
ticker,
|
|
as_of_date=story_date,
|
|
)
|
|
story_result = await asyncio.to_thread(
|
|
get_or_create_stock_story,
|
|
self.storage.market_store,
|
|
symbol=ticker,
|
|
as_of_date=story_date,
|
|
)
|
|
story_status = {
|
|
"as_of_date": story_date,
|
|
"source": story_result.get("source") or "local",
|
|
}
|
|
|
|
similar_status = None
|
|
if rebuild_similar_days and target_date:
|
|
await asyncio.to_thread(
|
|
self.storage.market_store.delete_similar_day_cache,
|
|
ticker,
|
|
target_date=target_date,
|
|
)
|
|
similar_result = await asyncio.to_thread(
|
|
find_similar_days,
|
|
self.storage.market_store,
|
|
symbol=ticker,
|
|
target_date=target_date,
|
|
top_k=8,
|
|
)
|
|
similar_status = {
|
|
"target_date": target_date,
|
|
"count": len(similar_result.get("items") or []),
|
|
"error": similar_result.get("error"),
|
|
}
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "stock_enrich_completed",
|
|
"ticker": ticker,
|
|
"start_date": start_date,
|
|
"end_date": end_date,
|
|
"story_date": story_date or None,
|
|
"target_date": target_date or None,
|
|
"force": force,
|
|
"only_local_to_llm": only_local_to_llm,
|
|
"stats": result,
|
|
"story_status": story_status,
|
|
"similar_status": similar_status,
|
|
},
|
|
ensure_ascii=False,
|
|
default=str,
|
|
),
|
|
)
|
|
|
|
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
|
if not self.is_backtest:
|
|
return
|
|
dates = data.get("dates", [])
|
|
if dates and self._backtest_task is None:
|
|
task = asyncio.create_task(
|
|
self._run_backtest_dates(dates),
|
|
)
|
|
task.add_done_callback(self._handle_backtest_exception)
|
|
self._backtest_task = task
|
|
|
|
async def _handle_manual_trigger(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Run one live/mock trading cycle on demand."""
|
|
if self.is_backtest:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "Manual trigger is only available in live/mock mode.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
if (
|
|
self._cycle_lock.locked()
|
|
or (
|
|
self._manual_cycle_task is not None
|
|
and not self._manual_cycle_task.done()
|
|
)
|
|
):
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "A trading cycle is already running.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self.state_sync.on_system_message("已有任务在运行,已忽略手动触发")
|
|
return
|
|
|
|
requested_date = data.get("date")
|
|
await self.state_sync.on_system_message("收到手动触发请求,准备开始新一轮分析与决策")
|
|
task = asyncio.create_task(
|
|
self.on_strategy_trigger(
|
|
date=requested_date or datetime.now().strftime("%Y-%m-%d"),
|
|
),
|
|
)
|
|
task.add_done_callback(self._handle_manual_cycle_exception)
|
|
self._manual_cycle_task = task
|
|
|
|
async def _handle_reload_runtime_assets(self):
|
|
"""Reload prompt, skills, and safe runtime config without restart."""
|
|
config_name = self.config.get("config_name", "default")
|
|
runtime_config = resolve_runtime_config(
|
|
project_root=self._project_root,
|
|
config_name=config_name,
|
|
enable_memory=self.config.get("enable_memory", False),
|
|
schedule_mode=self.config.get("schedule_mode", "daily"),
|
|
interval_minutes=self.config.get("interval_minutes", 60),
|
|
trigger_time=self.config.get("trigger_time", "09:30"),
|
|
)
|
|
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
|
runtime_updates = self._apply_runtime_config(runtime_config)
|
|
await self.state_sync.on_system_message(
|
|
"Runtime assets reloaded.",
|
|
)
|
|
await self.broadcast(
|
|
{
|
|
"type": "runtime_assets_reloaded",
|
|
**result,
|
|
**runtime_updates,
|
|
},
|
|
)
|
|
|
|
async def _handle_update_runtime_config(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Persist selected runtime settings and hot-reload them."""
|
|
updates: Dict[str, Any] = {}
|
|
|
|
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
|
|
if schedule_mode:
|
|
if schedule_mode not in {"daily", "intraday"}:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "schedule_mode must be 'daily' or 'intraday'.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
updates["schedule_mode"] = schedule_mode
|
|
|
|
interval_minutes = data.get("interval_minutes")
|
|
if interval_minutes is not None:
|
|
try:
|
|
parsed_interval = int(interval_minutes)
|
|
except (TypeError, ValueError):
|
|
parsed_interval = 0
|
|
if parsed_interval <= 0:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "interval_minutes must be a positive integer.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
updates["interval_minutes"] = parsed_interval
|
|
|
|
trigger_time = data.get("trigger_time")
|
|
if trigger_time is not None:
|
|
raw_trigger = str(trigger_time).strip()
|
|
if raw_trigger and raw_trigger != "now":
|
|
try:
|
|
datetime.strptime(raw_trigger, "%H:%M")
|
|
except ValueError:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "trigger_time must use HH:MM or 'now'.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
updates["trigger_time"] = raw_trigger or "09:30"
|
|
|
|
max_comm_cycles = data.get("max_comm_cycles")
|
|
if max_comm_cycles is not None:
|
|
try:
|
|
parsed_cycles = int(max_comm_cycles)
|
|
except (TypeError, ValueError):
|
|
parsed_cycles = 0
|
|
if parsed_cycles <= 0:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "max_comm_cycles must be a positive integer.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
updates["max_comm_cycles"] = parsed_cycles
|
|
|
|
if not updates:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "No runtime settings were provided.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
update_bootstrap_values_for_run(
|
|
project_root=self._project_root,
|
|
config_name=config_name,
|
|
updates=updates,
|
|
)
|
|
await self.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
|
|
await self._handle_reload_runtime_assets()
|
|
|
|
async def _handle_update_watchlist(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
|
|
tickers = self._normalize_watchlist(data.get("tickers"))
|
|
if not tickers:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "update_watchlist requires at least one valid ticker.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
update_bootstrap_values_for_run(
|
|
project_root=self._project_root,
|
|
config_name=config_name,
|
|
updates={"tickers": tickers},
|
|
)
|
|
await self.state_sync.on_system_message(
|
|
f"Watchlist updated: {', '.join(tickers)}",
|
|
)
|
|
await self.broadcast(
|
|
{
|
|
"type": "watchlist_updated",
|
|
"config_name": config_name,
|
|
"tickers": tickers,
|
|
},
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
self._schedule_watchlist_market_store_refresh(tickers)
|
|
|
|
async def _handle_get_agent_skills(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Return skill catalog and status for one agent."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
if not agent_id:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "get_agent_skills requires agent_id.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
|
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
|
resolved_skills = set(
|
|
skills_manager.resolve_agent_skill_names(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
default_skills=[],
|
|
),
|
|
)
|
|
enabled = set(agent_config.enabled_skills)
|
|
disabled = set(agent_config.disabled_skills)
|
|
|
|
payload = []
|
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
|
|
if item.skill_name in disabled:
|
|
status = "disabled"
|
|
elif item.skill_name in enabled:
|
|
status = "enabled"
|
|
elif item.skill_name in resolved_skills:
|
|
status = "active"
|
|
else:
|
|
status = "available"
|
|
payload.append(
|
|
{
|
|
"skill_name": item.skill_name,
|
|
"name": item.name,
|
|
"description": item.description,
|
|
"version": item.version,
|
|
"source": item.source,
|
|
"tools": item.tools,
|
|
"status": status,
|
|
},
|
|
)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_skills_loaded",
|
|
"config_name": config_name,
|
|
"agent_id": agent_id,
|
|
"skills": payload,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_agent_profile(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Return structured profile/config summary for one agent."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
if not agent_id:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "get_agent_profile requires agent_id.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
|
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
|
profiles = load_agent_profiles()
|
|
profile = profiles.get(agent_id, {})
|
|
bootstrap = get_bootstrap_config_for_run(self._project_root, config_name)
|
|
override = bootstrap.agent_override(agent_id)
|
|
active_tool_groups = override.get(
|
|
"active_tool_groups",
|
|
agent_config.active_tool_groups or profile.get("active_tool_groups", []),
|
|
)
|
|
if not isinstance(active_tool_groups, list):
|
|
active_tool_groups = []
|
|
disabled_tool_groups = agent_config.disabled_tool_groups
|
|
if disabled_tool_groups:
|
|
disabled_set = set(disabled_tool_groups)
|
|
active_tool_groups = [
|
|
group_name
|
|
for group_name in active_tool_groups
|
|
if group_name not in disabled_set
|
|
]
|
|
|
|
default_skills = profile.get("skills", [])
|
|
if not isinstance(default_skills, list):
|
|
default_skills = []
|
|
resolved_skills = skills_manager.resolve_agent_skill_names(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
default_skills=default_skills,
|
|
)
|
|
prompt_files = agent_config.prompt_files or [
|
|
"SOUL.md",
|
|
"PROFILE.md",
|
|
"AGENTS.md",
|
|
"POLICY.md",
|
|
"MEMORY.md",
|
|
]
|
|
model_name, model_provider = get_agent_model_info(agent_id)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_profile_loaded",
|
|
"config_name": config_name,
|
|
"agent_id": agent_id,
|
|
"profile": {
|
|
"model_name": model_name,
|
|
"model_provider": model_provider,
|
|
"prompt_files": prompt_files,
|
|
"default_skills": default_skills,
|
|
"resolved_skills": resolved_skills,
|
|
"active_tool_groups": active_tool_groups,
|
|
"disabled_tool_groups": disabled_tool_groups,
|
|
"enabled_skills": agent_config.enabled_skills,
|
|
"disabled_skills": agent_config.disabled_skills,
|
|
},
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
|
|
async def _handle_get_skill_detail(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Return full SKILL.md body for one skill."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
if not skill_name:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "get_skill_detail requires skill_name.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
try:
|
|
if agent_id:
|
|
config_name = self.config.get("config_name", "default")
|
|
detail = skills_manager.load_agent_skill_document(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
skill_name=skill_name,
|
|
)
|
|
else:
|
|
detail = skills_manager.load_skill_document(skill_name)
|
|
except FileNotFoundError:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": f"Unknown skill: {skill_name}",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "skill_detail_loaded",
|
|
"agent_id": agent_id,
|
|
"skill": detail,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
|
|
async def _handle_create_agent_local_skill(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Create a new local skill for one agent and hot-reload runtime assets."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
if not agent_id or not skill_name:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "create_agent_local_skill requires agent_id and skill_name.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
try:
|
|
skills_manager.create_agent_local_skill(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
skill_name=skill_name,
|
|
)
|
|
except (ValueError, FileExistsError) as exc:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{"type": "error", "message": str(exc)},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
await self.state_sync.on_system_message(
|
|
f"Created local skill {skill_name} for {agent_id}",
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_local_skill_created",
|
|
"agent_id": agent_id,
|
|
"skill_name": skill_name,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_get_agent_skills(websocket, {"agent_id": agent_id})
|
|
await self._handle_get_skill_detail(
|
|
websocket,
|
|
{"agent_id": agent_id, "skill_name": skill_name},
|
|
)
|
|
|
|
async def _handle_update_agent_local_skill(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Update one agent-local SKILL.md and hot-reload runtime assets."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
content = data.get("content")
|
|
if not agent_id or not skill_name or not isinstance(content, str):
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "update_agent_local_skill requires agent_id, skill_name, and string content.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
try:
|
|
skills_manager.update_agent_local_skill(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
skill_name=skill_name,
|
|
content=content,
|
|
)
|
|
except (ValueError, FileNotFoundError) as exc:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{"type": "error", "message": str(exc)},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
await self.state_sync.on_system_message(
|
|
f"Updated local skill {skill_name} for {agent_id}",
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_local_skill_updated",
|
|
"agent_id": agent_id,
|
|
"skill_name": skill_name,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_get_skill_detail(
|
|
websocket,
|
|
{"agent_id": agent_id, "skill_name": skill_name},
|
|
)
|
|
|
|
async def _handle_delete_agent_local_skill(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Delete one agent-local skill and hot-reload runtime assets."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
if not agent_id or not skill_name:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "delete_agent_local_skill requires agent_id and skill_name.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
try:
|
|
skills_manager.delete_agent_local_skill(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
skill_name=skill_name,
|
|
)
|
|
skills_manager.forget_agent_skill_overrides(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
skill_names=[skill_name],
|
|
)
|
|
except (ValueError, FileNotFoundError) as exc:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{"type": "error", "message": str(exc)},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
await self.state_sync.on_system_message(
|
|
f"Deleted local skill {skill_name} for {agent_id}",
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_local_skill_deleted",
|
|
"agent_id": agent_id,
|
|
"skill_name": skill_name,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_get_agent_skills(websocket, {"agent_id": agent_id})
|
|
|
|
async def _handle_remove_agent_skill(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Remove one shared skill from the agent's installed set."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
if not agent_id or not skill_name:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "remove_agent_skill requires agent_id and skill_name.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
skill_names = {
|
|
item.skill_name
|
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
|
if item.source != "local"
|
|
}
|
|
if skill_name not in skill_names:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{"type": "error", "message": f"Unknown shared skill: {skill_name}"},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
skills_manager.update_agent_skill_overrides(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
disable=[skill_name],
|
|
)
|
|
await self.state_sync.on_system_message(
|
|
f"Removed shared skill {skill_name} from {agent_id}",
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_skill_removed",
|
|
"agent_id": agent_id,
|
|
"skill_name": skill_name,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_get_agent_skills(websocket, {"agent_id": agent_id})
|
|
|
|
async def _handle_update_agent_skill(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Enable or disable one skill for one agent and hot-reload assets."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
skill_name = str(data.get("skill_name", "")).strip()
|
|
enabled = data.get("enabled")
|
|
if not agent_id or not skill_name or not isinstance(enabled, bool):
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "update_agent_skill requires agent_id, skill_name, and boolean enabled.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
skill_names = {
|
|
item.skill_name
|
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
|
}
|
|
if skill_name not in skill_names:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": f"Unknown skill: {skill_name}",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
if enabled:
|
|
skills_manager.update_agent_skill_overrides(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
enable=[skill_name],
|
|
)
|
|
await self.state_sync.on_system_message(
|
|
f"Enabled skill {skill_name} for {agent_id}",
|
|
)
|
|
else:
|
|
skills_manager.update_agent_skill_overrides(
|
|
config_name=config_name,
|
|
agent_id=agent_id,
|
|
disable=[skill_name],
|
|
)
|
|
await self.state_sync.on_system_message(
|
|
f"Disabled skill {skill_name} for {agent_id}",
|
|
)
|
|
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_skill_updated",
|
|
"agent_id": agent_id,
|
|
"skill_name": skill_name,
|
|
"enabled": enabled,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await self._handle_get_agent_skills(
|
|
websocket,
|
|
{"agent_id": agent_id},
|
|
)
|
|
|
|
async def _handle_get_agent_workspace_file(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Load one editable agent workspace markdown file."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
filename = self._normalize_agent_workspace_filename(data.get("filename"))
|
|
if not agent_id or not filename:
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "get_agent_workspace_file requires agent_id and supported filename.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
|
path = asset_dir / filename
|
|
content = path.read_text(encoding="utf-8") if path.exists() else ""
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_workspace_file_loaded",
|
|
"config_name": config_name,
|
|
"agent_id": agent_id,
|
|
"filename": filename,
|
|
"content": content,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
|
|
async def _handle_update_agent_workspace_file(
|
|
self,
|
|
websocket: ServerConnection,
|
|
data: Dict[str, Any],
|
|
) -> None:
|
|
"""Persist one editable agent workspace markdown file and hot-reload."""
|
|
agent_id = str(data.get("agent_id", "")).strip()
|
|
filename = self._normalize_agent_workspace_filename(data.get("filename"))
|
|
content = data.get("content")
|
|
if not agent_id or not filename or not isinstance(content, str):
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "error",
|
|
"message": "update_agent_workspace_file requires agent_id, supported filename, and string content.",
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
return
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
skills_manager = SkillsManager(project_root=self._project_root)
|
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
|
path = asset_dir / filename
|
|
path.write_text(content, encoding="utf-8")
|
|
await self.state_sync.on_system_message(
|
|
f"Updated {filename} for {agent_id}",
|
|
)
|
|
await websocket.send(
|
|
json.dumps(
|
|
{
|
|
"type": "agent_workspace_file_updated",
|
|
"agent_id": agent_id,
|
|
"filename": filename,
|
|
},
|
|
ensure_ascii=False,
|
|
),
|
|
)
|
|
await self._handle_reload_runtime_assets()
|
|
await self._handle_get_agent_workspace_file(
|
|
websocket,
|
|
{"agent_id": agent_id, "filename": filename},
|
|
)
|
|
|
|
@staticmethod
|
|
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
|
"""Parse watchlist payloads from websocket messages."""
|
|
if raw_tickers is None:
|
|
return []
|
|
|
|
if isinstance(raw_tickers, str):
|
|
candidates = raw_tickers.split(",")
|
|
elif isinstance(raw_tickers, list):
|
|
candidates = raw_tickers
|
|
else:
|
|
candidates = [raw_tickers]
|
|
|
|
tickers: List[str] = []
|
|
for candidate in candidates:
|
|
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
|
if symbol and symbol not in tickers:
|
|
tickers.append(symbol)
|
|
return tickers
|
|
|
|
@staticmethod
|
|
def _normalize_agent_workspace_filename(raw_name: Any) -> Optional[str]:
|
|
"""Restrict editable workspace files to a safe allowlist."""
|
|
filename = str(raw_name or "").strip()
|
|
if filename in EDITABLE_AGENT_WORKSPACE_FILES:
|
|
return filename
|
|
return None
|
|
|
|
def _apply_runtime_config(
|
|
self,
|
|
runtime_config: Dict[str, Any],
|
|
) -> Dict[str, Any]:
|
|
"""Apply runtime config to gateway-owned services and state."""
|
|
warnings: List[str] = []
|
|
|
|
ticker_changes = self.market_service.update_tickers(
|
|
runtime_config.get("tickers", []),
|
|
)
|
|
self.config["tickers"] = ticker_changes["active"]
|
|
|
|
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
|
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
|
|
self.config["schedule_mode"] = runtime_config.get(
|
|
"schedule_mode",
|
|
self.config.get("schedule_mode", "daily"),
|
|
)
|
|
self.config["interval_minutes"] = int(
|
|
runtime_config.get(
|
|
"interval_minutes",
|
|
self.config.get("interval_minutes", 60),
|
|
),
|
|
)
|
|
self.config["trigger_time"] = runtime_config.get(
|
|
"trigger_time",
|
|
self.config.get("trigger_time", "09:30"),
|
|
)
|
|
|
|
if self.scheduler:
|
|
self.scheduler.reconfigure(
|
|
mode=self.config["schedule_mode"],
|
|
trigger_time=self.config["trigger_time"],
|
|
interval_minutes=self.config["interval_minutes"],
|
|
)
|
|
|
|
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
|
|
margin_requirement=runtime_config["margin_requirement"],
|
|
)
|
|
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
|
|
"margin_requirement",
|
|
runtime_config["margin_requirement"],
|
|
)
|
|
|
|
requested_initial_cash = float(runtime_config["initial_cash"])
|
|
current_initial_cash = float(self.storage.initial_cash)
|
|
initial_cash_applied = requested_initial_cash == current_initial_cash
|
|
if not initial_cash_applied:
|
|
if (
|
|
self.storage.can_apply_initial_cash()
|
|
and self.pipeline.pm.can_apply_initial_cash()
|
|
):
|
|
initial_cash_applied = self.storage.apply_initial_cash(
|
|
requested_initial_cash,
|
|
)
|
|
if initial_cash_applied:
|
|
self.pipeline.pm.apply_runtime_portfolio_config(
|
|
initial_cash=requested_initial_cash,
|
|
)
|
|
self.config["initial_cash"] = self.storage.initial_cash
|
|
else:
|
|
warnings.append(
|
|
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
|
"because the run already has positions, margin usage, or trades.",
|
|
)
|
|
|
|
requested_enable_memory = bool(runtime_config["enable_memory"])
|
|
current_enable_memory = bool(self.config.get("enable_memory", False))
|
|
if requested_enable_memory != current_enable_memory:
|
|
warnings.append(
|
|
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
|
"because long-term memory contexts are created at startup.",
|
|
)
|
|
|
|
self._sync_runtime_state()
|
|
|
|
return {
|
|
"runtime_config_requested": runtime_config,
|
|
"runtime_config_applied": {
|
|
"tickers": list(self.config.get("tickers", [])),
|
|
"schedule_mode": self.config.get("schedule_mode", "daily"),
|
|
"interval_minutes": self.config.get("interval_minutes", 60),
|
|
"trigger_time": self.config.get("trigger_time", "09:30"),
|
|
"initial_cash": self.storage.initial_cash,
|
|
"margin_requirement": self.config["margin_requirement"],
|
|
"max_comm_cycles": self.config["max_comm_cycles"],
|
|
"enable_memory": self.config.get("enable_memory", False),
|
|
},
|
|
"runtime_config_status": {
|
|
"tickers": True,
|
|
"schedule_mode": True,
|
|
"interval_minutes": True,
|
|
"trigger_time": True,
|
|
"initial_cash": initial_cash_applied,
|
|
"margin_requirement": pm_apply_result["margin_requirement"],
|
|
"max_comm_cycles": True,
|
|
"enable_memory": requested_enable_memory == current_enable_memory,
|
|
},
|
|
"ticker_changes": ticker_changes,
|
|
"runtime_config_warnings": warnings,
|
|
}
|
|
|
|
def _sync_runtime_state(self) -> None:
|
|
"""Refresh persisted state and dashboard after runtime config changes."""
|
|
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
|
self.state_sync.update_state(
|
|
"runtime_config",
|
|
{
|
|
"tickers": self.config.get("tickers", []),
|
|
"schedule_mode": self.config.get("schedule_mode", "daily"),
|
|
"interval_minutes": self.config.get("interval_minutes", 60),
|
|
"trigger_time": self.config.get("trigger_time", "09:30"),
|
|
"initial_cash": self.storage.initial_cash,
|
|
"margin_requirement": self.config.get("margin_requirement"),
|
|
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
|
"enable_memory": self.config.get("enable_memory", False),
|
|
},
|
|
)
|
|
|
|
self.storage.update_server_state_from_dashboard(self.state_sync.state)
|
|
self.state_sync.save_state()
|
|
|
|
self._dashboard.tickers = list(self.config.get("tickers", []))
|
|
self._dashboard.initial_cash = self.storage.initial_cash
|
|
self._dashboard.enable_memory = bool(
|
|
self.config.get("enable_memory", False),
|
|
)
|
|
|
|
summary = self.storage.load_file("summary") or {}
|
|
holdings = self.storage.load_file("holdings") or []
|
|
trades = self.storage.load_file("trades") or []
|
|
self._dashboard.update(
|
|
portfolio=summary,
|
|
holdings=holdings,
|
|
trades=trades,
|
|
)
|
|
|
|
def _schedule_watchlist_market_store_refresh(
|
|
self,
|
|
tickers: List[str],
|
|
) -> None:
|
|
"""Kick off a non-blocking Polygon refresh for the updated watchlist."""
|
|
if not tickers:
|
|
return
|
|
if self._watchlist_ingest_task and not self._watchlist_ingest_task.done():
|
|
self._watchlist_ingest_task.cancel()
|
|
self._watchlist_ingest_task = asyncio.create_task(
|
|
self._refresh_market_store_for_watchlist(tickers),
|
|
)
|
|
|
|
async def _refresh_market_store_for_watchlist(
|
|
self,
|
|
tickers: List[str],
|
|
) -> None:
|
|
"""Refresh the long-lived market store after a watchlist update."""
|
|
try:
|
|
await self.state_sync.on_system_message(
|
|
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
|
)
|
|
results = await asyncio.to_thread(
|
|
ingest_symbols,
|
|
tickers,
|
|
mode="incremental",
|
|
)
|
|
summary = ", ".join(
|
|
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
|
for item in results
|
|
)
|
|
await self.state_sync.on_system_message(
|
|
f"自选股市场数据已同步: {summary}",
|
|
)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
logger.warning("Watchlist market store refresh failed: %s", exc)
|
|
await self.state_sync.on_system_message(
|
|
f"自选股市场数据同步失败: {exc}",
|
|
)
|
|
|
|
async def broadcast(self, message: Dict[str, Any]):
|
|
"""Broadcast message to all connected clients"""
|
|
if not self.connected_clients:
|
|
return
|
|
|
|
message_json = json.dumps(message, ensure_ascii=False, default=str)
|
|
|
|
async with self.lock:
|
|
tasks = [
|
|
self._send_to_client(client, message_json)
|
|
for client in self.connected_clients.copy()
|
|
]
|
|
|
|
if tasks:
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
async def _send_to_client(
|
|
self,
|
|
client: ServerConnection,
|
|
message: str,
|
|
):
|
|
try:
|
|
await client.send(message)
|
|
except websockets.ConnectionClosed:
|
|
async with self.lock:
|
|
self.connected_clients.discard(client)
|
|
|
|
async def _market_status_monitor(self):
|
|
"""Periodically check and broadcast market status changes"""
|
|
while True:
|
|
try:
|
|
await self.market_service.check_and_broadcast_market_status()
|
|
|
|
# On market open, start live session tracking
|
|
status = self.market_service.get_market_status()
|
|
if (
|
|
status["status"] == "open"
|
|
and not self.storage.is_live_session_active
|
|
):
|
|
self.storage.start_live_session()
|
|
summary = self.storage.load_file("summary") or {}
|
|
self._session_start_portfolio_value = summary.get(
|
|
"totalAssetValue",
|
|
self.storage.initial_cash,
|
|
)
|
|
logger.info(
|
|
"Session start portfolio: "
|
|
f"${self._session_start_portfolio_value:,.2f}",
|
|
)
|
|
elif (
|
|
status["status"] != "open"
|
|
and self.storage.is_live_session_active
|
|
):
|
|
self.storage.end_live_session()
|
|
self._session_start_portfolio_value = None
|
|
|
|
# Update and broadcast live returns if session is active
|
|
if self.storage.is_live_session_active:
|
|
await self._update_and_broadcast_live_returns()
|
|
|
|
await asyncio.sleep(60) # Check every minute
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Market status monitor error: {e}")
|
|
await asyncio.sleep(60)
|
|
|
|
async def _update_and_broadcast_live_returns(self):
|
|
"""Calculate and broadcast live returns for current session"""
|
|
if not self.storage.is_live_session_active:
|
|
return
|
|
|
|
# Get current prices and calculate portfolio value
|
|
prices = self.market_service.get_all_prices()
|
|
if not prices or not any(p > 0 for p in prices.values()):
|
|
return
|
|
|
|
# Load current internal state to get baseline values
|
|
state = self.storage.load_internal_state()
|
|
|
|
# Get latest values from history (if available)
|
|
equity_history = state.get("equity_history", [])
|
|
baseline_history = state.get("baseline_history", [])
|
|
baseline_vw_history = state.get("baseline_vw_history", [])
|
|
momentum_history = state.get("momentum_history", [])
|
|
|
|
current_equity = equity_history[-1]["v"] if equity_history else None
|
|
current_baseline = (
|
|
baseline_history[-1]["v"] if baseline_history else None
|
|
)
|
|
current_baseline_vw = (
|
|
baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
|
)
|
|
current_momentum = (
|
|
momentum_history[-1]["v"] if momentum_history else None
|
|
)
|
|
|
|
# Update live returns with current values
|
|
point = self.storage.update_live_returns(
|
|
current_equity=current_equity,
|
|
current_baseline=current_baseline,
|
|
current_baseline_vw=current_baseline_vw,
|
|
current_momentum=current_momentum,
|
|
)
|
|
|
|
# Broadcast if we have new data
|
|
if point:
|
|
live_returns = self.storage.get_live_returns()
|
|
await self.broadcast(
|
|
{
|
|
"type": "team_summary",
|
|
"equity_return": live_returns["equity_return"],
|
|
"baseline_return": live_returns["baseline_return"],
|
|
"baseline_vw_return": live_returns["baseline_vw_return"],
|
|
"momentum_return": live_returns["momentum_return"],
|
|
},
|
|
)
|
|
|
|
async def on_strategy_trigger(self, date: str):
|
|
"""Handle trading cycle trigger"""
|
|
if self._cycle_lock.locked():
|
|
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
|
await self.state_sync.on_system_message(
|
|
f"已有交易周期在运行,跳过本次触发: {date}",
|
|
)
|
|
return
|
|
|
|
async with self._cycle_lock:
|
|
logger.info(f"Strategy triggered for {date}")
|
|
|
|
tickers = self.config.get("tickers", [])
|
|
|
|
if self.is_backtest:
|
|
await self._run_backtest_cycle(date, tickers)
|
|
else:
|
|
await self._run_live_cycle(date, tickers)
|
|
|
|
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
|
"""Run backtest cycle with pre-loaded prices"""
|
|
self.market_service.set_backtest_date(date)
|
|
await self.market_service.emit_market_open()
|
|
|
|
await self.state_sync.on_cycle_start(date)
|
|
self._dashboard.update(date=date, status="Analyzing...")
|
|
|
|
prices = self.market_service.get_open_prices()
|
|
close_prices = self.market_service.get_close_prices()
|
|
market_caps = self._get_market_caps(tickers, date)
|
|
|
|
result = await self.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=date,
|
|
prices=prices,
|
|
close_prices=close_prices,
|
|
market_caps=market_caps,
|
|
)
|
|
|
|
await self.market_service.emit_market_close()
|
|
settlement_result = result.get("settlement_result")
|
|
self._save_cycle_results(result, date, close_prices, settlement_result)
|
|
await self._broadcast_portfolio_updates(result, close_prices)
|
|
await self._finalize_cycle(date)
|
|
|
|
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
|
"""
|
|
Run live cycle with real market timing.
|
|
|
|
- Analysis runs immediately
|
|
- Daily mode waits for open/close as before
|
|
- Intraday mode executes only during market open
|
|
and skips trading outside market hours
|
|
"""
|
|
# Get actual trading date (might be next trading day if weekend)
|
|
trading_date = self.market_service.get_live_trading_date()
|
|
logger.info(
|
|
f"Live cycle: triggered={date}, trading_date={trading_date}",
|
|
)
|
|
|
|
await self.state_sync.on_cycle_start(trading_date)
|
|
self._dashboard.update(date=trading_date, status="Analyzing...")
|
|
|
|
market_caps = self._get_market_caps(tickers, trading_date)
|
|
schedule_mode = self.config.get("schedule_mode", "daily")
|
|
market_status = self.market_service.get_market_status()
|
|
current_prices = self.market_service.get_all_prices()
|
|
|
|
if schedule_mode == "intraday":
|
|
execute_decisions = market_status.get("status") == "open"
|
|
if execute_decisions:
|
|
await self.state_sync.on_system_message(
|
|
"定时任务触发:当前处于交易时段,本轮将执行交易决策",
|
|
)
|
|
else:
|
|
await self.state_sync.on_system_message(
|
|
"定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易",
|
|
)
|
|
|
|
result = await self.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=trading_date,
|
|
prices=current_prices,
|
|
market_caps=market_caps,
|
|
execute_decisions=execute_decisions,
|
|
)
|
|
close_prices = current_prices
|
|
else:
|
|
# Daily mode keeps the original full-session behavior
|
|
result = await self.pipeline.run_cycle(
|
|
tickers=tickers,
|
|
date=trading_date,
|
|
market_caps=market_caps,
|
|
get_open_prices_fn=self.market_service.wait_for_open_prices,
|
|
get_close_prices_fn=self.market_service.wait_for_close_prices,
|
|
)
|
|
close_prices = self.market_service.get_all_prices()
|
|
|
|
settlement_result = result.get("settlement_result")
|
|
self._save_cycle_results(
|
|
result,
|
|
trading_date,
|
|
close_prices,
|
|
settlement_result,
|
|
)
|
|
await self._broadcast_portfolio_updates(result, close_prices)
|
|
await self._finalize_cycle(trading_date)
|
|
|
|
async def _finalize_cycle(self, date: str):
|
|
"""Finalize cycle: broadcast state and update dashboard"""
|
|
summary = self.storage.load_file("summary") or {}
|
|
|
|
# Include live returns if session is active
|
|
if self.storage.is_live_session_active:
|
|
live_returns = self.storage.get_live_returns()
|
|
summary.update(live_returns)
|
|
|
|
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
|
|
|
holdings = self.storage.load_file("holdings") or []
|
|
trades = self.storage.load_file("trades") or []
|
|
leaderboard = self.storage.load_file("leaderboard") or []
|
|
|
|
if leaderboard:
|
|
await self.state_sync.on_leaderboard_update(leaderboard)
|
|
|
|
self._dashboard.update(
|
|
date=date,
|
|
status="Running",
|
|
portfolio=summary,
|
|
holdings=holdings,
|
|
trades=trades,
|
|
)
|
|
|
|
def _get_market_caps(
|
|
self,
|
|
tickers: List[str],
|
|
date: str,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Get market caps for tickers (stub implementation)
|
|
|
|
Args:
|
|
tickers: List of tickers
|
|
date: Trading date
|
|
|
|
Returns:
|
|
Dict mapping ticker to market cap
|
|
"""
|
|
from ..tools.data_tools import get_market_cap
|
|
|
|
market_caps = {}
|
|
for ticker in tickers:
|
|
try:
|
|
market_cap = get_market_cap(ticker, date)
|
|
if market_cap:
|
|
market_caps[ticker] = market_cap
|
|
else:
|
|
market_caps[ticker] = 1e9
|
|
except Exception:
|
|
market_caps[ticker] = 1e9
|
|
|
|
return market_caps
|
|
|
|
async def _broadcast_portfolio_updates(
|
|
self,
|
|
result: Dict[str, Any],
|
|
prices: Dict[str, float],
|
|
):
|
|
portfolio = result.get("portfolio", {})
|
|
|
|
if portfolio:
|
|
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
|
if holdings:
|
|
await self.state_sync.on_holdings_update(holdings)
|
|
|
|
stats = FrontendAdapter.build_stats(portfolio, prices)
|
|
if stats:
|
|
await self.state_sync.on_stats_update(stats)
|
|
|
|
executed_trades = result.get("executed_trades", [])
|
|
if executed_trades:
|
|
await self.state_sync.on_trades_executed(executed_trades)
|
|
|
|
def _save_cycle_results(
|
|
self,
|
|
result: Dict[str, Any],
|
|
date: str,
|
|
prices: Dict[str, float],
|
|
settlement_result: Optional[Dict[str, Any]] = None,
|
|
):
|
|
portfolio = result.get("portfolio", {})
|
|
executed_trades = result.get("executed_trades", [])
|
|
|
|
# Extract baseline values from settlement result
|
|
baseline_values = None
|
|
if settlement_result:
|
|
baseline_values = settlement_result.get("baseline_values")
|
|
|
|
if portfolio:
|
|
self.storage.update_dashboard_after_cycle(
|
|
portfolio=portfolio,
|
|
prices=prices,
|
|
date=date,
|
|
executed_trades=executed_trades,
|
|
baseline_values=baseline_values,
|
|
)
|
|
|
|
async def _run_backtest_dates(self, dates: List[str]):
|
|
self.state_sync.set_backtest_dates(dates)
|
|
self._dashboard.update(days_total=len(dates), days_completed=0)
|
|
|
|
await self.state_sync.on_system_message(
|
|
f"Starting backtest - {len(dates)} trading days",
|
|
)
|
|
|
|
try:
|
|
for i, date in enumerate(dates):
|
|
self._dashboard.update(days_completed=i)
|
|
await self.on_strategy_trigger(date=date)
|
|
await asyncio.sleep(0.1)
|
|
|
|
await self.state_sync.on_system_message(
|
|
f"Backtest complete - {len(dates)} days",
|
|
)
|
|
|
|
# Update dashboard with final state
|
|
summary = self.storage.load_file("summary") or {}
|
|
self._dashboard.update(
|
|
status="Complete",
|
|
portfolio=summary,
|
|
days_completed=len(dates),
|
|
)
|
|
self._dashboard.stop()
|
|
self._dashboard.print_final_summary()
|
|
except Exception as e:
|
|
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
|
|
logger.error(error_msg, exc_info=True)
|
|
await self.state_sync.on_system_message(error_msg)
|
|
self._dashboard.update(status=f"Failed: {str(e)}")
|
|
self._dashboard.stop()
|
|
raise
|
|
finally:
|
|
self._backtest_task = None
|
|
|
|
def _handle_backtest_exception(self, task: asyncio.Task):
|
|
"""Handle exceptions from backtest task"""
|
|
try:
|
|
task.result()
|
|
except asyncio.CancelledError:
|
|
logger.info("Backtest task was cancelled")
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Backtest task failed with exception:{type(e).__name__}:{e}",
|
|
exc_info=True,
|
|
)
|
|
|
|
def _handle_manual_cycle_exception(self, task: asyncio.Task):
|
|
"""Handle exceptions from manually-triggered live cycles."""
|
|
self._manual_cycle_task = None
|
|
try:
|
|
task.result()
|
|
except asyncio.CancelledError:
|
|
logger.info("Manual cycle task was cancelled")
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Manual cycle task failed with exception:%s:%s",
|
|
type(exc).__name__,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
|
|
def set_backtest_dates(self, dates: List[str]):
|
|
self.state_sync.set_backtest_dates(dates)
|
|
if dates:
|
|
self._backtest_start_date = dates[0]
|
|
self._backtest_end_date = dates[-1]
|
|
self._dashboard.days_total = len(dates)
|
|
|
|
def stop(self):
|
|
self.state_sync.save_state()
|
|
self.market_service.stop()
|
|
if self._backtest_task:
|
|
self._backtest_task.cancel()
|
|
if self._market_status_task:
|
|
self._market_status_task.cancel()
|
|
if self._watchlist_ingest_task:
|
|
self._watchlist_ingest_task.cancel()
|
|
self._dashboard.stop()
|