确认PokieTicker新闻库数据源

This commit is contained in:
2026-03-16 02:19:25 +08:00
parent 78f133617f
commit 564c92c0c8
182 changed files with 6436 additions and 1050 deletions

View File

@@ -48,15 +48,19 @@ class AnalystAgent(ReActAgent):
f"Must be one of: {list(ANALYST_TYPES.keys())}",
)
self.analyst_type_key = analyst_type
self.analyst_persona = ANALYST_TYPES[analyst_type]["display_name"]
object.__setattr__(self, "analyst_type_key", analyst_type)
object.__setattr__(
self,
"analyst_persona",
ANALYST_TYPES[analyst_type]["display_name"],
)
if agent_id is None:
agent_id = analyst_type
self.agent_id = agent_id
object.__setattr__(self, "agent_id", agent_id)
self.config = config or {}
self.toolkit = toolkit
object.__setattr__(self, "config", config or {})
object.__setattr__(self, "toolkit", toolkit)
sys_prompt = self._load_system_prompt()
kwargs = {
@@ -125,4 +129,12 @@ class AnalystAgent(ReActAgent):
self.config.get("config_name", "default"),
active_skill_dirs=active_skill_dirs,
)
self.sys_prompt = self._load_system_prompt()
self._apply_runtime_sys_prompt(self._load_system_prompt())
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
"""Update the prompt used by future turns and the cached system msg."""
self._sys_prompt = sys_prompt
for msg, _marks in self.memory.content:
if getattr(msg, "role", None) == "system":
msg.content = sys_prompt
break

View File

@@ -38,21 +38,29 @@ class PMAgent(ReActAgent):
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
toolkit: Optional[Toolkit] = None,
):
self.config = config or {}
object.__setattr__(self, "config", config or {})
# Portfolio state
self.portfolio = {
"cash": initial_cash,
"positions": {},
"margin_used": 0.0,
"margin_requirement": margin_requirement,
}
object.__setattr__(
self,
"portfolio",
{
"cash": initial_cash,
"positions": {},
"margin_used": 0.0,
"margin_requirement": margin_requirement,
},
)
# Decisions made in current cycle
self._decisions: Dict[str, Dict] = {}
object.__setattr__(self, "_decisions", {})
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
self._toolkit_factory = toolkit_factory
self._toolkit_factory_kwargs = toolkit_factory_kwargs
object.__setattr__(self, "_toolkit_factory", toolkit_factory)
object.__setattr__(
self,
"_toolkit_factory_kwargs",
toolkit_factory_kwargs,
)
# Create toolkit after local state is ready so bound tool methods can be registered.
if toolkit is None:
@@ -65,7 +73,7 @@ class PMAgent(ReActAgent):
)
else:
toolkit = self._create_toolkit()
self.toolkit = toolkit
object.__setattr__(self, "toolkit", toolkit)
sys_prompt = build_agent_system_prompt(
agent_id=name,
@@ -205,6 +213,42 @@ class PMAgent(ReActAgent):
"""Update portfolio after external execution"""
self.portfolio.update(portfolio)
def _has_open_positions(self) -> bool:
"""Return whether the current portfolio still has non-zero positions."""
for position in self.portfolio.get("positions", {}).values():
if position.get("long", 0) or position.get("short", 0):
return True
return False
def can_apply_initial_cash(self) -> bool:
"""Only allow cash rebasing before any positions or margin exist."""
return (
not self._has_open_positions()
and float(self.portfolio.get("margin_used", 0.0) or 0.0) == 0.0
)
def apply_runtime_portfolio_config(
self,
*,
margin_requirement: Optional[float] = None,
initial_cash: Optional[float] = None,
) -> Dict[str, bool]:
"""Apply safe run-time portfolio config updates."""
result = {
"margin_requirement": False,
"initial_cash": False,
}
if margin_requirement is not None:
self.portfolio["margin_requirement"] = float(margin_requirement)
result["margin_requirement"] = True
if initial_cash is not None and self.can_apply_initial_cash():
self.portfolio["cash"] = float(initial_cash)
result["initial_cash"] = True
return result
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
"""Reload toolkit and system prompt from current run assets."""
from .toolkit_factory import create_agent_toolkit
@@ -221,8 +265,18 @@ class PMAgent(ReActAgent):
owner=self,
**toolkit_kwargs,
)
self.sys_prompt = build_agent_system_prompt(
agent_id=self.name,
config_name=self.config.get("config_name", "default"),
toolkit=self.toolkit,
self._apply_runtime_sys_prompt(
build_agent_system_prompt(
agent_id=self.name,
config_name=self.config.get("config_name", "default"),
toolkit=self.toolkit,
),
)
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
"""Update the prompt used by future turns and the cached system msg."""
self._sys_prompt = sys_prompt
for msg, _marks in self.memory.content:
if getattr(msg, "role", None) == "system":
msg.content = sys_prompt
break

View File

@@ -39,12 +39,12 @@ class RiskAgent(ReActAgent):
config: Configuration dictionary
long_term_memory: Optional ReMeTaskLongTermMemory instance
"""
self.config = config or {}
self.agent_id = name
object.__setattr__(self, "config", config or {})
object.__setattr__(self, "agent_id", name)
if toolkit is None:
toolkit = Toolkit()
self.toolkit = toolkit
object.__setattr__(self, "toolkit", toolkit)
sys_prompt = self._load_system_prompt()
@@ -99,4 +99,12 @@ class RiskAgent(ReActAgent):
self.config.get("config_name", "default"),
active_skill_dirs=active_skill_dirs,
)
self.sys_prompt = self._load_system_prompt()
self._apply_runtime_sys_prompt(self._load_system_prompt())
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
"""Update the prompt used by future turns and the cached system msg."""
self._sys_prompt = sys_prompt
for msg, _marks in self.memory.content:
if getattr(msg, "role", None) == "system":
msg.content = sys_prompt
break

View File

@@ -62,6 +62,59 @@ class SkillsManager:
raise FileNotFoundError(f"Unknown skill: {skill_name}")
def _persist_runtime_edits(
self,
config_name: str,
skill_name: str,
active_dir: Path,
) -> None:
"""
Persist run-time edits from active skills into customized skills.
This keeps active skill experiments from being lost on the next reload
while still allowing the active directory to be re-synced cleanly.
"""
if not active_dir.exists():
return
source_dir = self._resolve_source_dir(skill_name)
if active_dir.resolve() == source_dir.resolve():
return
if not self._directories_match(active_dir, source_dir):
customized_dir = self.customized_root / skill_name
customized_dir.parent.mkdir(parents=True, exist_ok=True)
if customized_dir.exists():
shutil.rmtree(customized_dir)
shutil.copytree(active_dir, customized_dir)
@staticmethod
def _directories_match(left: Path, right: Path) -> bool:
"""Compare two directory trees by file contents."""
if not left.exists() or not right.exists():
return False
left_items = sorted(
path.relative_to(left)
for path in left.rglob("*")
)
right_items = sorted(
path.relative_to(right)
for path in right.rglob("*")
)
if left_items != right_items:
return False
for relative_path in left_items:
left_path = left / relative_path
right_path = right / relative_path
if left_path.is_dir() != right_path.is_dir():
return False
if left_path.is_file():
if left_path.read_bytes() != right_path.read_bytes():
return False
return True
def resolve_agent_skill_names(
self,
config_name: str,
@@ -103,12 +156,22 @@ class SkillsManager:
for existing in active_root.iterdir():
if existing.is_dir() and existing.name not in wanted:
self._persist_runtime_edits(
config_name=config_name,
skill_name=existing.name,
active_dir=existing,
)
shutil.rmtree(existing)
for skill_name in skill_names:
source_dir = self._resolve_source_dir(skill_name)
target_dir = active_root / skill_name
if target_dir.exists():
self._persist_runtime_edits(
config_name=config_name,
skill_name=skill_name,
active_dir=target_dir,
)
shutil.rmtree(target_dir)
shutil.copytree(source_dir, target_dir)
synced_paths.append(target_dir)

View File

@@ -49,9 +49,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
config_name: Configuration name for the run
auto_clean: If True, skip confirmation and clean automatically
"""
# logs_dir = get_project_root() / "logs"
logs_dir = get_project_root()
base_data_dir = logs_dir / config_name
workspace_manager = WorkspaceManager(project_root=get_project_root())
base_data_dir = workspace_manager.get_run_dir(config_name)
# Check if historical data exists
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):

