确认PokieTicker新闻库数据源
This commit is contained in:
@@ -5,12 +5,18 @@ WebSocket Gateway for frontend communication
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
|
||||
from backend.config.bootstrap_config import (
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -18,6 +24,7 @@ from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +58,7 @@ class Gateway:
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.connected_clients: Set[ServerConnection] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
@@ -63,6 +70,7 @@ class Gateway:
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
self._provider_router = get_provider_router()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._project_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
@@ -87,6 +95,7 @@ class Gateway:
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
@@ -94,6 +103,20 @@ class Gateway:
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.config.get(
|
||||
"initial_cash",
|
||||
self.storage.initial_cash,
|
||||
),
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
self.state_sync.update_state(
|
||||
"data_sources",
|
||||
self._provider_router.get_usage_snapshot(),
|
||||
@@ -159,7 +182,7 @@ class Gateway:
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
@@ -170,7 +193,7 @@ class Gateway:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
async def _send_initial_state(self, websocket: ServerConnection):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
@@ -198,7 +221,7 @@ class Gateway:
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
websocket: ServerConnection,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
@@ -221,12 +244,104 @@ class Gateway:
|
||||
await self._handle_start_backtest(data)
|
||||
elif msg_type == "reload_runtime_assets":
|
||||
await self._handle_reload_runtime_assets()
|
||||
elif msg_type == "update_watchlist":
|
||||
await self._handle_update_watchlist(websocket, data)
|
||||
elif msg_type == "get_stock_history":
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_get_stock_history(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": "",
|
||||
"prices": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price.model_dump() for price in prices][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_explain_events(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_explain_events_loaded",
|
||||
"ticker": ticker,
|
||||
"events": snapshot.get("events", []),
|
||||
"signals": snapshot.get("signals", []),
|
||||
"trades": snapshot.get("trades", []),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
@@ -239,8 +354,15 @@ class Gateway:
|
||||
self._backtest_task = task
|
||||
|
||||
async def _handle_reload_runtime_assets(self):
|
||||
"""Reload prompt assets and active skills without restarting the server."""
|
||||
result = self.pipeline.reload_runtime_assets()
|
||||
"""Reload prompt, skills, and safe runtime config without restart."""
|
||||
config_name = self.config.get("config_name", "default")
|
||||
runtime_config = resolve_runtime_config(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
enable_memory=self.config.get("enable_memory", False),
|
||||
)
|
||||
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||
runtime_updates = self._apply_runtime_config(runtime_config)
|
||||
await self.state_sync.on_system_message(
|
||||
"Runtime assets reloaded.",
|
||||
)
|
||||
@@ -248,9 +370,174 @@ class Gateway:
|
||||
{
|
||||
"type": "runtime_assets_reloaded",
|
||||
**result,
|
||||
**runtime_updates,
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_update_watchlist(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
|
||||
tickers = self._normalize_watchlist(data.get("tickers"))
|
||||
if not tickers:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "update_watchlist requires at least one valid ticker.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
config_name = self.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
updates={"tickers": tickers},
|
||||
)
|
||||
await self.state_sync.on_system_message(
|
||||
f"Watchlist updated: {', '.join(tickers)}",
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "watchlist_updated",
|
||||
"config_name": config_name,
|
||||
"tickers": tickers,
|
||||
},
|
||||
)
|
||||
await self._handle_reload_runtime_assets()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
"""Parse watchlist payloads from websocket messages."""
|
||||
if raw_tickers is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_tickers, str):
|
||||
candidates = raw_tickers.split(",")
|
||||
elif isinstance(raw_tickers, list):
|
||||
candidates = raw_tickers
|
||||
else:
|
||||
candidates = [raw_tickers]
|
||||
|
||||
tickers: List[str] = []
|
||||
for candidate in candidates:
|
||||
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||
if symbol and symbol not in tickers:
|
||||
tickers.append(symbol)
|
||||
return tickers
|
||||
|
||||
def _apply_runtime_config(
|
||||
self,
|
||||
runtime_config: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply runtime config to gateway-owned services and state."""
|
||||
warnings: List[str] = []
|
||||
|
||||
ticker_changes = self.market_service.update_tickers(
|
||||
runtime_config.get("tickers", []),
|
||||
)
|
||||
self.config["tickers"] = ticker_changes["active"]
|
||||
|
||||
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
|
||||
|
||||
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
margin_requirement=runtime_config["margin_requirement"],
|
||||
)
|
||||
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
|
||||
"margin_requirement",
|
||||
runtime_config["margin_requirement"],
|
||||
)
|
||||
|
||||
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||
current_initial_cash = float(self.storage.initial_cash)
|
||||
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||
if not initial_cash_applied:
|
||||
if (
|
||||
self.storage.can_apply_initial_cash()
|
||||
and self.pipeline.pm.can_apply_initial_cash()
|
||||
):
|
||||
initial_cash_applied = self.storage.apply_initial_cash(
|
||||
requested_initial_cash,
|
||||
)
|
||||
if initial_cash_applied:
|
||||
self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
initial_cash=requested_initial_cash,
|
||||
)
|
||||
self.config["initial_cash"] = self.storage.initial_cash
|
||||
else:
|
||||
warnings.append(
|
||||
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||
"because the run already has positions, margin usage, or trades.",
|
||||
)
|
||||
|
||||
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||
current_enable_memory = bool(self.config.get("enable_memory", False))
|
||||
if requested_enable_memory != current_enable_memory:
|
||||
warnings.append(
|
||||
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||
"because long-term memory contexts are created at startup.",
|
||||
)
|
||||
|
||||
self._sync_runtime_state()
|
||||
|
||||
return {
|
||||
"runtime_config_requested": runtime_config,
|
||||
"runtime_config_applied": {
|
||||
"tickers": list(self.config.get("tickers", [])),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config["margin_requirement"],
|
||||
"max_comm_cycles": self.config["max_comm_cycles"],
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
"runtime_config_status": {
|
||||
"tickers": True,
|
||||
"initial_cash": initial_cash_applied,
|
||||
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||
"max_comm_cycles": True,
|
||||
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||
},
|
||||
"ticker_changes": ticker_changes,
|
||||
"runtime_config_warnings": warnings,
|
||||
}
|
||||
|
||||
def _sync_runtime_state(self) -> None:
|
||||
"""Refresh persisted state and dashboard after runtime config changes."""
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
|
||||
self.storage.update_server_state_from_dashboard(self.state_sync.state)
|
||||
self.state_sync.save_state()
|
||||
|
||||
self._dashboard.tickers = list(self.config.get("tickers", []))
|
||||
self._dashboard.initial_cash = self.storage.initial_cash
|
||||
self._dashboard.enable_memory = bool(
|
||||
self.config.get("enable_memory", False),
|
||||
)
|
||||
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
self._dashboard.update(
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
@@ -269,7 +556,7 @@ class Gateway:
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
client: ServerConnection,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user