确认PokieTicker新闻库数据源

This commit is contained in:
2026-03-16 02:19:25 +08:00
parent 78f133617f
commit 564c92c0c8
182 changed files with 6436 additions and 1050 deletions

View File

@@ -5,12 +5,18 @@ WebSocket Gateway for frontend communication
import asyncio
import json
import logging
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set
import websockets
from websockets.server import WebSocketServerProtocol
from websockets.asyncio.server import ServerConnection
from backend.config.bootstrap_config import (
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.provider_utils import normalize_symbol
from backend.utils.msg_adapter import FrontendAdapter
from backend.utils.terminal_dashboard import get_dashboard
from backend.core.pipeline import TradingPipeline
@@ -18,6 +24,7 @@ from backend.core.state_sync import StateSync
from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.data.provider_router import get_provider_router
from backend.tools.data_tools import get_prices
logger = logging.getLogger(__name__)
@@ -51,7 +58,7 @@ class Gateway:
self.state_sync.set_broadcast_fn(self.broadcast)
self.pipeline.state_sync = self.state_sync
self.connected_clients: Set[WebSocketServerProtocol] = set()
self.connected_clients: Set[ServerConnection] = set()
self.lock = asyncio.Lock()
self._backtest_task: Optional[asyncio.Task] = None
self._backtest_start_date: Optional[str] = None
@@ -63,6 +70,7 @@ class Gateway:
self._session_start_portfolio_value: Optional[float] = None
self._provider_router = get_provider_router()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._project_root = Path(__file__).resolve().parents[2]
async def start(self, host: str = "0.0.0.0", port: int = 8766):
"""Start gateway server"""
@@ -87,6 +95,7 @@ class Gateway:
self._dashboard.start()
self.state_sync.load_state()
self.market_service.set_price_recorder(self.storage.record_price_point)
self.state_sync.update_state("status", "running")
self.state_sync.update_state("server_mode", self.mode)
self.state_sync.update_state("is_backtest", self.is_backtest)
@@ -94,6 +103,20 @@ class Gateway:
"is_mock_mode",
self.config.get("mock_mode", False),
)
self.state_sync.update_state("tickers", self.config.get("tickers", []))
self.state_sync.update_state(
"runtime_config",
{
"tickers": self.config.get("tickers", []),
"initial_cash": self.config.get(
"initial_cash",
self.storage.initial_cash,
),
"margin_requirement": self.config.get("margin_requirement"),
"max_comm_cycles": self.config.get("max_comm_cycles"),
"enable_memory": self.config.get("enable_memory", False),
},
)
self.state_sync.update_state(
"data_sources",
self._provider_router.get_usage_snapshot(),
@@ -159,7 +182,7 @@ class Gateway:
def state(self) -> Dict[str, Any]:
return self.state_sync.state
async def handle_client(self, websocket: WebSocketServerProtocol):
async def handle_client(self, websocket: ServerConnection):
"""Handle WebSocket client connection"""
async with self.lock:
self.connected_clients.add(websocket)
@@ -170,7 +193,7 @@ class Gateway:
async with self.lock:
self.connected_clients.discard(websocket)
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
async def _send_initial_state(self, websocket: ServerConnection):
state_payload = self.state_sync.get_initial_state_payload(
include_dashboard=True,
)
@@ -198,7 +221,7 @@ class Gateway:
async def _handle_client_messages(
self,
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
):
try:
async for message in websocket:
@@ -221,12 +244,104 @@ class Gateway:
await self._handle_start_backtest(data)
elif msg_type == "reload_runtime_assets":
await self._handle_reload_runtime_assets()
elif msg_type == "update_watchlist":
await self._handle_update_watchlist(websocket, data)
elif msg_type == "get_stock_history":
await self._handle_get_stock_history(websocket, data)
elif msg_type == "get_stock_explain_events":
await self._handle_get_stock_explain_events(websocket, data)
except websockets.ConnectionClosed:
pass
except json.JSONDecodeError:
pass
async def _handle_get_stock_history(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
prices = await asyncio.to_thread(
get_prices,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_explain_events(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(
json.dumps(
{
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
},
ensure_ascii=False,
default=str,
),
)
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
@@ -239,8 +354,15 @@ class Gateway:
self._backtest_task = task
async def _handle_reload_runtime_assets(self):
"""Reload prompt assets and active skills without restarting the server."""
result = self.pipeline.reload_runtime_assets()
"""Reload prompt, skills, and safe runtime config without restart."""
config_name = self.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=self._project_root,
config_name=config_name,
enable_memory=self.config.get("enable_memory", False),
)
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = self._apply_runtime_config(runtime_config)
await self.state_sync.on_system_message(
"Runtime assets reloaded.",
)
@@ -248,9 +370,174 @@ class Gateway:
{
"type": "runtime_assets_reloaded",
**result,
**runtime_updates,
},
)
async def _handle_update_watchlist(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
tickers = self._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(
json.dumps(
{
"type": "error",
"message": "update_watchlist requires at least one valid ticker.",
},
ensure_ascii=False,
),
)
return
config_name = self.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=self._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await self.state_sync.on_system_message(
f"Watchlist updated: {', '.join(tickers)}",
)
await self.broadcast(
{
"type": "watchlist_updated",
"config_name": config_name,
"tickers": tickers,
},
)
await self._handle_reload_runtime_assets()
@staticmethod
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: List[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def _apply_runtime_config(
self,
runtime_config: Dict[str, Any],
) -> Dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: List[str] = []
ticker_changes = self.market_service.update_tickers(
runtime_config.get("tickers", []),
)
self.config["tickers"] = ticker_changes["active"]
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(self.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
self.storage.can_apply_initial_cash()
and self.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = self.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
self.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
self.config["initial_cash"] = self.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(self.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
self._sync_runtime_state()
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(self.config.get("tickers", [])),
"initial_cash": self.storage.initial_cash,
"margin_requirement": self.config["margin_requirement"],
"max_comm_cycles": self.config["max_comm_cycles"],
"enable_memory": self.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def _sync_runtime_state(self) -> None:
"""Refresh persisted state and dashboard after runtime config changes."""
self.state_sync.update_state("tickers", self.config.get("tickers", []))
self.state_sync.update_state(
"runtime_config",
{
"tickers": self.config.get("tickers", []),
"initial_cash": self.storage.initial_cash,
"margin_requirement": self.config.get("margin_requirement"),
"max_comm_cycles": self.config.get("max_comm_cycles"),
"enable_memory": self.config.get("enable_memory", False),
},
)
self.storage.update_server_state_from_dashboard(self.state_sync.state)
self.state_sync.save_state()
self._dashboard.tickers = list(self.config.get("tickers", []))
self._dashboard.initial_cash = self.storage.initial_cash
self._dashboard.enable_memory = bool(
self.config.get("enable_memory", False),
)
summary = self.storage.load_file("summary") or {}
holdings = self.storage.load_file("holdings") or []
trades = self.storage.load_file("trades") or []
self._dashboard.update(
portfolio=summary,
holdings=holdings,
trades=trades,
)
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
@@ -269,7 +556,7 @@ class Gateway:
async def _send_to_client(
self,
client: WebSocketServerProtocol,
client: ServerConnection,
message: str,
):
try:

View File

@@ -54,6 +54,7 @@ class MarketService:
self.running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._broadcast_func: Optional[Callable] = None
self._price_record_func: Optional[Callable[..., None]] = None
self._price_manager: Optional[Any] = None
self._current_date: Optional[str] = None
@@ -92,6 +93,10 @@ class MarketService:
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
)
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
"""Register an optional callback for persisting runtime price points."""
self._price_record_func = recorder
def _make_price_callback(self) -> Callable:
"""Create thread-safe price callback"""
@@ -169,6 +174,24 @@ class MarketService:
((price - open_price) / open_price) * 100 if open_price > 0 else 0
)
if self._price_record_func:
try:
self._price_record_func(
ticker=symbol,
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
price=float(price),
open_price=float(open_price) if open_price is not None else None,
ret=float(ret),
source=self.mode_name.lower(),
meta=price_data,
)
except Exception as exc:
logger.warning(
"Failed to record price point for %s: %s",
symbol,
exc,
)
await self._broadcast_func(
{
"type": "price_update",
@@ -205,6 +228,43 @@ class MarketService:
self._loop = None
self._broadcast_func = None
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
"""Hot-update subscribed tickers without restarting the service."""
normalized: List[str] = []
for ticker in tickers:
symbol = normalize_symbol(ticker)
if symbol and symbol not in normalized:
normalized.append(symbol)
previous = list(self.tickers)
removed = [ticker for ticker in previous if ticker not in normalized]
added = [ticker for ticker in normalized if ticker not in previous]
self.tickers = normalized
if self._price_manager:
if removed:
self._price_manager.unsubscribe(removed)
if added:
if self.mock_mode:
self._price_manager.subscribe(
added,
base_prices={ticker: 100.0 for ticker in added},
)
else:
self._price_manager.subscribe(added)
if self.backtest_mode and self._current_date:
self._price_manager.set_date(self._current_date)
for ticker in removed:
self.cache.pop(ticker, None)
return {
"added": added,
"removed": removed,
"active": list(self.tickers),
}
# Backtest methods
def set_backtest_date(self, date: str):
"""Set current backtest date"""

View File

@@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
"""Run-scoped SQLite storage for query-oriented runtime history."""
from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
SCHEMA = """
CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
timestamp TEXT,
agent_id TEXT,
agent_name TEXT,
ticker TEXT,
title TEXT,
content TEXT,
payload_json TEXT NOT NULL,
run_date TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS trades (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
side TEXT,
qty REAL,
price REAL,
timestamp TEXT,
trading_date TEXT,
agent_id TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS signals (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
agent_id TEXT,
agent_name TEXT,
role TEXT,
signal TEXT,
confidence REAL,
reasoning_json TEXT,
real_return REAL,
is_correct TEXT,
trade_date TEXT,
created_at TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
CREATE TABLE IF NOT EXISTS price_points (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
timestamp TEXT NOT NULL,
price REAL NOT NULL,
open_price REAL,
ret REAL,
source TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
"""
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def _hash_key(*parts: Any) -> str:
raw = "::".join("" if part is None else str(part) for part in parts)
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
class RuntimeDb:
"""Small SQLite helper for append-mostly runtime data."""
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(SCHEMA)
def insert_event(self, event: Dict[str, Any]):
payload = dict(event or {})
if not payload:
return
event_id = payload.get("id") or _hash_key(
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("content"),
payload.get("title"),
)
ticker = payload.get("ticker")
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
ticker = payload["tickers"][0]
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO events
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
event_id,
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("agentName") or payload.get("agent_name"),
ticker,
payload.get("title"),
payload.get("content"),
_json_dumps(payload),
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
),
)
def upsert_trade(self, trade: Dict[str, Any]):
payload = dict(trade or {})
if not payload:
return
trade_id = payload.get("id") or _hash_key(
payload.get("ticker"),
payload.get("timestamp") or payload.get("ts"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO trades
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
trade_id,
payload.get("ticker"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
payload.get("timestamp") or payload.get("ts"),
payload.get("trading_date"),
payload.get("agentId") or payload.get("agent_id"),
_json_dumps(payload),
),
)
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
return
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
with self._connect() as conn:
conn.execute("DELETE FROM signals")
for agent in leaderboard:
agent_id = agent.get("agentId")
agent_name = agent.get("name")
role = agent.get("role")
for signal in agent.get("signals", []) or []:
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
continue
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
conn.execute(
"""
INSERT INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def insert_price_point(
self,
*,
ticker: str,
timestamp: str,
price: float,
open_price: Optional[float] = None,
ret: Optional[float] = None,
source: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
):
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO price_points
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
price_id,
ticker,
timestamp,
price,
open_price,
ret,
source,
_json_dumps(meta or {}),
),
)
def get_stock_explain_snapshot(
self,
ticker: str,
*,
limit_events: int = 24,
limit_trades: int = 12,
limit_signals: int = 12,
) -> Dict[str, list[Dict[str, Any]]]:
"""Fetch query-oriented history for a single ticker."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return {"events": [], "trades": [], "signals": []}
with self._connect() as conn:
trade_rows = conn.execute(
"""
SELECT * FROM trades
WHERE ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(symbol, limit_trades),
).fetchall()
signal_rows = conn.execute(
"""
SELECT * FROM signals
WHERE ticker = ?
ORDER BY trade_date DESC, created_at DESC
LIMIT ?
""",
(symbol, limit_signals),
).fetchall()
event_rows = conn.execute(
"""
SELECT * FROM events
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
).fetchall()
normalized_events = []
seen_event_ids: set[str] = set()
for row in event_rows:
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
content = str(row["content"] or payload.get("content") or "")
title = str(row["title"] or payload.get("title") or "")
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
continue
event_id = row["id"]
if event_id in seen_event_ids:
continue
seen_event_ids.add(event_id)
normalized_events.append(
{
"id": event_id,
"type": "mention",
"timestamp": row["timestamp"],
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
"meta": payload.get("conferenceTitle")
or payload.get("feedType")
or row["event_type"],
"body": content,
"tone": "neutral",
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
},
)
if len(normalized_events) >= limit_events:
break
normalized_trades = [
{
"id": row["id"],
"type": "trade",
"timestamp": row["timestamp"],
"title": f"{row['side']} {int(row['qty'] or 0)}",
"meta": "交易执行",
"body": f"成交价 ${float(row['price'] or 0):.2f}",
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
}
for row in trade_rows
]
normalized_signals = [
{
"id": row["id"],
"type": "signal",
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
"meta": row["role"],
"body": (
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
if row["real_return"] is not None
else "该信号暂未完成后验评估"
),
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
}
for row in signal_rows
]
return {
"events": normalized_events,
"trades": normalized_trades,
"signals": normalized_signals,
}

View File

@@ -10,6 +10,8 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from .runtime_db import RuntimeDb
logger = logging.getLogger(__name__)
@@ -61,6 +63,7 @@ class StorageService:
self.state_dir = self.dashboard_dir.parent / "state"
self.state_dir.mkdir(parents=True, exist_ok=True)
self.server_state_file = self.state_dir / "server_state.json"
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
# Feed history (for agent messages)
self.max_feed_history = 200
@@ -114,6 +117,11 @@ class StorageService:
try:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
if file_type == "leaderboard" and isinstance(data, list):
self.runtime_db.replace_signals_for_leaderboard(data)
elif file_type == "trades" and isinstance(data, list):
for trade in data:
self.runtime_db.upsert_trade(trade)
except Exception as e:
logger.error(f"Failed to save {file_type}.json: {e}")
@@ -211,6 +219,7 @@ class StorageService:
try:
with open(self.internal_state_file, "w", encoding="utf-8") as f:
json.dump(state, f, indent=2, ensure_ascii=False)
self._sync_price_history_to_db(state.get("price_history", {}))
except Exception as e:
logger.error(f"Failed to save internal state: {e}")
@@ -231,6 +240,41 @@ class StorageService:
"margin_requirement": 0.25, # Default 25% margin requirement
}
@staticmethod
def _portfolio_is_pristine(portfolio_state: Dict[str, Any]) -> bool:
"""Return whether the persisted portfolio can be safely rebased."""
positions = portfolio_state.get("positions", {})
has_positions = any(
position.get("long", 0) or position.get("short", 0)
for position in positions.values()
)
margin_used = float(portfolio_state.get("margin_used", 0.0) or 0.0)
return not has_positions and margin_used == 0.0
def can_apply_initial_cash(self) -> bool:
"""Only allow initial cash changes before the run has traded."""
state = self.load_internal_state()
if not self._portfolio_is_pristine(state.get("portfolio_state", {})):
return False
if state.get("all_trades"):
return False
return len(state.get("equity_history", [])) <= 1
def apply_initial_cash(self, initial_cash: float) -> bool:
"""Rebase storage state to a new initial cash when the run is pristine."""
if not self.can_apply_initial_cash():
return False
self.initial_cash = float(initial_cash)
if self.internal_state_file.exists():
self.internal_state_file.unlink()
self.initialize_empty_dashboard()
state = self.load_server_state()
self.update_server_state_from_dashboard(state)
self.save_server_state(state)
return True
def save_portfolio_state(self, portfolio: Dict[str, Any]):
"""
Save portfolio state to internal state
@@ -750,6 +794,7 @@ class StorageService:
"last_day_history": [],
"trading_days_total": 0,
"trading_days_completed": 0,
"price_history": {},
}
if not self.server_state_file.exists():
@@ -771,6 +816,11 @@ class StorageService:
)
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
for event in saved_state.get("feed_history", []):
self.runtime_db.insert_event(event)
for trade in saved_state.get("trades", []):
self.runtime_db.upsert_trade(trade)
return saved_state
def save_server_state(self, state: Dict[str, Any]):
@@ -852,6 +902,7 @@ class StorageService:
state["feed_history"] = []
state["feed_history"].insert(0, feed_msg)
self.runtime_db.insert_event(feed_msg)
# Trim to max size
if len(state["feed_history"]) > self.max_feed_history:
@@ -861,6 +912,69 @@ class StorageService:
return True
def record_price_point(
self,
*,
ticker: str,
timestamp: str,
price: float,
open_price: Optional[float] = None,
ret: Optional[float] = None,
source: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
):
"""Persist a runtime price point for later query-oriented reads."""
if not ticker or not timestamp:
return
try:
self.runtime_db.insert_price_point(
ticker=ticker,
timestamp=timestamp,
price=price,
open_price=open_price,
ret=ret,
source=source,
meta=meta,
)
except Exception as exc:
logger.warning("Failed to record price point for %s: %s", ticker, exc)
def _sync_price_history_to_db(self, price_history: Dict[str, Any]):
"""Backfill structured price points from serialized internal state."""
if not isinstance(price_history, dict):
return
for ticker, points in price_history.items():
if not ticker or not isinstance(points, list):
continue
for point in points:
if isinstance(point, (list, tuple)) and len(point) >= 2:
timestamp, price = point[0], point[1]
try:
self.record_price_point(
ticker=str(ticker),
timestamp=str(timestamp),
price=float(price),
)
except (TypeError, ValueError):
continue
elif isinstance(point, dict):
timestamp = point.get("timestamp") or point.get("label") or point.get("date")
price = point.get("price") or point.get("close") or point.get("value")
if not timestamp or price is None:
continue
try:
self.record_price_point(
ticker=str(ticker),
timestamp=str(timestamp),
price=float(price),
open_price=point.get("open"),
ret=point.get("ret"),
source=point.get("source"),
meta=point,
)
except (TypeError, ValueError):
continue
def _get_default_stats(self) -> Dict[str, Any]:
"""Get default stats structure"""
return {
@@ -889,6 +1003,7 @@ class StorageService:
stats = self.load_file("stats") or self._get_default_stats()
trades = self.load_file("trades") or []
leaderboard = self.load_file("leaderboard") or []
internal_state = self.load_internal_state()
# Update state
state["portfolio"] = {
@@ -910,6 +1025,9 @@ class StorageService:
state["stats"] = stats
state["trades"] = trades
state["leaderboard"] = leaderboard
state["price_history"] = internal_state.get("price_history", {})
self.runtime_db.replace_signals_for_leaderboard(leaderboard)
self._sync_price_history_to_db(state["price_history"])
# ========== Live Returns Tracking ==========