View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""Parse run-scoped BOOTSTRAP.md into structured configuration."""
"""Parse run-scoped BOOTSTRAP.md into structured and runtime config."""
from dataclasses import dataclass, field
from pathlib import Path
@@ -8,6 +8,8 @@ import re
import yaml
from backend.config.env_config import get_env_float, get_env_int, get_env_list
BOOTSTRAP_FRONT_MATTER_RE = re.compile(
r"^---\s*\n(.*?)\n---\s*\n?(.*)$",
@@ -63,3 +65,84 @@ def get_bootstrap_config_for_run(
return load_bootstrap_config(
project_root / "runs" / config_name / "BOOTSTRAP.md",
)
def save_bootstrap_config(bootstrap_path: Path, config: BootstrapConfig) -> None:
"""Persist structured bootstrap config back to BOOTSTRAP.md."""
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
values = config.values if isinstance(config.values, dict) else {}
front_matter = yaml.safe_dump(
values,
allow_unicode=True,
sort_keys=False,
).strip()
body = (config.prompt_body or "").strip()
content = f"---\n{front_matter}\n---"
if body:
content += f"\n\n{body}\n"
else:
content += "\n"
bootstrap_path.write_text(content, encoding="utf-8")
def update_bootstrap_values_for_run(
project_root: Path,
config_name: str,
updates: Dict[str, Any],
) -> BootstrapConfig:
"""Patch selected front matter keys for a run and persist them."""
bootstrap_path = project_root / "runs" / config_name / "BOOTSTRAP.md"
existing = load_bootstrap_config(bootstrap_path)
values = dict(existing.values)
values.update(updates)
updated = BootstrapConfig(values=values, prompt_body=existing.prompt_body)
save_bootstrap_config(bootstrap_path, updated)
return updated
def _coerce_bool(value: Any) -> bool:
"""Parse booleans from bootstrap-friendly string values."""
if isinstance(value, bool):
return value
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return bool(value)
def resolve_runtime_config(
project_root: Path,
config_name: str,
enable_memory: bool = False,
) -> Dict[str, Any]:
"""Merge env defaults with run-scoped bootstrap front matter."""
bootstrap = get_bootstrap_config_for_run(project_root, config_name)
return {
"tickers": bootstrap.get("tickers")
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
"initial_cash": float(
bootstrap.get(
"initial_cash",
get_env_float("INITIAL_CASH", 100000.0),
),
),
"margin_requirement": float(
bootstrap.get(
"margin_requirement",
get_env_float("MARGIN_REQUIREMENT", 0.0),
),
),
"max_comm_cycles": int(
bootstrap.get(
"max_comm_cycles",
get_env_int("MAX_COMM_CYCLES", 2),
),
),
"enable_memory": bool(enable_memory)
or _coerce_bool(bootstrap.get("enable_memory", False)),
}

View File

@@ -226,12 +226,18 @@ class TradingPipeline:
"settlement_result": settlement_result,
}
def reload_runtime_assets(self) -> Dict[str, Any]:
"""Reload prompt assets, bootstrap config, and active skills for all agents."""
def reload_runtime_assets(
self,
runtime_config: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Reload prompt assets and safe in-process runtime settings."""
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
if runtime_config and "max_comm_cycles" in runtime_config:
self.max_comm_cycles = int(runtime_config["max_comm_cycles"])
skills_manager = SkillsManager()
profiles = load_agent_profiles()
active_skill_map = skills_manager.prepare_active_skills(
@@ -262,6 +268,7 @@ class TradingPipeline:
agent_id: [path.name for path in paths]
for agent_id, paths in active_skill_map.items()
},
"max_comm_cycles": self.max_comm_cycles,
}
async def _clear_all_agent_memory(self):

