确认PokieTicker新闻库数据源
This commit is contained in:
@@ -5,12 +5,18 @@ WebSocket Gateway for frontend communication
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
|
||||
from backend.config.bootstrap_config import (
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -18,6 +24,7 @@ from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +58,7 @@ class Gateway:
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.connected_clients: Set[ServerConnection] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
@@ -63,6 +70,7 @@ class Gateway:
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
self._provider_router = get_provider_router()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._project_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
@@ -87,6 +95,7 @@ class Gateway:
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
@@ -94,6 +103,20 @@ class Gateway:
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.config.get(
|
||||
"initial_cash",
|
||||
self.storage.initial_cash,
|
||||
),
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
self.state_sync.update_state(
|
||||
"data_sources",
|
||||
self._provider_router.get_usage_snapshot(),
|
||||
@@ -159,7 +182,7 @@ class Gateway:
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
@@ -170,7 +193,7 @@ class Gateway:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
async def _send_initial_state(self, websocket: ServerConnection):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
@@ -198,7 +221,7 @@ class Gateway:
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
websocket: ServerConnection,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
@@ -221,12 +244,104 @@ class Gateway:
|
||||
await self._handle_start_backtest(data)
|
||||
elif msg_type == "reload_runtime_assets":
|
||||
await self._handle_reload_runtime_assets()
|
||||
elif msg_type == "update_watchlist":
|
||||
await self._handle_update_watchlist(websocket, data)
|
||||
elif msg_type == "get_stock_history":
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_get_stock_history(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": "",
|
||||
"prices": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price.model_dump() for price in prices][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_explain_events(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_explain_events_loaded",
|
||||
"ticker": ticker,
|
||||
"events": snapshot.get("events", []),
|
||||
"signals": snapshot.get("signals", []),
|
||||
"trades": snapshot.get("trades", []),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
@@ -239,8 +354,15 @@ class Gateway:
|
||||
self._backtest_task = task
|
||||
|
||||
async def _handle_reload_runtime_assets(self):
|
||||
"""Reload prompt assets and active skills without restarting the server."""
|
||||
result = self.pipeline.reload_runtime_assets()
|
||||
"""Reload prompt, skills, and safe runtime config without restart."""
|
||||
config_name = self.config.get("config_name", "default")
|
||||
runtime_config = resolve_runtime_config(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
enable_memory=self.config.get("enable_memory", False),
|
||||
)
|
||||
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||
runtime_updates = self._apply_runtime_config(runtime_config)
|
||||
await self.state_sync.on_system_message(
|
||||
"Runtime assets reloaded.",
|
||||
)
|
||||
@@ -248,9 +370,174 @@ class Gateway:
|
||||
{
|
||||
"type": "runtime_assets_reloaded",
|
||||
**result,
|
||||
**runtime_updates,
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_update_watchlist(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
|
||||
tickers = self._normalize_watchlist(data.get("tickers"))
|
||||
if not tickers:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "update_watchlist requires at least one valid ticker.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
config_name = self.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
updates={"tickers": tickers},
|
||||
)
|
||||
await self.state_sync.on_system_message(
|
||||
f"Watchlist updated: {', '.join(tickers)}",
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "watchlist_updated",
|
||||
"config_name": config_name,
|
||||
"tickers": tickers,
|
||||
},
|
||||
)
|
||||
await self._handle_reload_runtime_assets()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
"""Parse watchlist payloads from websocket messages."""
|
||||
if raw_tickers is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_tickers, str):
|
||||
candidates = raw_tickers.split(",")
|
||||
elif isinstance(raw_tickers, list):
|
||||
candidates = raw_tickers
|
||||
else:
|
||||
candidates = [raw_tickers]
|
||||
|
||||
tickers: List[str] = []
|
||||
for candidate in candidates:
|
||||
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||
if symbol and symbol not in tickers:
|
||||
tickers.append(symbol)
|
||||
return tickers
|
||||
|
||||
def _apply_runtime_config(
|
||||
self,
|
||||
runtime_config: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply runtime config to gateway-owned services and state."""
|
||||
warnings: List[str] = []
|
||||
|
||||
ticker_changes = self.market_service.update_tickers(
|
||||
runtime_config.get("tickers", []),
|
||||
)
|
||||
self.config["tickers"] = ticker_changes["active"]
|
||||
|
||||
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
|
||||
|
||||
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
margin_requirement=runtime_config["margin_requirement"],
|
||||
)
|
||||
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
|
||||
"margin_requirement",
|
||||
runtime_config["margin_requirement"],
|
||||
)
|
||||
|
||||
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||
current_initial_cash = float(self.storage.initial_cash)
|
||||
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||
if not initial_cash_applied:
|
||||
if (
|
||||
self.storage.can_apply_initial_cash()
|
||||
and self.pipeline.pm.can_apply_initial_cash()
|
||||
):
|
||||
initial_cash_applied = self.storage.apply_initial_cash(
|
||||
requested_initial_cash,
|
||||
)
|
||||
if initial_cash_applied:
|
||||
self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
initial_cash=requested_initial_cash,
|
||||
)
|
||||
self.config["initial_cash"] = self.storage.initial_cash
|
||||
else:
|
||||
warnings.append(
|
||||
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||
"because the run already has positions, margin usage, or trades.",
|
||||
)
|
||||
|
||||
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||
current_enable_memory = bool(self.config.get("enable_memory", False))
|
||||
if requested_enable_memory != current_enable_memory:
|
||||
warnings.append(
|
||||
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||
"because long-term memory contexts are created at startup.",
|
||||
)
|
||||
|
||||
self._sync_runtime_state()
|
||||
|
||||
return {
|
||||
"runtime_config_requested": runtime_config,
|
||||
"runtime_config_applied": {
|
||||
"tickers": list(self.config.get("tickers", [])),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config["margin_requirement"],
|
||||
"max_comm_cycles": self.config["max_comm_cycles"],
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
"runtime_config_status": {
|
||||
"tickers": True,
|
||||
"initial_cash": initial_cash_applied,
|
||||
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||
"max_comm_cycles": True,
|
||||
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||
},
|
||||
"ticker_changes": ticker_changes,
|
||||
"runtime_config_warnings": warnings,
|
||||
}
|
||||
|
||||
def _sync_runtime_state(self) -> None:
|
||||
"""Refresh persisted state and dashboard after runtime config changes."""
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
|
||||
self.storage.update_server_state_from_dashboard(self.state_sync.state)
|
||||
self.state_sync.save_state()
|
||||
|
||||
self._dashboard.tickers = list(self.config.get("tickers", []))
|
||||
self._dashboard.initial_cash = self.storage.initial_cash
|
||||
self._dashboard.enable_memory = bool(
|
||||
self.config.get("enable_memory", False),
|
||||
)
|
||||
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
self._dashboard.update(
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
@@ -269,7 +556,7 @@ class Gateway:
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
client: ServerConnection,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -54,6 +54,7 @@ class MarketService:
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_record_func: Optional[Callable[..., None]] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
@@ -92,6 +93,10 @@ class MarketService:
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
|
||||
"""Register an optional callback for persisting runtime price points."""
|
||||
self._price_record_func = recorder
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
@@ -169,6 +174,24 @@ class MarketService:
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
if self._price_record_func:
|
||||
try:
|
||||
self._price_record_func(
|
||||
ticker=symbol,
|
||||
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
|
||||
price=float(price),
|
||||
open_price=float(open_price) if open_price is not None else None,
|
||||
ret=float(ret),
|
||||
source=self.mode_name.lower(),
|
||||
meta=price_data,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to record price point for %s: %s",
|
||||
symbol,
|
||||
exc,
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
@@ -205,6 +228,43 @@ class MarketService:
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
|
||||
"""Hot-update subscribed tickers without restarting the service."""
|
||||
normalized: List[str] = []
|
||||
for ticker in tickers:
|
||||
symbol = normalize_symbol(ticker)
|
||||
if symbol and symbol not in normalized:
|
||||
normalized.append(symbol)
|
||||
|
||||
previous = list(self.tickers)
|
||||
removed = [ticker for ticker in previous if ticker not in normalized]
|
||||
added = [ticker for ticker in normalized if ticker not in previous]
|
||||
self.tickers = normalized
|
||||
|
||||
if self._price_manager:
|
||||
if removed:
|
||||
self._price_manager.unsubscribe(removed)
|
||||
if added:
|
||||
if self.mock_mode:
|
||||
self._price_manager.subscribe(
|
||||
added,
|
||||
base_prices={ticker: 100.0 for ticker in added},
|
||||
)
|
||||
else:
|
||||
self._price_manager.subscribe(added)
|
||||
|
||||
if self.backtest_mode and self._current_date:
|
||||
self._price_manager.set_date(self._current_date)
|
||||
|
||||
for ticker in removed:
|
||||
self.cache.pop(ticker, None)
|
||||
|
||||
return {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"active": list(self.tickers),
|
||||
}
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
|
||||
388
backend/services/runtime_db.py
Normal file
388
backend/services/runtime_db.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped SQLite storage for query-oriented runtime history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
timestamp TEXT,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
ticker TEXT,
|
||||
title TEXT,
|
||||
content TEXT,
|
||||
payload_json TEXT NOT NULL,
|
||||
run_date TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
side TEXT,
|
||||
qty REAL,
|
||||
price REAL,
|
||||
timestamp TEXT,
|
||||
trading_date TEXT,
|
||||
agent_id TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signals (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
role TEXT,
|
||||
signal TEXT,
|
||||
confidence REAL,
|
||||
reasoning_json TEXT,
|
||||
real_return REAL,
|
||||
is_correct TEXT,
|
||||
trade_date TEXT,
|
||||
created_at TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS price_points (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
open_price REAL,
|
||||
ret REAL,
|
||||
source TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_key(*parts: Any) -> str:
|
||||
raw = "::".join("" if part is None else str(part) for part in parts)
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class RuntimeDb:
|
||||
"""Small SQLite helper for append-mostly runtime data."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def insert_event(self, event: Dict[str, Any]):
|
||||
payload = dict(event or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
event_id = payload.get("id") or _hash_key(
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("content"),
|
||||
payload.get("title"),
|
||||
)
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
|
||||
ticker = payload["tickers"][0]
|
||||
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO events
|
||||
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
event_id,
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("agentName") or payload.get("agent_name"),
|
||||
ticker,
|
||||
payload.get("title"),
|
||||
payload.get("content"),
|
||||
_json_dumps(payload),
|
||||
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
|
||||
),
|
||||
)
|
||||
|
||||
def upsert_trade(self, trade: Dict[str, Any]):
|
||||
payload = dict(trade or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
trade_id = payload.get("id") or _hash_key(
|
||||
payload.get("ticker"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO trades
|
||||
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trade_id,
|
||||
payload.get("ticker"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("trading_date"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
return
|
||||
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM signals")
|
||||
for agent in leaderboard:
|
||||
agent_id = agent.get("agentId")
|
||||
agent_name = agent.get("name")
|
||||
role = agent.get("role")
|
||||
for signal in agent.get("signals", []) or []:
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
continue
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def insert_price_point(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
timestamp: str,
|
||||
price: float,
|
||||
open_price: Optional[float] = None,
|
||||
ret: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO price_points
|
||||
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
price_id,
|
||||
ticker,
|
||||
timestamp,
|
||||
price,
|
||||
open_price,
|
||||
ret,
|
||||
source,
|
||||
_json_dumps(meta or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def get_stock_explain_snapshot(
|
||||
self,
|
||||
ticker: str,
|
||||
*,
|
||||
limit_events: int = 24,
|
||||
limit_trades: int = 12,
|
||||
limit_signals: int = 12,
|
||||
) -> Dict[str, list[Dict[str, Any]]]:
|
||||
"""Fetch query-oriented history for a single ticker."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return {"events": [], "trades": [], "signals": []}
|
||||
|
||||
with self._connect() as conn:
|
||||
trade_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM trades
|
||||
WHERE ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_trades),
|
||||
).fetchall()
|
||||
signal_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM signals
|
||||
WHERE ticker = ?
|
||||
ORDER BY trade_date DESC, created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_signals),
|
||||
).fetchall()
|
||||
event_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM events
|
||||
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
|
||||
).fetchall()
|
||||
|
||||
normalized_events = []
|
||||
seen_event_ids: set[str] = set()
|
||||
for row in event_rows:
|
||||
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||
content = str(row["content"] or payload.get("content") or "")
|
||||
title = str(row["title"] or payload.get("title") or "")
|
||||
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
|
||||
continue
|
||||
event_id = row["id"]
|
||||
if event_id in seen_event_ids:
|
||||
continue
|
||||
seen_event_ids.add(event_id)
|
||||
normalized_events.append(
|
||||
{
|
||||
"id": event_id,
|
||||
"type": "mention",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
|
||||
"meta": payload.get("conferenceTitle")
|
||||
or payload.get("feedType")
|
||||
or row["event_type"],
|
||||
"body": content,
|
||||
"tone": "neutral",
|
||||
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
|
||||
},
|
||||
)
|
||||
if len(normalized_events) >= limit_events:
|
||||
break
|
||||
|
||||
normalized_trades = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "trade",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": f"{row['side']} {int(row['qty'] or 0)} 股",
|
||||
"meta": "交易执行",
|
||||
"body": f"成交价 ${float(row['price'] or 0):.2f}",
|
||||
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
|
||||
}
|
||||
for row in trade_rows
|
||||
]
|
||||
|
||||
normalized_signals = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "signal",
|
||||
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
|
||||
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
|
||||
"meta": row["role"],
|
||||
"body": (
|
||||
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
|
||||
if row["real_return"] is not None
|
||||
else "该信号暂未完成后验评估"
|
||||
),
|
||||
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
|
||||
}
|
||||
for row in signal_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"events": normalized_events,
|
||||
"trades": normalized_trades,
|
||||
"signals": normalized_signals,
|
||||
}
|
||||
@@ -10,6 +10,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .runtime_db import RuntimeDb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ class StorageService:
|
||||
self.state_dir = self.dashboard_dir.parent / "state"
|
||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.server_state_file = self.state_dir / "server_state.json"
|
||||
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||
|
||||
# Feed history (for agent messages)
|
||||
self.max_feed_history = 200
|
||||
@@ -114,6 +117,11 @@ class StorageService:
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
if file_type == "leaderboard" and isinstance(data, list):
|
||||
self.runtime_db.replace_signals_for_leaderboard(data)
|
||||
elif file_type == "trades" and isinstance(data, list):
|
||||
for trade in data:
|
||||
self.runtime_db.upsert_trade(trade)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {file_type}.json: {e}")
|
||||
|
||||
@@ -211,6 +219,7 @@ class StorageService:
|
||||
try:
|
||||
with open(self.internal_state_file, "w", encoding="utf-8") as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
self._sync_price_history_to_db(state.get("price_history", {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save internal state: {e}")
|
||||
|
||||
@@ -231,6 +240,41 @@ class StorageService:
|
||||
"margin_requirement": 0.25, # Default 25% margin requirement
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _portfolio_is_pristine(portfolio_state: Dict[str, Any]) -> bool:
|
||||
"""Return whether the persisted portfolio can be safely rebased."""
|
||||
positions = portfolio_state.get("positions", {})
|
||||
has_positions = any(
|
||||
position.get("long", 0) or position.get("short", 0)
|
||||
for position in positions.values()
|
||||
)
|
||||
margin_used = float(portfolio_state.get("margin_used", 0.0) or 0.0)
|
||||
return not has_positions and margin_used == 0.0
|
||||
|
||||
def can_apply_initial_cash(self) -> bool:
|
||||
"""Only allow initial cash changes before the run has traded."""
|
||||
state = self.load_internal_state()
|
||||
if not self._portfolio_is_pristine(state.get("portfolio_state", {})):
|
||||
return False
|
||||
if state.get("all_trades"):
|
||||
return False
|
||||
return len(state.get("equity_history", [])) <= 1
|
||||
|
||||
def apply_initial_cash(self, initial_cash: float) -> bool:
|
||||
"""Rebase storage state to a new initial cash when the run is pristine."""
|
||||
if not self.can_apply_initial_cash():
|
||||
return False
|
||||
|
||||
self.initial_cash = float(initial_cash)
|
||||
if self.internal_state_file.exists():
|
||||
self.internal_state_file.unlink()
|
||||
|
||||
self.initialize_empty_dashboard()
|
||||
state = self.load_server_state()
|
||||
self.update_server_state_from_dashboard(state)
|
||||
self.save_server_state(state)
|
||||
return True
|
||||
|
||||
def save_portfolio_state(self, portfolio: Dict[str, Any]):
|
||||
"""
|
||||
Save portfolio state to internal state
|
||||
@@ -750,6 +794,7 @@ class StorageService:
|
||||
"last_day_history": [],
|
||||
"trading_days_total": 0,
|
||||
"trading_days_completed": 0,
|
||||
"price_history": {},
|
||||
}
|
||||
|
||||
if not self.server_state_file.exists():
|
||||
@@ -771,6 +816,11 @@ class StorageService:
|
||||
)
|
||||
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
|
||||
|
||||
for event in saved_state.get("feed_history", []):
|
||||
self.runtime_db.insert_event(event)
|
||||
for trade in saved_state.get("trades", []):
|
||||
self.runtime_db.upsert_trade(trade)
|
||||
|
||||
return saved_state
|
||||
|
||||
def save_server_state(self, state: Dict[str, Any]):
|
||||
@@ -852,6 +902,7 @@ class StorageService:
|
||||
state["feed_history"] = []
|
||||
|
||||
state["feed_history"].insert(0, feed_msg)
|
||||
self.runtime_db.insert_event(feed_msg)
|
||||
|
||||
# Trim to max size
|
||||
if len(state["feed_history"]) > self.max_feed_history:
|
||||
@@ -861,6 +912,69 @@ class StorageService:
|
||||
|
||||
return True
|
||||
|
||||
def record_price_point(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
timestamp: str,
|
||||
price: float,
|
||||
open_price: Optional[float] = None,
|
||||
ret: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Persist a runtime price point for later query-oriented reads."""
|
||||
if not ticker or not timestamp:
|
||||
return
|
||||
try:
|
||||
self.runtime_db.insert_price_point(
|
||||
ticker=ticker,
|
||||
timestamp=timestamp,
|
||||
price=price,
|
||||
open_price=open_price,
|
||||
ret=ret,
|
||||
source=source,
|
||||
meta=meta,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record price point for %s: %s", ticker, exc)
|
||||
|
||||
def _sync_price_history_to_db(self, price_history: Dict[str, Any]):
|
||||
"""Backfill structured price points from serialized internal state."""
|
||||
if not isinstance(price_history, dict):
|
||||
return
|
||||
for ticker, points in price_history.items():
|
||||
if not ticker or not isinstance(points, list):
|
||||
continue
|
||||
for point in points:
|
||||
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
||||
timestamp, price = point[0], point[1]
|
||||
try:
|
||||
self.record_price_point(
|
||||
ticker=str(ticker),
|
||||
timestamp=str(timestamp),
|
||||
price=float(price),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
elif isinstance(point, dict):
|
||||
timestamp = point.get("timestamp") or point.get("label") or point.get("date")
|
||||
price = point.get("price") or point.get("close") or point.get("value")
|
||||
if not timestamp or price is None:
|
||||
continue
|
||||
try:
|
||||
self.record_price_point(
|
||||
ticker=str(ticker),
|
||||
timestamp=str(timestamp),
|
||||
price=float(price),
|
||||
open_price=point.get("open"),
|
||||
ret=point.get("ret"),
|
||||
source=point.get("source"),
|
||||
meta=point,
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
def _get_default_stats(self) -> Dict[str, Any]:
|
||||
"""Get default stats structure"""
|
||||
return {
|
||||
@@ -889,6 +1003,7 @@ class StorageService:
|
||||
stats = self.load_file("stats") or self._get_default_stats()
|
||||
trades = self.load_file("trades") or []
|
||||
leaderboard = self.load_file("leaderboard") or []
|
||||
internal_state = self.load_internal_state()
|
||||
|
||||
# Update state
|
||||
state["portfolio"] = {
|
||||
@@ -910,6 +1025,9 @@ class StorageService:
|
||||
state["stats"] = stats
|
||||
state["trades"] = trades
|
||||
state["leaderboard"] = leaderboard
|
||||
state["price_history"] = internal_state.get("price_history", {})
|
||||
self.runtime_db.replace_signals_for_leaderboard(leaderboard)
|
||||
self._sync_price_history_to_db(state["price_history"])
|
||||
|
||||
# ========== Live Returns Tracking ==========
|
||||
|
||||
|
||||
Reference in New Issue
Block a user