View File

@@ -438,6 +438,8 @@ class StateSync:
"server_mode": self._state.get("server_mode", "live"),
"is_mock_mode": self._state.get("is_mock_mode", False),
"is_backtest": self._state.get("is_backtest", False),
"tickers": self._state.get("tickers"),
"runtime_config": self._state.get("runtime_config"),
"feed_history": self._state.get("feed_history", []),
"current_date": self._state.get("current_date"),
"trading_days_total": self._state.get("trading_days_total", 0),
@@ -452,6 +454,7 @@ class StateSync:
"portfolio": self._state.get("portfolio", {}),
"realtime_prices": self._state.get("realtime_prices", {}),
"data_sources": self._state.get("data_sources", {}),
"price_history": self._state.get("price_history", {}),
}
if include_dashboard:

View File

@@ -30,6 +30,25 @@ logger = logging.getLogger(__name__)
_DATA_DIR = Path(__file__).parent / "ret_data"
def _format_provider_error(exc: Exception) -> str:
"""Condense common provider failures into short, readable messages."""
message = str(exc).strip().replace("\n", " ")
if "429" in message:
return "rate limit reached"
if "402" in message:
return "insufficient credits"
if "422" in message or "Missing parameters" in message:
return "invalid request parameters"
if "Quote not found" in message:
return "quote not found"
return message
def _has_valid_ticker(ticker: str) -> bool:
"""Return whether the normalized ticker is non-empty."""
return bool((ticker or "").strip())
class DataProviderRouter:
"""Route data requests across configured providers with fallbacks."""
@@ -56,6 +75,8 @@ class DataProviderRouter:
end_date: str,
) -> tuple[list[Price], DataSource]:
"""Fetch prices using preferred providers with fallback."""
if not _has_valid_ticker(ticker):
return [], "local_csv"
last_error: Optional[Exception] = None
for source in self.price_sources():
@@ -78,7 +99,12 @@ class DataProviderRouter:
return prices, source
except Exception as exc:
last_error = exc
logger.warning("Price source %s failed for %s: %s", source, ticker, exc)
logger.warning(
"Price source %s failed for %s: %s",
source,
ticker,
_format_provider_error(exc),
)
if last_error:
raise last_error
@@ -92,6 +118,8 @@ class DataProviderRouter:
limit: int = 10,
) -> tuple[list[FinancialMetrics], DataSource]:
"""Fetch financial metrics with API provider fallback."""
if not _has_valid_ticker(ticker):
return [], "local_csv"
last_error: Optional[Exception] = None
for source in self.api_sources():
@@ -126,7 +154,7 @@ class DataProviderRouter:
"Financial metrics source %s failed for %s: %s",
source,
ticker,
exc,
_format_provider_error(exc),
)
if last_error:
@@ -142,6 +170,8 @@ class DataProviderRouter:
limit: int = 10,
) -> list[LineItem]:
"""Line items are only supported via Financial Datasets."""
if not _has_valid_ticker(ticker):
return []
if "financial_datasets" not in self.api_sources():
return []
try:
@@ -155,7 +185,11 @@ class DataProviderRouter:
self._record_success("line_items", "financial_datasets")
return results
except Exception as exc:
logger.warning("Line items source failed for %s: %s", ticker, exc)
logger.warning(
"Line items source failed for %s: %s",
ticker,
_format_provider_error(exc),
)
return []
def get_insider_trades(
@@ -166,6 +200,8 @@ class DataProviderRouter:
limit: int = 1000,
) -> tuple[list[InsiderTrade], DataSource]:
"""Fetch insider trades with provider fallback."""
if not _has_valid_ticker(ticker):
return [], "local_csv"
last_error: Optional[Exception] = None
for source in self.api_sources():
@@ -193,7 +229,7 @@ class DataProviderRouter:
"Insider trades source %s failed for %s: %s",
source,
ticker,
exc,
_format_provider_error(exc),
)
if last_error:
@@ -208,6 +244,8 @@ class DataProviderRouter:
limit: int = 1000,
) -> tuple[list[CompanyNews], DataSource]:
"""Fetch company news with provider fallback."""
if not _has_valid_ticker(ticker):
return [], "local_csv"
last_error: Optional[Exception] = None
for source in self.api_sources():
@@ -244,7 +282,7 @@ class DataProviderRouter:
"Company news source %s failed for %s: %s",
source,
ticker,
exc,
_format_provider_error(exc),
)
if last_error:
@@ -258,6 +296,8 @@ class DataProviderRouter:
metrics_lookup,
) -> tuple[Optional[float], DataSource]:
"""Fetch market cap using facts API or financial metrics fallback."""
if not _has_valid_ticker(ticker):
return None, "local_csv"
today = datetime.datetime.now().strftime("%Y-%m-%d")
if end_date == today and "financial_datasets" in self.api_sources():
try:
@@ -267,7 +307,7 @@ class DataProviderRouter:
logger.warning(
"Market cap facts source failed for %s: %s",
ticker,
exc,
_format_provider_error(exc),
)
metrics, source = metrics_lookup(ticker, end_date)

View File

@@ -18,9 +18,8 @@ from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import PromptLoader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.bootstrap_config import get_bootstrap_config_for_run
from backend.config.bootstrap_config import resolve_runtime_config
from backend.config.constants import ANALYST_TYPES
from backend.config.env_config import get_env_float, get_env_int, get_env_list
from backend.core.pipeline import TradingPipeline
from backend.core.scheduler import BacktestScheduler, Scheduler
from backend.utils.settlement import SettlementCoordinator
@@ -36,35 +35,20 @@ loguru.logger.disable("reme_ai")
_prompt_loader = PromptLoader()
def _get_run_dir(config_name: str) -> Path:
"""Return the canonical run-scoped directory for a config."""
project_root = Path(__file__).resolve().parents[1]
return WorkspaceManager(project_root=project_root).get_run_dir(config_name)
def _resolve_runtime_config(args) -> dict:
"""Merge env defaults with run-scoped bootstrap config."""
project_root = Path(__file__).resolve().parents[1]
bootstrap = get_bootstrap_config_for_run(project_root, args.config_name)
return {
"tickers": bootstrap.get("tickers")
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
"initial_cash": float(
bootstrap.get(
"initial_cash",
get_env_float("INITIAL_CASH", 100000.0),
),
),
"margin_requirement": float(
bootstrap.get(
"margin_requirement",
get_env_float("MARGIN_REQUIREMENT", 0.0),
),
),
"max_comm_cycles": int(
bootstrap.get(
"max_comm_cycles",
get_env_int("MAX_COMM_CYCLES", 2),
),
),
"enable_memory": args.enable_memory
or bool(bootstrap.get("enable_memory", False)),
}
return resolve_runtime_config(
project_root=project_root,
config_name=args.config_name,
enable_memory=args.enable_memory,
)
def create_long_term_memory(agent_name: str, config_name: str):
@@ -82,7 +66,7 @@ def create_long_term_memory(agent_name: str, config_name: str):
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
return None
memory_dir = str(Path(config_name) / "memory")
memory_dir = str(_get_run_dir(config_name) / "memory")
return ReMeTaskLongTermMemory(
agent_name=agent_name,
@@ -241,7 +225,7 @@ async def run_with_gateway(args):
# Create storage service
storage_service = StorageService(
dashboard_dir=Path(config_name) / "team_dashboard",
dashboard_dir=_get_run_dir(config_name) / "team_dashboard",
initial_cash=initial_cash,
config_name=config_name,
)
@@ -316,6 +300,10 @@ async def run_with_gateway(args):
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": config_name,
"initial_cash": initial_cash,
"margin_requirement": margin_requirement,
"max_comm_cycles": runtime_config["max_comm_cycles"],
"enable_memory": runtime_config["enable_memory"],
},
)

View File

@@ -5,12 +5,18 @@ WebSocket Gateway for frontend communication
import asyncio
import json
import logging
from datetime import datetime
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set
import websockets
from websockets.server import WebSocketServerProtocol
from websockets.asyncio.server import ServerConnection
from backend.config.bootstrap_config import (
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.provider_utils import normalize_symbol
from backend.utils.msg_adapter import FrontendAdapter
from backend.utils.terminal_dashboard import get_dashboard
from backend.core.pipeline import TradingPipeline
@@ -18,6 +24,7 @@ from backend.core.state_sync import StateSync
from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.data.provider_router import get_provider_router
from backend.tools.data_tools import get_prices
logger = logging.getLogger(__name__)
@@ -51,7 +58,7 @@ class Gateway:
self.state_sync.set_broadcast_fn(self.broadcast)
self.pipeline.state_sync = self.state_sync
self.connected_clients: Set[WebSocketServerProtocol] = set()
self.connected_clients: Set[ServerConnection] = set()
self.lock = asyncio.Lock()
self._backtest_task: Optional[asyncio.Task] = None
self._backtest_start_date: Optional[str] = None
@@ -63,6 +70,7 @@ class Gateway:
self._session_start_portfolio_value: Optional[float] = None
self._provider_router = get_provider_router()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._project_root = Path(__file__).resolve().parents[2]
async def start(self, host: str = "0.0.0.0", port: int = 8766):
"""Start gateway server"""
@@ -87,6 +95,7 @@ class Gateway:
self._dashboard.start()
self.state_sync.load_state()
self.market_service.set_price_recorder(self.storage.record_price_point)
self.state_sync.update_state("status", "running")
self.state_sync.update_state("server_mode", self.mode)
self.state_sync.update_state("is_backtest", self.is_backtest)
@@ -94,6 +103,20 @@ class Gateway:
"is_mock_mode",
self.config.get("mock_mode", False),
)
self.state_sync.update_state("tickers", self.config.get("tickers", []))
self.state_sync.update_state(
"runtime_config",
{
"tickers": self.config.get("tickers", []),
"initial_cash": self.config.get(
"initial_cash",
self.storage.initial_cash,
),
"margin_requirement": self.config.get("margin_requirement"),
"max_comm_cycles": self.config.get("max_comm_cycles"),
"enable_memory": self.config.get("enable_memory", False),
},
)
self.state_sync.update_state(
"data_sources",
self._provider_router.get_usage_snapshot(),
@@ -159,7 +182,7 @@ class Gateway:
def state(self) -> Dict[str, Any]:
return self.state_sync.state
async def handle_client(self, websocket: WebSocketServerProtocol):
async def handle_client(self, websocket: ServerConnection):
"""Handle WebSocket client connection"""
async with self.lock:
self.connected_clients.add(websocket)
@@ -170,7 +193,7 @@ class Gateway:
async with self.lock:
self.connected_clients.discard(websocket)
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
async def _send_initial_state(self, websocket: ServerConnection):
state_payload = self.state_sync.get_initial_state_payload(
include_dashboard=True,
)
@@ -198,7 +221,7 @@ class Gateway:
async def _handle_client_messages(
self,
websocket: WebSocketServerProtocol,
websocket: ServerConnection,
):
try:
async for message in websocket:
@@ -221,12 +244,104 @@ class Gateway:
await self._handle_start_backtest(data)
elif msg_type == "reload_runtime_assets":
await self._handle_reload_runtime_assets()
elif msg_type == "update_watchlist":
await self._handle_update_watchlist(websocket, data)
elif msg_type == "get_stock_history":
await self._handle_get_stock_history(websocket, data)
elif msg_type == "get_stock_explain_events":
await self._handle_get_stock_explain_events(websocket, data)
except websockets.ConnectionClosed:
pass
except json.JSONDecodeError:
pass
async def _handle_get_stock_history(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
},
ensure_ascii=False,
),
)
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = self.state_sync.state.get("current_date")
if not end_date:
end_date = datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
"%Y-%m-%d",
)
prices = await asyncio.to_thread(
get_prices,
ticker,
start_date,
end_date,
)
usage_snapshot = self._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
await websocket.send(
json.dumps(
{
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
},
ensure_ascii=False,
default=str,
),
)
async def _handle_get_stock_explain_events(
self,
websocket: ServerConnection,
data: Dict[str, Any],
):
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(
json.dumps(
{
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
},
ensure_ascii=False,
default=str,
),
)
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
@@ -239,8 +354,15 @@ class Gateway:
self._backtest_task = task
async def _handle_reload_runtime_assets(self):
"""Reload prompt assets and active skills without restarting the server."""
result = self.pipeline.reload_runtime_assets()
"""Reload prompt, skills, and safe runtime config without restart."""
config_name = self.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=self._project_root,
config_name=config_name,
enable_memory=self.config.get("enable_memory", False),
)
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = self._apply_runtime_config(runtime_config)
await self.state_sync.on_system_message(
"Runtime assets reloaded.",
)
@@ -248,9 +370,174 @@ class Gateway:
{
"type": "runtime_assets_reloaded",
**result,
**runtime_updates,
},
)
async def _handle_update_watchlist(
self,
websocket: ServerConnection,
data: Dict[str, Any],
) -> None:
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
tickers = self._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(
json.dumps(
{
"type": "error",
"message": "update_watchlist requires at least one valid ticker.",
},
ensure_ascii=False,
),
)
return
config_name = self.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=self._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await self.state_sync.on_system_message(
f"Watchlist updated: {', '.join(tickers)}",
)
await self.broadcast(
{
"type": "watchlist_updated",
"config_name": config_name,
"tickers": tickers,
},
)
await self._handle_reload_runtime_assets()
@staticmethod
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: List[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def _apply_runtime_config(
self,
runtime_config: Dict[str, Any],
) -> Dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: List[str] = []
ticker_changes = self.market_service.update_tickers(
runtime_config.get("tickers", []),
)
self.config["tickers"] = ticker_changes["active"]
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(self.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
self.storage.can_apply_initial_cash()
and self.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = self.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
self.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
self.config["initial_cash"] = self.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(self.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
self._sync_runtime_state()
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(self.config.get("tickers", [])),
"initial_cash": self.storage.initial_cash,
"margin_requirement": self.config["margin_requirement"],
"max_comm_cycles": self.config["max_comm_cycles"],
"enable_memory": self.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def _sync_runtime_state(self) -> None:
"""Refresh persisted state and dashboard after runtime config changes."""
self.state_sync.update_state("tickers", self.config.get("tickers", []))
self.state_sync.update_state(
"runtime_config",
{
"tickers": self.config.get("tickers", []),
"initial_cash": self.storage.initial_cash,
"margin_requirement": self.config.get("margin_requirement"),
"max_comm_cycles": self.config.get("max_comm_cycles"),
"enable_memory": self.config.get("enable_memory", False),
},
)
self.storage.update_server_state_from_dashboard(self.state_sync.state)
self.state_sync.save_state()
self._dashboard.tickers = list(self.config.get("tickers", []))
self._dashboard.initial_cash = self.storage.initial_cash
self._dashboard.enable_memory = bool(
self.config.get("enable_memory", False),
)
summary = self.storage.load_file("summary") or {}
holdings = self.storage.load_file("holdings") or []
trades = self.storage.load_file("trades") or []
self._dashboard.update(
portfolio=summary,
holdings=holdings,
trades=trades,
)
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
@@ -269,7 +556,7 @@ class Gateway:
async def _send_to_client(
self,
client: WebSocketServerProtocol,
client: ServerConnection,
message: str,
):
try:

View File

@@ -54,6 +54,7 @@ class MarketService:
self.running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._broadcast_func: Optional[Callable] = None
self._price_record_func: Optional[Callable[..., None]] = None
self._price_manager: Optional[Any] = None
self._current_date: Optional[str] = None
@@ -92,6 +93,10 @@ class MarketService:
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
)
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
"""Register an optional callback for persisting runtime price points."""
self._price_record_func = recorder
def _make_price_callback(self) -> Callable:
"""Create thread-safe price callback"""
@@ -169,6 +174,24 @@ class MarketService:
((price - open_price) / open_price) * 100 if open_price > 0 else 0
)
if self._price_record_func:
try:
self._price_record_func(
ticker=symbol,
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
price=float(price),
open_price=float(open_price) if open_price is not None else None,
ret=float(ret),
source=self.mode_name.lower(),
meta=price_data,
)
except Exception as exc:
logger.warning(
"Failed to record price point for %s: %s",
symbol,
exc,
)
await self._broadcast_func(
{
"type": "price_update",
@@ -205,6 +228,43 @@ class MarketService:
self._loop = None
self._broadcast_func = None
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
"""Hot-update subscribed tickers without restarting the service."""
normalized: List[str] = []
for ticker in tickers:
symbol = normalize_symbol(ticker)
if symbol and symbol not in normalized:
normalized.append(symbol)
previous = list(self.tickers)
removed = [ticker for ticker in previous if ticker not in normalized]
added = [ticker for ticker in normalized if ticker not in previous]
self.tickers = normalized
if self._price_manager:
if removed:
self._price_manager.unsubscribe(removed)
if added:
if self.mock_mode:
self._price_manager.subscribe(
added,
base_prices={ticker: 100.0 for ticker in added},
)
else:
self._price_manager.subscribe(added)
if self.backtest_mode and self._current_date:
self._price_manager.set_date(self._current_date)
for ticker in removed:
self.cache.pop(ticker, None)
return {
"added": added,
"removed": removed,
"active": list(self.tickers),
}
# Backtest methods
def set_backtest_date(self, date: str):
"""Set current backtest date"""

View File

@@ -0,0 +1,388 @@
# -*- coding: utf-8 -*-
"""Run-scoped SQLite storage for query-oriented runtime history."""
from __future__ import annotations
import hashlib
import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, Iterable, Optional
SCHEMA = """
CREATE TABLE IF NOT EXISTS events (
id TEXT PRIMARY KEY,
event_type TEXT NOT NULL,
timestamp TEXT,
agent_id TEXT,
agent_name TEXT,
ticker TEXT,
title TEXT,
content TEXT,
payload_json TEXT NOT NULL,
run_date TEXT
);
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS trades (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
side TEXT,
qty REAL,
price REAL,
timestamp TEXT,
trading_date TEXT,
agent_id TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
CREATE TABLE IF NOT EXISTS signals (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
agent_id TEXT,
agent_name TEXT,
role TEXT,
signal TEXT,
confidence REAL,
reasoning_json TEXT,
real_return REAL,
is_correct TEXT,
trade_date TEXT,
created_at TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
CREATE TABLE IF NOT EXISTS price_points (
id TEXT PRIMARY KEY,
ticker TEXT NOT NULL,
timestamp TEXT NOT NULL,
price REAL NOT NULL,
open_price REAL,
ret REAL,
source TEXT,
meta_json TEXT
);
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
"""
def _json_dumps(value: Any) -> str:
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
def _hash_key(*parts: Any) -> str:
raw = "::".join("" if part is None else str(part) for part in parts)
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
class RuntimeDb:
"""Small SQLite helper for append-mostly runtime data."""
def __init__(self, db_path: Path):
self.db_path = Path(db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
return conn
def _init_db(self):
with self._connect() as conn:
conn.executescript(SCHEMA)
def insert_event(self, event: Dict[str, Any]):
payload = dict(event or {})
if not payload:
return
event_id = payload.get("id") or _hash_key(
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("content"),
payload.get("title"),
)
ticker = payload.get("ticker")
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
ticker = payload["tickers"][0]
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO events
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
event_id,
payload.get("type"),
payload.get("timestamp"),
payload.get("agentId") or payload.get("agent_id"),
payload.get("agentName") or payload.get("agent_name"),
ticker,
payload.get("title"),
payload.get("content"),
_json_dumps(payload),
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
),
)
def upsert_trade(self, trade: Dict[str, Any]):
payload = dict(trade or {})
if not payload:
return
trade_id = payload.get("id") or _hash_key(
payload.get("ticker"),
payload.get("timestamp") or payload.get("ts"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO trades
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
trade_id,
payload.get("ticker"),
payload.get("side"),
payload.get("qty"),
payload.get("price"),
payload.get("timestamp") or payload.get("ts"),
payload.get("trading_date"),
payload.get("agentId") or payload.get("agent_id"),
_json_dumps(payload),
),
)
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
return
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
with self._connect() as conn:
conn.execute(
"""
INSERT OR REPLACE INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
with self._connect() as conn:
conn.execute("DELETE FROM signals")
for agent in leaderboard:
agent_id = agent.get("agentId")
agent_name = agent.get("name")
role = agent.get("role")
for signal in agent.get("signals", []) or []:
payload = dict(signal or {})
ticker = payload.get("ticker")
if not ticker:
continue
signal_id = _hash_key(
agent_id,
ticker,
payload.get("date"),
payload.get("signal"),
payload.get("confidence"),
)
conn.execute(
"""
INSERT INTO signals
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
real_return, is_correct, trade_date, created_at, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
signal_id,
ticker,
agent_id,
agent_name,
role,
payload.get("signal"),
payload.get("confidence"),
_json_dumps(payload.get("reasoning")),
payload.get("real_return"),
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
payload.get("date"),
payload.get("created_at") or payload.get("date"),
_json_dumps(payload),
),
)
def insert_price_point(
self,
*,
ticker: str,
timestamp: str,
price: float,
open_price: Optional[float] = None,
ret: Optional[float] = None,
source: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
):
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
with self._connect() as conn:
conn.execute(
"""
INSERT OR IGNORE INTO price_points
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
(
price_id,
ticker,
timestamp,
price,
open_price,
ret,
source,
_json_dumps(meta or {}),
),
)
def get_stock_explain_snapshot(
self,
ticker: str,
*,
limit_events: int = 24,
limit_trades: int = 12,
limit_signals: int = 12,
) -> Dict[str, list[Dict[str, Any]]]:
"""Fetch query-oriented history for a single ticker."""
symbol = str(ticker or "").strip().upper()
if not symbol:
return {"events": [], "trades": [], "signals": []}
with self._connect() as conn:
trade_rows = conn.execute(
"""
SELECT * FROM trades
WHERE ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(symbol, limit_trades),
).fetchall()
signal_rows = conn.execute(
"""
SELECT * FROM signals
WHERE ticker = ?
ORDER BY trade_date DESC, created_at DESC
LIMIT ?
""",
(symbol, limit_signals),
).fetchall()
event_rows = conn.execute(
"""
SELECT * FROM events
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
ORDER BY timestamp DESC
LIMIT ?
""",
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
).fetchall()
normalized_events = []
seen_event_ids: set[str] = set()
for row in event_rows:
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
content = str(row["content"] or payload.get("content") or "")
title = str(row["title"] or payload.get("title") or "")
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
continue
event_id = row["id"]
if event_id in seen_event_ids:
continue
seen_event_ids.add(event_id)
normalized_events.append(
{
"id": event_id,
"type": "mention",
"timestamp": row["timestamp"],
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
"meta": payload.get("conferenceTitle")
or payload.get("feedType")
or row["event_type"],
"body": content,
"tone": "neutral",
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
},
)
if len(normalized_events) >= limit_events:
break
normalized_trades = [
{
"id": row["id"],
"type": "trade",
"timestamp": row["timestamp"],
"title": f"{row['side']} {int(row['qty'] or 0)}",
"meta": "交易执行",
"body": f"成交价 ${float(row['price'] or 0):.2f}",
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
}
for row in trade_rows
]
normalized_signals = [
{
"id": row["id"],
"type": "signal",
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
"meta": row["role"],
"body": (
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
if row["real_return"] is not None
else "该信号暂未完成后验评估"
),
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
}
for row in signal_rows
]
return {
"events": normalized_events,
"trades": normalized_trades,
"signals": normalized_signals,
}

View File

@@ -10,6 +10,8 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from .runtime_db import RuntimeDb
logger = logging.getLogger(__name__)
@@ -61,6 +63,7 @@ class StorageService:
self.state_dir = self.dashboard_dir.parent / "state"
self.state_dir.mkdir(parents=True, exist_ok=True)
self.server_state_file = self.state_dir / "server_state.json"
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
# Feed history (for agent messages)
self.max_feed_history = 200
@@ -114,6 +117,11 @@ class StorageService:
try:
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
if file_type == "leaderboard" and isinstance(data, list):
self.runtime_db.replace_signals_for_leaderboard(data)
elif file_type == "trades" and isinstance(data, list):
for trade in data:
self.runtime_db.upsert_trade(trade)
except Exception as e:
logger.error(f"Failed to save {file_type}.json: {e}")
@@ -211,6 +219,7 @@ class StorageService:
try:
with open(self.internal_state_file, "w", encoding="utf-8") as f:
json.dump(state, f, indent=2, ensure_ascii=False)
self._sync_price_history_to_db(state.get("price_history", {}))
except Exception as e:
logger.error(f"Failed to save internal state: {e}")
@@ -231,6 +240,41 @@ class StorageService:
"margin_requirement": 0.25, # Default 25% margin requirement
}
@staticmethod
def _portfolio_is_pristine(portfolio_state: Dict[str, Any]) -> bool:
"""Return whether the persisted portfolio can be safely rebased."""
positions = portfolio_state.get("positions", {})
has_positions = any(
position.get("long", 0) or position.get("short", 0)
for position in positions.values()
)
margin_used = float(portfolio_state.get("margin_used", 0.0) or 0.0)
return not has_positions and margin_used == 0.0
def can_apply_initial_cash(self) -> bool:
"""Only allow initial cash changes before the run has traded."""
state = self.load_internal_state()
if not self._portfolio_is_pristine(state.get("portfolio_state", {})):
return False
if state.get("all_trades"):
return False
return len(state.get("equity_history", [])) <= 1
def apply_initial_cash(self, initial_cash: float) -> bool:
"""Rebase storage state to a new initial cash when the run is pristine."""
if not self.can_apply_initial_cash():
return False
self.initial_cash = float(initial_cash)
if self.internal_state_file.exists():
self.internal_state_file.unlink()
self.initialize_empty_dashboard()
state = self.load_server_state()
self.update_server_state_from_dashboard(state)
self.save_server_state(state)
return True
def save_portfolio_state(self, portfolio: Dict[str, Any]):
"""
Save portfolio state to internal state
@@ -750,6 +794,7 @@ class StorageService:
"last_day_history": [],
"trading_days_total": 0,
"trading_days_completed": 0,
"price_history": {},
}
if not self.server_state_file.exists():
@@ -771,6 +816,11 @@ class StorageService:
)
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
for event in saved_state.get("feed_history", []):
self.runtime_db.insert_event(event)
for trade in saved_state.get("trades", []):
self.runtime_db.upsert_trade(trade)
return saved_state
def save_server_state(self, state: Dict[str, Any]):
@@ -852,6 +902,7 @@ class StorageService:
state["feed_history"] = []
state["feed_history"].insert(0, feed_msg)
self.runtime_db.insert_event(feed_msg)
# Trim to max size
if len(state["feed_history"]) > self.max_feed_history:
@@ -861,6 +912,69 @@ class StorageService:
return True
def record_price_point(
self,
*,
ticker: str,
timestamp: str,
price: float,
open_price: Optional[float] = None,
ret: Optional[float] = None,
source: Optional[str] = None,
meta: Optional[Dict[str, Any]] = None,
):
"""Persist a runtime price point for later query-oriented reads."""
if not ticker or not timestamp:
return
try:
self.runtime_db.insert_price_point(
ticker=ticker,
timestamp=timestamp,
price=price,
open_price=open_price,
ret=ret,
source=source,
meta=meta,
)
except Exception as exc:
logger.warning("Failed to record price point for %s: %s", ticker, exc)
def _sync_price_history_to_db(self, price_history: Dict[str, Any]):
"""Backfill structured price points from serialized internal state."""
if not isinstance(price_history, dict):
return
for ticker, points in price_history.items():
if not ticker or not isinstance(points, list):
continue
for point in points:
if isinstance(point, (list, tuple)) and len(point) >= 2:
timestamp, price = point[0], point[1]
try:
self.record_price_point(
ticker=str(ticker),
timestamp=str(timestamp),
price=float(price),
)
except (TypeError, ValueError):
continue
elif isinstance(point, dict):
timestamp = point.get("timestamp") or point.get("label") or point.get("date")
price = point.get("price") or point.get("close") or point.get("value")
if not timestamp or price is None:
continue
try:
self.record_price_point(
ticker=str(ticker),
timestamp=str(timestamp),
price=float(price),
open_price=point.get("open"),
ret=point.get("ret"),
source=point.get("source"),
meta=point,
)
except (TypeError, ValueError):
continue
def _get_default_stats(self) -> Dict[str, Any]:
"""Get default stats structure"""
return {
@@ -889,6 +1003,7 @@ class StorageService:
stats = self.load_file("stats") or self._get_default_stats()
trades = self.load_file("trades") or []
leaderboard = self.load_file("leaderboard") or []
internal_state = self.load_internal_state()
# Update state
state["portfolio"] = {
@@ -910,6 +1025,9 @@ class StorageService:
state["stats"] = stats
state["trades"] = trades
state["leaderboard"] = leaderboard
state["price_history"] = internal_state.get("price_history", {})
self.runtime_db.replace_signals_for_leaderboard(leaderboard)
self._sync_price_history_to_db(state["price_history"])
# ========== Live Returns Tracking ==========

View File

@@ -7,6 +7,7 @@ Returns human-readable text format for easy LLM consumption.
"""
# flake8: noqa: E501
# pylint: disable=C0301,W0613
import ast
import json
import logging
import traceback
@@ -20,6 +21,7 @@ import pandas as pd
from agentscope.message import TextBlock
from agentscope.tool import ToolResponse
from backend.data.provider_utils import normalize_symbol
from backend.tools.data_tools import (
get_company_news,
get_financial_metrics,
@@ -53,6 +55,16 @@ def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
Returns:
List of stock tickers.
"""
def _sanitize(values: List[object]) -> List[str]:
cleaned: List[str] = []
for value in values:
if value is None:
continue
symbol = normalize_symbol(str(value).strip().strip("\"'"))
if symbol and symbol not in cleaned:
cleaned.append(symbol)
return cleaned
if tickers is None:
return []
@@ -60,17 +72,22 @@ def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
try:
parsed = json.loads(tickers)
if isinstance(parsed, list):
return parsed
# If it's a single string, wrap in list
return [parsed]
return _sanitize(parsed)
return _sanitize([parsed])
except json.JSONDecodeError:
# If not valid JSON, treat as comma-separated string
return [t.strip() for t in tickers.split(",") if t.strip()]
try:
parsed = ast.literal_eval(tickers)
if isinstance(parsed, list):
return _sanitize(parsed)
return _sanitize([parsed])
except (SyntaxError, ValueError):
pass
return _sanitize(tickers.split(","))
if isinstance(tickers, list):
return tickers
return _sanitize(tickers)
return []
return _sanitize([tickers])
def _safe_float(value, default=0.0) -> float:
@@ -350,6 +367,7 @@ def get_financial_metrics_tool(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [
f"=== Comprehensive Financial Metrics ({current_date}, {period}) ===\n",
]

View File

@@ -96,13 +96,19 @@ def get_prices(
list[Price]: List of Price objects
"""
ticker = normalize_symbol(ticker)
if not ticker:
return []
cached_sources = _router.price_sources()
for source in cached_sources:
cache_key = f"{ticker}_{start_date}_{end_date}_{source}"
if cached_data := _cache.get_prices(cache_key):
return [Price(**price) for price in cached_data]
prices, data_source = _router.get_prices(ticker, start_date, end_date)
try:
prices, data_source = _router.get_prices(ticker, start_date, end_date)
except Exception as exc:
logger.info("Price lookup failed for %s: %s", ticker, exc)
return []
if not prices:
return []
@@ -133,17 +139,23 @@ def get_financial_metrics(
list[FinancialMetrics]: List of financial metrics
"""
ticker = normalize_symbol(ticker)
if not ticker:
return []
for source in _router.api_sources():
cache_key = f"{ticker}_{period}_{end_date}_{limit}_{source}"
if cached_data := _cache.get_financial_metrics(cache_key):
return [FinancialMetrics(**metric) for metric in cached_data]
financial_metrics, data_source = _router.get_financial_metrics(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
try:
financial_metrics, data_source = _router.get_financial_metrics(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
except Exception as exc:
logger.info("Financial metrics lookup failed for %s: %s", ticker, exc)
return []
if not financial_metrics:
return []
@@ -169,6 +181,8 @@ def search_line_items(
"""
try:
ticker = normalize_symbol(ticker)
if not ticker:
return []
return _router.search_line_items(
ticker=ticker,
line_items=line_items,
@@ -190,6 +204,8 @@ def get_insider_trades(
) -> list[InsiderTrade]:
"""Fetch insider trades from cache or API."""
ticker = normalize_symbol(ticker)
if not ticker:
return []
for source in _router.api_sources():
cache_key = (
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{source}"
@@ -197,12 +213,16 @@ def get_insider_trades(
if cached_data := _cache.get_insider_trades(cache_key):
return [InsiderTrade(**trade) for trade in cached_data]
all_trades, data_source = _router.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
try:
all_trades, data_source = _router.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
except Exception as exc:
logger.info("Insider trades lookup failed for %s: %s", ticker, exc)
return []
if not all_trades:
return []
@@ -219,6 +239,8 @@ def get_company_news(
) -> list[CompanyNews]:
"""Fetch company news from cache or API."""
ticker = normalize_symbol(ticker)
if not ticker:
return []
for source in _router.api_sources():
cache_key = (
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{source}"
@@ -226,12 +248,16 @@ def get_company_news(
if cached_data := _cache.get_company_news(cache_key):
return [CompanyNews(**news) for news in cached_data]
all_news, data_source = _router.get_company_news(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
try:
all_news, data_source = _router.get_company_news(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
except Exception as exc:
logger.info("Company news lookup failed for %s: %s", ticker, exc)
return []
if not all_news:
return []
@@ -243,6 +269,8 @@ def get_company_news(
def get_market_cap(ticker: str, end_date: str) -> float | None:
"""Fetch market cap from the API. Finnhub values are converted from millions."""
ticker = normalize_symbol(ticker)
if not ticker:
return None
def _metrics_lookup(symbol: str, date: str):
for source in _router.api_sources():
@@ -256,11 +284,15 @@ def get_market_cap(ticker: str, end_date: str) -> float | None:
limit=10,
)
market_cap, _ = _router.get_market_cap(
ticker=ticker,
end_date=end_date,
metrics_lookup=_metrics_lookup,
)
try:
market_cap, _ = _router.get_market_cap(
ticker=ticker,
end_date=end_date,
metrics_lookup=_metrics_lookup,
)
except Exception as exc:
logger.info("Market cap lookup failed for %s: %s", ticker, exc)
return None
return market_cap