确认PokieTicker新闻库数据源
This commit is contained in:
@@ -48,15 +48,19 @@ class AnalystAgent(ReActAgent):
|
||||
f"Must be one of: {list(ANALYST_TYPES.keys())}",
|
||||
)
|
||||
|
||||
self.analyst_type_key = analyst_type
|
||||
self.analyst_persona = ANALYST_TYPES[analyst_type]["display_name"]
|
||||
object.__setattr__(self, "analyst_type_key", analyst_type)
|
||||
object.__setattr__(
|
||||
self,
|
||||
"analyst_persona",
|
||||
ANALYST_TYPES[analyst_type]["display_name"],
|
||||
)
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = analyst_type
|
||||
self.agent_id = agent_id
|
||||
object.__setattr__(self, "agent_id", agent_id)
|
||||
|
||||
self.config = config or {}
|
||||
self.toolkit = toolkit
|
||||
object.__setattr__(self, "config", config or {})
|
||||
object.__setattr__(self, "toolkit", toolkit)
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
kwargs = {
|
||||
@@ -125,4 +129,12 @@ class AnalystAgent(ReActAgent):
|
||||
self.config.get("config_name", "default"),
|
||||
active_skill_dirs=active_skill_dirs,
|
||||
)
|
||||
self.sys_prompt = self._load_system_prompt()
|
||||
self._apply_runtime_sys_prompt(self._load_system_prompt())
|
||||
|
||||
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||
"""Update the prompt used by future turns and the cached system msg."""
|
||||
self._sys_prompt = sys_prompt
|
||||
for msg, _marks in self.memory.content:
|
||||
if getattr(msg, "role", None) == "system":
|
||||
msg.content = sys_prompt
|
||||
break
|
||||
|
||||
@@ -38,21 +38,29 @@ class PMAgent(ReActAgent):
|
||||
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
|
||||
toolkit: Optional[Toolkit] = None,
|
||||
):
|
||||
self.config = config or {}
|
||||
object.__setattr__(self, "config", config or {})
|
||||
|
||||
# Portfolio state
|
||||
self.portfolio = {
|
||||
"cash": initial_cash,
|
||||
"positions": {},
|
||||
"margin_used": 0.0,
|
||||
"margin_requirement": margin_requirement,
|
||||
}
|
||||
object.__setattr__(
|
||||
self,
|
||||
"portfolio",
|
||||
{
|
||||
"cash": initial_cash,
|
||||
"positions": {},
|
||||
"margin_used": 0.0,
|
||||
"margin_requirement": margin_requirement,
|
||||
},
|
||||
)
|
||||
|
||||
# Decisions made in current cycle
|
||||
self._decisions: Dict[str, Dict] = {}
|
||||
object.__setattr__(self, "_decisions", {})
|
||||
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
|
||||
self._toolkit_factory = toolkit_factory
|
||||
self._toolkit_factory_kwargs = toolkit_factory_kwargs
|
||||
object.__setattr__(self, "_toolkit_factory", toolkit_factory)
|
||||
object.__setattr__(
|
||||
self,
|
||||
"_toolkit_factory_kwargs",
|
||||
toolkit_factory_kwargs,
|
||||
)
|
||||
|
||||
# Create toolkit after local state is ready so bound tool methods can be registered.
|
||||
if toolkit is None:
|
||||
@@ -65,7 +73,7 @@ class PMAgent(ReActAgent):
|
||||
)
|
||||
else:
|
||||
toolkit = self._create_toolkit()
|
||||
self.toolkit = toolkit
|
||||
object.__setattr__(self, "toolkit", toolkit)
|
||||
|
||||
sys_prompt = build_agent_system_prompt(
|
||||
agent_id=name,
|
||||
@@ -205,6 +213,42 @@ class PMAgent(ReActAgent):
|
||||
"""Update portfolio after external execution"""
|
||||
self.portfolio.update(portfolio)
|
||||
|
||||
def _has_open_positions(self) -> bool:
|
||||
"""Return whether the current portfolio still has non-zero positions."""
|
||||
for position in self.portfolio.get("positions", {}).values():
|
||||
if position.get("long", 0) or position.get("short", 0):
|
||||
return True
|
||||
return False
|
||||
|
||||
def can_apply_initial_cash(self) -> bool:
|
||||
"""Only allow cash rebasing before any positions or margin exist."""
|
||||
return (
|
||||
not self._has_open_positions()
|
||||
and float(self.portfolio.get("margin_used", 0.0) or 0.0) == 0.0
|
||||
)
|
||||
|
||||
def apply_runtime_portfolio_config(
|
||||
self,
|
||||
*,
|
||||
margin_requirement: Optional[float] = None,
|
||||
initial_cash: Optional[float] = None,
|
||||
) -> Dict[str, bool]:
|
||||
"""Apply safe run-time portfolio config updates."""
|
||||
result = {
|
||||
"margin_requirement": False,
|
||||
"initial_cash": False,
|
||||
}
|
||||
|
||||
if margin_requirement is not None:
|
||||
self.portfolio["margin_requirement"] = float(margin_requirement)
|
||||
result["margin_requirement"] = True
|
||||
|
||||
if initial_cash is not None and self.can_apply_initial_cash():
|
||||
self.portfolio["cash"] = float(initial_cash)
|
||||
result["initial_cash"] = True
|
||||
|
||||
return result
|
||||
|
||||
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
||||
"""Reload toolkit and system prompt from current run assets."""
|
||||
from .toolkit_factory import create_agent_toolkit
|
||||
@@ -221,8 +265,18 @@ class PMAgent(ReActAgent):
|
||||
owner=self,
|
||||
**toolkit_kwargs,
|
||||
)
|
||||
self.sys_prompt = build_agent_system_prompt(
|
||||
agent_id=self.name,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
toolkit=self.toolkit,
|
||||
self._apply_runtime_sys_prompt(
|
||||
build_agent_system_prompt(
|
||||
agent_id=self.name,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
toolkit=self.toolkit,
|
||||
),
|
||||
)
|
||||
|
||||
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||
"""Update the prompt used by future turns and the cached system msg."""
|
||||
self._sys_prompt = sys_prompt
|
||||
for msg, _marks in self.memory.content:
|
||||
if getattr(msg, "role", None) == "system":
|
||||
msg.content = sys_prompt
|
||||
break
|
||||
|
||||
@@ -39,12 +39,12 @@ class RiskAgent(ReActAgent):
|
||||
config: Configuration dictionary
|
||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||
"""
|
||||
self.config = config or {}
|
||||
self.agent_id = name
|
||||
object.__setattr__(self, "config", config or {})
|
||||
object.__setattr__(self, "agent_id", name)
|
||||
|
||||
if toolkit is None:
|
||||
toolkit = Toolkit()
|
||||
self.toolkit = toolkit
|
||||
object.__setattr__(self, "toolkit", toolkit)
|
||||
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
@@ -99,4 +99,12 @@ class RiskAgent(ReActAgent):
|
||||
self.config.get("config_name", "default"),
|
||||
active_skill_dirs=active_skill_dirs,
|
||||
)
|
||||
self.sys_prompt = self._load_system_prompt()
|
||||
self._apply_runtime_sys_prompt(self._load_system_prompt())
|
||||
|
||||
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||
"""Update the prompt used by future turns and the cached system msg."""
|
||||
self._sys_prompt = sys_prompt
|
||||
for msg, _marks in self.memory.content:
|
||||
if getattr(msg, "role", None) == "system":
|
||||
msg.content = sys_prompt
|
||||
break
|
||||
|
||||
@@ -62,6 +62,59 @@ class SkillsManager:
|
||||
|
||||
raise FileNotFoundError(f"Unknown skill: {skill_name}")
|
||||
|
||||
def _persist_runtime_edits(
|
||||
self,
|
||||
config_name: str,
|
||||
skill_name: str,
|
||||
active_dir: Path,
|
||||
) -> None:
|
||||
"""
|
||||
Persist run-time edits from active skills into customized skills.
|
||||
|
||||
This keeps active skill experiments from being lost on the next reload
|
||||
while still allowing the active directory to be re-synced cleanly.
|
||||
"""
|
||||
if not active_dir.exists():
|
||||
return
|
||||
|
||||
source_dir = self._resolve_source_dir(skill_name)
|
||||
if active_dir.resolve() == source_dir.resolve():
|
||||
return
|
||||
|
||||
if not self._directories_match(active_dir, source_dir):
|
||||
customized_dir = self.customized_root / skill_name
|
||||
customized_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
if customized_dir.exists():
|
||||
shutil.rmtree(customized_dir)
|
||||
shutil.copytree(active_dir, customized_dir)
|
||||
|
||||
@staticmethod
|
||||
def _directories_match(left: Path, right: Path) -> bool:
|
||||
"""Compare two directory trees by file contents."""
|
||||
if not left.exists() or not right.exists():
|
||||
return False
|
||||
|
||||
left_items = sorted(
|
||||
path.relative_to(left)
|
||||
for path in left.rglob("*")
|
||||
)
|
||||
right_items = sorted(
|
||||
path.relative_to(right)
|
||||
for path in right.rglob("*")
|
||||
)
|
||||
if left_items != right_items:
|
||||
return False
|
||||
|
||||
for relative_path in left_items:
|
||||
left_path = left / relative_path
|
||||
right_path = right / relative_path
|
||||
if left_path.is_dir() != right_path.is_dir():
|
||||
return False
|
||||
if left_path.is_file():
|
||||
if left_path.read_bytes() != right_path.read_bytes():
|
||||
return False
|
||||
return True
|
||||
|
||||
def resolve_agent_skill_names(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -103,12 +156,22 @@ class SkillsManager:
|
||||
|
||||
for existing in active_root.iterdir():
|
||||
if existing.is_dir() and existing.name not in wanted:
|
||||
self._persist_runtime_edits(
|
||||
config_name=config_name,
|
||||
skill_name=existing.name,
|
||||
active_dir=existing,
|
||||
)
|
||||
shutil.rmtree(existing)
|
||||
|
||||
for skill_name in skill_names:
|
||||
source_dir = self._resolve_source_dir(skill_name)
|
||||
target_dir = active_root / skill_name
|
||||
if target_dir.exists():
|
||||
self._persist_runtime_edits(
|
||||
config_name=config_name,
|
||||
skill_name=skill_name,
|
||||
active_dir=target_dir,
|
||||
)
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
synced_paths.append(target_dir)
|
||||
|
||||
@@ -49,9 +49,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
||||
config_name: Configuration name for the run
|
||||
auto_clean: If True, skip confirmation and clean automatically
|
||||
"""
|
||||
# logs_dir = get_project_root() / "logs"
|
||||
logs_dir = get_project_root()
|
||||
base_data_dir = logs_dir / config_name
|
||||
workspace_manager = WorkspaceManager(project_root=get_project_root())
|
||||
base_data_dir = workspace_manager.get_run_dir(config_name)
|
||||
|
||||
# Check if historical data exists
|
||||
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Parse run-scoped BOOTSTRAP.md into structured configuration."""
|
||||
"""Parse run-scoped BOOTSTRAP.md into structured and runtime config."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -8,6 +8,8 @@ import re
|
||||
|
||||
import yaml
|
||||
|
||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||
|
||||
|
||||
BOOTSTRAP_FRONT_MATTER_RE = re.compile(
|
||||
r"^---\s*\n(.*?)\n---\s*\n?(.*)$",
|
||||
@@ -63,3 +65,84 @@ def get_bootstrap_config_for_run(
|
||||
return load_bootstrap_config(
|
||||
project_root / "runs" / config_name / "BOOTSTRAP.md",
|
||||
)
|
||||
|
||||
|
||||
def save_bootstrap_config(bootstrap_path: Path, config: BootstrapConfig) -> None:
|
||||
"""Persist structured bootstrap config back to BOOTSTRAP.md."""
|
||||
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
values = config.values if isinstance(config.values, dict) else {}
|
||||
front_matter = yaml.safe_dump(
|
||||
values,
|
||||
allow_unicode=True,
|
||||
sort_keys=False,
|
||||
).strip()
|
||||
body = (config.prompt_body or "").strip()
|
||||
|
||||
content = f"---\n{front_matter}\n---"
|
||||
if body:
|
||||
content += f"\n\n{body}\n"
|
||||
else:
|
||||
content += "\n"
|
||||
|
||||
bootstrap_path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def update_bootstrap_values_for_run(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
updates: Dict[str, Any],
|
||||
) -> BootstrapConfig:
|
||||
"""Patch selected front matter keys for a run and persist them."""
|
||||
bootstrap_path = project_root / "runs" / config_name / "BOOTSTRAP.md"
|
||||
existing = load_bootstrap_config(bootstrap_path)
|
||||
values = dict(existing.values)
|
||||
values.update(updates)
|
||||
updated = BootstrapConfig(values=values, prompt_body=existing.prompt_body)
|
||||
save_bootstrap_config(bootstrap_path, updated)
|
||||
return updated
|
||||
|
||||
|
||||
def _coerce_bool(value: Any) -> bool:
|
||||
"""Parse booleans from bootstrap-friendly string values."""
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {"1", "true", "yes", "on"}:
|
||||
return True
|
||||
if normalized in {"0", "false", "no", "off"}:
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
|
||||
def resolve_runtime_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
enable_memory: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge env defaults with run-scoped bootstrap front matter."""
|
||||
bootstrap = get_bootstrap_config_for_run(project_root, config_name)
|
||||
return {
|
||||
"tickers": bootstrap.get("tickers")
|
||||
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
||||
"initial_cash": float(
|
||||
bootstrap.get(
|
||||
"initial_cash",
|
||||
get_env_float("INITIAL_CASH", 100000.0),
|
||||
),
|
||||
),
|
||||
"margin_requirement": float(
|
||||
bootstrap.get(
|
||||
"margin_requirement",
|
||||
get_env_float("MARGIN_REQUIREMENT", 0.0),
|
||||
),
|
||||
),
|
||||
"max_comm_cycles": int(
|
||||
bootstrap.get(
|
||||
"max_comm_cycles",
|
||||
get_env_int("MAX_COMM_CYCLES", 2),
|
||||
),
|
||||
),
|
||||
"enable_memory": bool(enable_memory)
|
||||
or _coerce_bool(bootstrap.get("enable_memory", False)),
|
||||
}
|
||||
|
||||
@@ -226,12 +226,18 @@ class TradingPipeline:
|
||||
"settlement_result": settlement_result,
|
||||
}
|
||||
|
||||
def reload_runtime_assets(self) -> Dict[str, Any]:
|
||||
"""Reload prompt assets, bootstrap config, and active skills for all agents."""
|
||||
def reload_runtime_assets(
|
||||
self,
|
||||
runtime_config: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Reload prompt assets and safe in-process runtime settings."""
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import load_agent_profiles
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
if runtime_config and "max_comm_cycles" in runtime_config:
|
||||
self.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
profiles = load_agent_profiles()
|
||||
active_skill_map = skills_manager.prepare_active_skills(
|
||||
@@ -262,6 +268,7 @@ class TradingPipeline:
|
||||
agent_id: [path.name for path in paths]
|
||||
for agent_id, paths in active_skill_map.items()
|
||||
},
|
||||
"max_comm_cycles": self.max_comm_cycles,
|
||||
}
|
||||
|
||||
async def _clear_all_agent_memory(self):
|
||||
|
||||
@@ -438,6 +438,8 @@ class StateSync:
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"tickers": self._state.get("tickers"),
|
||||
"runtime_config": self._state.get("runtime_config"),
|
||||
"feed_history": self._state.get("feed_history", []),
|
||||
"current_date": self._state.get("current_date"),
|
||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||
@@ -452,6 +454,7 @@ class StateSync:
|
||||
"portfolio": self._state.get("portfolio", {}),
|
||||
"realtime_prices": self._state.get("realtime_prices", {}),
|
||||
"data_sources": self._state.get("data_sources", {}),
|
||||
"price_history": self._state.get("price_history", {}),
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
|
||||
@@ -30,6 +30,25 @@ logger = logging.getLogger(__name__)
|
||||
_DATA_DIR = Path(__file__).parent / "ret_data"
|
||||
|
||||
|
||||
def _format_provider_error(exc: Exception) -> str:
|
||||
"""Condense common provider failures into short, readable messages."""
|
||||
message = str(exc).strip().replace("\n", " ")
|
||||
if "429" in message:
|
||||
return "rate limit reached"
|
||||
if "402" in message:
|
||||
return "insufficient credits"
|
||||
if "422" in message or "Missing parameters" in message:
|
||||
return "invalid request parameters"
|
||||
if "Quote not found" in message:
|
||||
return "quote not found"
|
||||
return message
|
||||
|
||||
|
||||
def _has_valid_ticker(ticker: str) -> bool:
|
||||
"""Return whether the normalized ticker is non-empty."""
|
||||
return bool((ticker or "").strip())
|
||||
|
||||
|
||||
class DataProviderRouter:
|
||||
"""Route data requests across configured providers with fallbacks."""
|
||||
|
||||
@@ -56,6 +75,8 @@ class DataProviderRouter:
|
||||
end_date: str,
|
||||
) -> tuple[list[Price], DataSource]:
|
||||
"""Fetch prices using preferred providers with fallback."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return [], "local_csv"
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for source in self.price_sources():
|
||||
@@ -78,7 +99,12 @@ class DataProviderRouter:
|
||||
return prices, source
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
logger.warning("Price source %s failed for %s: %s", source, ticker, exc)
|
||||
logger.warning(
|
||||
"Price source %s failed for %s: %s",
|
||||
source,
|
||||
ticker,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
|
||||
if last_error:
|
||||
raise last_error
|
||||
@@ -92,6 +118,8 @@ class DataProviderRouter:
|
||||
limit: int = 10,
|
||||
) -> tuple[list[FinancialMetrics], DataSource]:
|
||||
"""Fetch financial metrics with API provider fallback."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return [], "local_csv"
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for source in self.api_sources():
|
||||
@@ -126,7 +154,7 @@ class DataProviderRouter:
|
||||
"Financial metrics source %s failed for %s: %s",
|
||||
source,
|
||||
ticker,
|
||||
exc,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
|
||||
if last_error:
|
||||
@@ -142,6 +170,8 @@ class DataProviderRouter:
|
||||
limit: int = 10,
|
||||
) -> list[LineItem]:
|
||||
"""Line items are only supported via Financial Datasets."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return []
|
||||
if "financial_datasets" not in self.api_sources():
|
||||
return []
|
||||
try:
|
||||
@@ -155,7 +185,11 @@ class DataProviderRouter:
|
||||
self._record_success("line_items", "financial_datasets")
|
||||
return results
|
||||
except Exception as exc:
|
||||
logger.warning("Line items source failed for %s: %s", ticker, exc)
|
||||
logger.warning(
|
||||
"Line items source failed for %s: %s",
|
||||
ticker,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
return []
|
||||
|
||||
def get_insider_trades(
|
||||
@@ -166,6 +200,8 @@ class DataProviderRouter:
|
||||
limit: int = 1000,
|
||||
) -> tuple[list[InsiderTrade], DataSource]:
|
||||
"""Fetch insider trades with provider fallback."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return [], "local_csv"
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for source in self.api_sources():
|
||||
@@ -193,7 +229,7 @@ class DataProviderRouter:
|
||||
"Insider trades source %s failed for %s: %s",
|
||||
source,
|
||||
ticker,
|
||||
exc,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
|
||||
if last_error:
|
||||
@@ -208,6 +244,8 @@ class DataProviderRouter:
|
||||
limit: int = 1000,
|
||||
) -> tuple[list[CompanyNews], DataSource]:
|
||||
"""Fetch company news with provider fallback."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return [], "local_csv"
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for source in self.api_sources():
|
||||
@@ -244,7 +282,7 @@ class DataProviderRouter:
|
||||
"Company news source %s failed for %s: %s",
|
||||
source,
|
||||
ticker,
|
||||
exc,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
|
||||
if last_error:
|
||||
@@ -258,6 +296,8 @@ class DataProviderRouter:
|
||||
metrics_lookup,
|
||||
) -> tuple[Optional[float], DataSource]:
|
||||
"""Fetch market cap using facts API or financial metrics fallback."""
|
||||
if not _has_valid_ticker(ticker):
|
||||
return None, "local_csv"
|
||||
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
if end_date == today and "financial_datasets" in self.api_sources():
|
||||
try:
|
||||
@@ -267,7 +307,7 @@ class DataProviderRouter:
|
||||
logger.warning(
|
||||
"Market cap facts source failed for %s: %s",
|
||||
ticker,
|
||||
exc,
|
||||
_format_provider_error(exc),
|
||||
)
|
||||
|
||||
metrics, source = metrics_lookup(ticker, end_date)
|
||||
|
||||
@@ -18,9 +18,8 @@ from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from backend.config.bootstrap_config import resolve_runtime_config
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
@@ -36,35 +35,20 @@ loguru.logger.disable("reme_ai")
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
def _get_run_dir(config_name: str) -> Path:
|
||||
"""Return the canonical run-scoped directory for a config."""
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
return WorkspaceManager(project_root=project_root).get_run_dir(config_name)
|
||||
|
||||
|
||||
def _resolve_runtime_config(args) -> dict:
|
||||
"""Merge env defaults with run-scoped bootstrap config."""
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
bootstrap = get_bootstrap_config_for_run(project_root, args.config_name)
|
||||
|
||||
return {
|
||||
"tickers": bootstrap.get("tickers")
|
||||
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
||||
"initial_cash": float(
|
||||
bootstrap.get(
|
||||
"initial_cash",
|
||||
get_env_float("INITIAL_CASH", 100000.0),
|
||||
),
|
||||
),
|
||||
"margin_requirement": float(
|
||||
bootstrap.get(
|
||||
"margin_requirement",
|
||||
get_env_float("MARGIN_REQUIREMENT", 0.0),
|
||||
),
|
||||
),
|
||||
"max_comm_cycles": int(
|
||||
bootstrap.get(
|
||||
"max_comm_cycles",
|
||||
get_env_int("MAX_COMM_CYCLES", 2),
|
||||
),
|
||||
),
|
||||
"enable_memory": args.enable_memory
|
||||
or bool(bootstrap.get("enable_memory", False)),
|
||||
}
|
||||
return resolve_runtime_config(
|
||||
project_root=project_root,
|
||||
config_name=args.config_name,
|
||||
enable_memory=args.enable_memory,
|
||||
)
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, config_name: str):
|
||||
@@ -82,7 +66,7 @@ def create_long_term_memory(agent_name: str, config_name: str):
|
||||
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
||||
return None
|
||||
|
||||
memory_dir = str(Path(config_name) / "memory")
|
||||
memory_dir = str(_get_run_dir(config_name) / "memory")
|
||||
|
||||
return ReMeTaskLongTermMemory(
|
||||
agent_name=agent_name,
|
||||
@@ -241,7 +225,7 @@ async def run_with_gateway(args):
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=Path(config_name) / "team_dashboard",
|
||||
dashboard_dir=_get_run_dir(config_name) / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=config_name,
|
||||
)
|
||||
@@ -316,6 +300,10 @@ async def run_with_gateway(args):
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": config_name,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": runtime_config["max_comm_cycles"],
|
||||
"enable_memory": runtime_config["enable_memory"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -5,12 +5,18 @@ WebSocket Gateway for frontend communication
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
from websockets.asyncio.server import ServerConnection
|
||||
|
||||
from backend.config.bootstrap_config import (
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -18,6 +24,7 @@ from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,7 +58,7 @@ class Gateway:
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.connected_clients: Set[ServerConnection] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
@@ -63,6 +70,7 @@ class Gateway:
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
self._provider_router = get_provider_router()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._project_root = Path(__file__).resolve().parents[2]
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
@@ -87,6 +95,7 @@ class Gateway:
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
@@ -94,6 +103,20 @@ class Gateway:
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.config.get(
|
||||
"initial_cash",
|
||||
self.storage.initial_cash,
|
||||
),
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
self.state_sync.update_state(
|
||||
"data_sources",
|
||||
self._provider_router.get_usage_snapshot(),
|
||||
@@ -159,7 +182,7 @@ class Gateway:
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
async def handle_client(self, websocket: ServerConnection):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
@@ -170,7 +193,7 @@ class Gateway:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
async def _send_initial_state(self, websocket: ServerConnection):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
@@ -198,7 +221,7 @@ class Gateway:
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
websocket: ServerConnection,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
@@ -221,12 +244,104 @@ class Gateway:
|
||||
await self._handle_start_backtest(data)
|
||||
elif msg_type == "reload_runtime_assets":
|
||||
await self._handle_reload_runtime_assets()
|
||||
elif msg_type == "update_watchlist":
|
||||
await self._handle_update_watchlist(websocket, data)
|
||||
elif msg_type == "get_stock_history":
|
||||
await self._handle_get_stock_history(websocket, data)
|
||||
elif msg_type == "get_stock_explain_events":
|
||||
await self._handle_get_stock_explain_events(websocket, data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_get_stock_history(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": "",
|
||||
"prices": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = self.state_sync.state.get("current_date")
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime(
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
usage_snapshot = self._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price.model_dump() for price in prices][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_explain_events(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_explain_events_loaded",
|
||||
"ticker": ticker,
|
||||
"events": snapshot.get("events", []),
|
||||
"signals": snapshot.get("signals", []),
|
||||
"trades": snapshot.get("trades", []),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
@@ -239,8 +354,15 @@ class Gateway:
|
||||
self._backtest_task = task
|
||||
|
||||
async def _handle_reload_runtime_assets(self):
|
||||
"""Reload prompt assets and active skills without restarting the server."""
|
||||
result = self.pipeline.reload_runtime_assets()
|
||||
"""Reload prompt, skills, and safe runtime config without restart."""
|
||||
config_name = self.config.get("config_name", "default")
|
||||
runtime_config = resolve_runtime_config(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
enable_memory=self.config.get("enable_memory", False),
|
||||
)
|
||||
result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||
runtime_updates = self._apply_runtime_config(runtime_config)
|
||||
await self.state_sync.on_system_message(
|
||||
"Runtime assets reloaded.",
|
||||
)
|
||||
@@ -248,9 +370,174 @@ class Gateway:
|
||||
{
|
||||
"type": "runtime_assets_reloaded",
|
||||
**result,
|
||||
**runtime_updates,
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_update_watchlist(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Persist a new watchlist to BOOTSTRAP.md and hot-reload it."""
|
||||
tickers = self._normalize_watchlist(data.get("tickers"))
|
||||
if not tickers:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "error",
|
||||
"message": "update_watchlist requires at least one valid ticker.",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
config_name = self.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=self._project_root,
|
||||
config_name=config_name,
|
||||
updates={"tickers": tickers},
|
||||
)
|
||||
await self.state_sync.on_system_message(
|
||||
f"Watchlist updated: {', '.join(tickers)}",
|
||||
)
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "watchlist_updated",
|
||||
"config_name": config_name,
|
||||
"tickers": tickers,
|
||||
},
|
||||
)
|
||||
await self._handle_reload_runtime_assets()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||
"""Parse watchlist payloads from websocket messages."""
|
||||
if raw_tickers is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_tickers, str):
|
||||
candidates = raw_tickers.split(",")
|
||||
elif isinstance(raw_tickers, list):
|
||||
candidates = raw_tickers
|
||||
else:
|
||||
candidates = [raw_tickers]
|
||||
|
||||
tickers: List[str] = []
|
||||
for candidate in candidates:
|
||||
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||
if symbol and symbol not in tickers:
|
||||
tickers.append(symbol)
|
||||
return tickers
|
||||
|
||||
def _apply_runtime_config(
|
||||
self,
|
||||
runtime_config: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply runtime config to gateway-owned services and state."""
|
||||
warnings: List[str] = []
|
||||
|
||||
ticker_changes = self.market_service.update_tickers(
|
||||
runtime_config.get("tickers", []),
|
||||
)
|
||||
self.config["tickers"] = ticker_changes["active"]
|
||||
|
||||
self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles
|
||||
|
||||
pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
margin_requirement=runtime_config["margin_requirement"],
|
||||
)
|
||||
self.config["margin_requirement"] = self.pipeline.pm.portfolio.get(
|
||||
"margin_requirement",
|
||||
runtime_config["margin_requirement"],
|
||||
)
|
||||
|
||||
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||
current_initial_cash = float(self.storage.initial_cash)
|
||||
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||
if not initial_cash_applied:
|
||||
if (
|
||||
self.storage.can_apply_initial_cash()
|
||||
and self.pipeline.pm.can_apply_initial_cash()
|
||||
):
|
||||
initial_cash_applied = self.storage.apply_initial_cash(
|
||||
requested_initial_cash,
|
||||
)
|
||||
if initial_cash_applied:
|
||||
self.pipeline.pm.apply_runtime_portfolio_config(
|
||||
initial_cash=requested_initial_cash,
|
||||
)
|
||||
self.config["initial_cash"] = self.storage.initial_cash
|
||||
else:
|
||||
warnings.append(
|
||||
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||
"because the run already has positions, margin usage, or trades.",
|
||||
)
|
||||
|
||||
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||
current_enable_memory = bool(self.config.get("enable_memory", False))
|
||||
if requested_enable_memory != current_enable_memory:
|
||||
warnings.append(
|
||||
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||
"because long-term memory contexts are created at startup.",
|
||||
)
|
||||
|
||||
self._sync_runtime_state()
|
||||
|
||||
return {
|
||||
"runtime_config_requested": runtime_config,
|
||||
"runtime_config_applied": {
|
||||
"tickers": list(self.config.get("tickers", [])),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config["margin_requirement"],
|
||||
"max_comm_cycles": self.config["max_comm_cycles"],
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
"runtime_config_status": {
|
||||
"tickers": True,
|
||||
"initial_cash": initial_cash_applied,
|
||||
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||
"max_comm_cycles": True,
|
||||
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||
},
|
||||
"ticker_changes": ticker_changes,
|
||||
"runtime_config_warnings": warnings,
|
||||
}
|
||||
|
||||
def _sync_runtime_state(self) -> None:
|
||||
"""Refresh persisted state and dashboard after runtime config changes."""
|
||||
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||
self.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": self.config.get("tickers", []),
|
||||
"initial_cash": self.storage.initial_cash,
|
||||
"margin_requirement": self.config.get("margin_requirement"),
|
||||
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||
"enable_memory": self.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
|
||||
self.storage.update_server_state_from_dashboard(self.state_sync.state)
|
||||
self.state_sync.save_state()
|
||||
|
||||
self._dashboard.tickers = list(self.config.get("tickers", []))
|
||||
self._dashboard.initial_cash = self.storage.initial_cash
|
||||
self._dashboard.enable_memory = bool(
|
||||
self.config.get("enable_memory", False),
|
||||
)
|
||||
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
self._dashboard.update(
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
@@ -269,7 +556,7 @@ class Gateway:
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
client: ServerConnection,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
|
||||
@@ -54,6 +54,7 @@ class MarketService:
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_record_func: Optional[Callable[..., None]] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
@@ -92,6 +93,10 @@ class MarketService:
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
|
||||
"""Register an optional callback for persisting runtime price points."""
|
||||
self._price_record_func = recorder
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
@@ -169,6 +174,24 @@ class MarketService:
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
if self._price_record_func:
|
||||
try:
|
||||
self._price_record_func(
|
||||
ticker=symbol,
|
||||
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
|
||||
price=float(price),
|
||||
open_price=float(open_price) if open_price is not None else None,
|
||||
ret=float(ret),
|
||||
source=self.mode_name.lower(),
|
||||
meta=price_data,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to record price point for %s: %s",
|
||||
symbol,
|
||||
exc,
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
@@ -205,6 +228,43 @@ class MarketService:
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
|
||||
"""Hot-update subscribed tickers without restarting the service."""
|
||||
normalized: List[str] = []
|
||||
for ticker in tickers:
|
||||
symbol = normalize_symbol(ticker)
|
||||
if symbol and symbol not in normalized:
|
||||
normalized.append(symbol)
|
||||
|
||||
previous = list(self.tickers)
|
||||
removed = [ticker for ticker in previous if ticker not in normalized]
|
||||
added = [ticker for ticker in normalized if ticker not in previous]
|
||||
self.tickers = normalized
|
||||
|
||||
if self._price_manager:
|
||||
if removed:
|
||||
self._price_manager.unsubscribe(removed)
|
||||
if added:
|
||||
if self.mock_mode:
|
||||
self._price_manager.subscribe(
|
||||
added,
|
||||
base_prices={ticker: 100.0 for ticker in added},
|
||||
)
|
||||
else:
|
||||
self._price_manager.subscribe(added)
|
||||
|
||||
if self.backtest_mode and self._current_date:
|
||||
self._price_manager.set_date(self._current_date)
|
||||
|
||||
for ticker in removed:
|
||||
self.cache.pop(ticker, None)
|
||||
|
||||
return {
|
||||
"added": added,
|
||||
"removed": removed,
|
||||
"active": list(self.tickers),
|
||||
}
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
|
||||
388
backend/services/runtime_db.py
Normal file
388
backend/services/runtime_db.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped SQLite storage for query-oriented runtime history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT PRIMARY KEY,
|
||||
event_type TEXT NOT NULL,
|
||||
timestamp TEXT,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
ticker TEXT,
|
||||
title TEXT,
|
||||
content TEXT,
|
||||
payload_json TEXT NOT NULL,
|
||||
run_date TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS trades (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
side TEXT,
|
||||
qty REAL,
|
||||
price REAL,
|
||||
timestamp TEXT,
|
||||
trading_date TEXT,
|
||||
agent_id TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS signals (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
agent_name TEXT,
|
||||
role TEXT,
|
||||
signal TEXT,
|
||||
confidence REAL,
|
||||
reasoning_json TEXT,
|
||||
real_return REAL,
|
||||
is_correct TEXT,
|
||||
trade_date TEXT,
|
||||
created_at TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS price_points (
|
||||
id TEXT PRIMARY KEY,
|
||||
ticker TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
price REAL NOT NULL,
|
||||
open_price REAL,
|
||||
ret REAL,
|
||||
source TEXT,
|
||||
meta_json TEXT
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
|
||||
"""
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str:
|
||||
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||
|
||||
|
||||
def _hash_key(*parts: Any) -> str:
|
||||
raw = "::".join("" if part is None else str(part) for part in parts)
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class RuntimeDb:
|
||||
"""Small SQLite helper for append-mostly runtime data."""
|
||||
|
||||
def __init__(self, db_path: Path):
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
return conn
|
||||
|
||||
def _init_db(self):
|
||||
with self._connect() as conn:
|
||||
conn.executescript(SCHEMA)
|
||||
|
||||
def insert_event(self, event: Dict[str, Any]):
|
||||
payload = dict(event or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
event_id = payload.get("id") or _hash_key(
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("content"),
|
||||
payload.get("title"),
|
||||
)
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
|
||||
ticker = payload["tickers"][0]
|
||||
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO events
|
||||
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
event_id,
|
||||
payload.get("type"),
|
||||
payload.get("timestamp"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
payload.get("agentName") or payload.get("agent_name"),
|
||||
ticker,
|
||||
payload.get("title"),
|
||||
payload.get("content"),
|
||||
_json_dumps(payload),
|
||||
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
|
||||
),
|
||||
)
|
||||
|
||||
def upsert_trade(self, trade: Dict[str, Any]):
|
||||
payload = dict(trade or {})
|
||||
if not payload:
|
||||
return
|
||||
|
||||
trade_id = payload.get("id") or _hash_key(
|
||||
payload.get("ticker"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO trades
|
||||
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
trade_id,
|
||||
payload.get("ticker"),
|
||||
payload.get("side"),
|
||||
payload.get("qty"),
|
||||
payload.get("price"),
|
||||
payload.get("timestamp") or payload.get("ts"),
|
||||
payload.get("trading_date"),
|
||||
payload.get("agentId") or payload.get("agent_id"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
return
|
||||
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
|
||||
with self._connect() as conn:
|
||||
conn.execute("DELETE FROM signals")
|
||||
for agent in leaderboard:
|
||||
agent_id = agent.get("agentId")
|
||||
agent_name = agent.get("name")
|
||||
role = agent.get("role")
|
||||
for signal in agent.get("signals", []) or []:
|
||||
payload = dict(signal or {})
|
||||
ticker = payload.get("ticker")
|
||||
if not ticker:
|
||||
continue
|
||||
signal_id = _hash_key(
|
||||
agent_id,
|
||||
ticker,
|
||||
payload.get("date"),
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
ticker,
|
||||
agent_id,
|
||||
agent_name,
|
||||
role,
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
payload.get("created_at") or payload.get("date"),
|
||||
_json_dumps(payload),
|
||||
),
|
||||
)
|
||||
|
||||
def insert_price_point(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
timestamp: str,
|
||||
price: float,
|
||||
open_price: Optional[float] = None,
|
||||
ret: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR IGNORE INTO price_points
|
||||
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
price_id,
|
||||
ticker,
|
||||
timestamp,
|
||||
price,
|
||||
open_price,
|
||||
ret,
|
||||
source,
|
||||
_json_dumps(meta or {}),
|
||||
),
|
||||
)
|
||||
|
||||
def get_stock_explain_snapshot(
|
||||
self,
|
||||
ticker: str,
|
||||
*,
|
||||
limit_events: int = 24,
|
||||
limit_trades: int = 12,
|
||||
limit_signals: int = 12,
|
||||
) -> Dict[str, list[Dict[str, Any]]]:
|
||||
"""Fetch query-oriented history for a single ticker."""
|
||||
symbol = str(ticker or "").strip().upper()
|
||||
if not symbol:
|
||||
return {"events": [], "trades": [], "signals": []}
|
||||
|
||||
with self._connect() as conn:
|
||||
trade_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM trades
|
||||
WHERE ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_trades),
|
||||
).fetchall()
|
||||
signal_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM signals
|
||||
WHERE ticker = ?
|
||||
ORDER BY trade_date DESC, created_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(symbol, limit_signals),
|
||||
).fetchall()
|
||||
event_rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM events
|
||||
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
|
||||
).fetchall()
|
||||
|
||||
normalized_events = []
|
||||
seen_event_ids: set[str] = set()
|
||||
for row in event_rows:
|
||||
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||
content = str(row["content"] or payload.get("content") or "")
|
||||
title = str(row["title"] or payload.get("title") or "")
|
||||
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
|
||||
continue
|
||||
event_id = row["id"]
|
||||
if event_id in seen_event_ids:
|
||||
continue
|
||||
seen_event_ids.add(event_id)
|
||||
normalized_events.append(
|
||||
{
|
||||
"id": event_id,
|
||||
"type": "mention",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
|
||||
"meta": payload.get("conferenceTitle")
|
||||
or payload.get("feedType")
|
||||
or row["event_type"],
|
||||
"body": content,
|
||||
"tone": "neutral",
|
||||
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
|
||||
},
|
||||
)
|
||||
if len(normalized_events) >= limit_events:
|
||||
break
|
||||
|
||||
normalized_trades = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "trade",
|
||||
"timestamp": row["timestamp"],
|
||||
"title": f"{row['side']} {int(row['qty'] or 0)} 股",
|
||||
"meta": "交易执行",
|
||||
"body": f"成交价 ${float(row['price'] or 0):.2f}",
|
||||
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
|
||||
}
|
||||
for row in trade_rows
|
||||
]
|
||||
|
||||
normalized_signals = [
|
||||
{
|
||||
"id": row["id"],
|
||||
"type": "signal",
|
||||
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
|
||||
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
|
||||
"meta": row["role"],
|
||||
"body": (
|
||||
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
|
||||
if row["real_return"] is not None
|
||||
else "该信号暂未完成后验评估"
|
||||
),
|
||||
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
|
||||
}
|
||||
for row in signal_rows
|
||||
]
|
||||
|
||||
return {
|
||||
"events": normalized_events,
|
||||
"trades": normalized_trades,
|
||||
"signals": normalized_signals,
|
||||
}
|
||||
@@ -10,6 +10,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .runtime_db import RuntimeDb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,6 +63,7 @@ class StorageService:
|
||||
self.state_dir = self.dashboard_dir.parent / "state"
|
||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.server_state_file = self.state_dir / "server_state.json"
|
||||
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||
|
||||
# Feed history (for agent messages)
|
||||
self.max_feed_history = 200
|
||||
@@ -114,6 +117,11 @@ class StorageService:
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
if file_type == "leaderboard" and isinstance(data, list):
|
||||
self.runtime_db.replace_signals_for_leaderboard(data)
|
||||
elif file_type == "trades" and isinstance(data, list):
|
||||
for trade in data:
|
||||
self.runtime_db.upsert_trade(trade)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {file_type}.json: {e}")
|
||||
|
||||
@@ -211,6 +219,7 @@ class StorageService:
|
||||
try:
|
||||
with open(self.internal_state_file, "w", encoding="utf-8") as f:
|
||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||
self._sync_price_history_to_db(state.get("price_history", {}))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save internal state: {e}")
|
||||
|
||||
@@ -231,6 +240,41 @@ class StorageService:
|
||||
"margin_requirement": 0.25, # Default 25% margin requirement
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _portfolio_is_pristine(portfolio_state: Dict[str, Any]) -> bool:
|
||||
"""Return whether the persisted portfolio can be safely rebased."""
|
||||
positions = portfolio_state.get("positions", {})
|
||||
has_positions = any(
|
||||
position.get("long", 0) or position.get("short", 0)
|
||||
for position in positions.values()
|
||||
)
|
||||
margin_used = float(portfolio_state.get("margin_used", 0.0) or 0.0)
|
||||
return not has_positions and margin_used == 0.0
|
||||
|
||||
def can_apply_initial_cash(self) -> bool:
|
||||
"""Only allow initial cash changes before the run has traded."""
|
||||
state = self.load_internal_state()
|
||||
if not self._portfolio_is_pristine(state.get("portfolio_state", {})):
|
||||
return False
|
||||
if state.get("all_trades"):
|
||||
return False
|
||||
return len(state.get("equity_history", [])) <= 1
|
||||
|
||||
def apply_initial_cash(self, initial_cash: float) -> bool:
|
||||
"""Rebase storage state to a new initial cash when the run is pristine."""
|
||||
if not self.can_apply_initial_cash():
|
||||
return False
|
||||
|
||||
self.initial_cash = float(initial_cash)
|
||||
if self.internal_state_file.exists():
|
||||
self.internal_state_file.unlink()
|
||||
|
||||
self.initialize_empty_dashboard()
|
||||
state = self.load_server_state()
|
||||
self.update_server_state_from_dashboard(state)
|
||||
self.save_server_state(state)
|
||||
return True
|
||||
|
||||
def save_portfolio_state(self, portfolio: Dict[str, Any]):
|
||||
"""
|
||||
Save portfolio state to internal state
|
||||
@@ -750,6 +794,7 @@ class StorageService:
|
||||
"last_day_history": [],
|
||||
"trading_days_total": 0,
|
||||
"trading_days_completed": 0,
|
||||
"price_history": {},
|
||||
}
|
||||
|
||||
if not self.server_state_file.exists():
|
||||
@@ -771,6 +816,11 @@ class StorageService:
|
||||
)
|
||||
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
|
||||
|
||||
for event in saved_state.get("feed_history", []):
|
||||
self.runtime_db.insert_event(event)
|
||||
for trade in saved_state.get("trades", []):
|
||||
self.runtime_db.upsert_trade(trade)
|
||||
|
||||
return saved_state
|
||||
|
||||
def save_server_state(self, state: Dict[str, Any]):
|
||||
@@ -852,6 +902,7 @@ class StorageService:
|
||||
state["feed_history"] = []
|
||||
|
||||
state["feed_history"].insert(0, feed_msg)
|
||||
self.runtime_db.insert_event(feed_msg)
|
||||
|
||||
# Trim to max size
|
||||
if len(state["feed_history"]) > self.max_feed_history:
|
||||
@@ -861,6 +912,69 @@ class StorageService:
|
||||
|
||||
return True
|
||||
|
||||
def record_price_point(
|
||||
self,
|
||||
*,
|
||||
ticker: str,
|
||||
timestamp: str,
|
||||
price: float,
|
||||
open_price: Optional[float] = None,
|
||||
ret: Optional[float] = None,
|
||||
source: Optional[str] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
"""Persist a runtime price point for later query-oriented reads."""
|
||||
if not ticker or not timestamp:
|
||||
return
|
||||
try:
|
||||
self.runtime_db.insert_price_point(
|
||||
ticker=ticker,
|
||||
timestamp=timestamp,
|
||||
price=price,
|
||||
open_price=open_price,
|
||||
ret=ret,
|
||||
source=source,
|
||||
meta=meta,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to record price point for %s: %s", ticker, exc)
|
||||
|
||||
def _sync_price_history_to_db(self, price_history: Dict[str, Any]):
|
||||
"""Backfill structured price points from serialized internal state."""
|
||||
if not isinstance(price_history, dict):
|
||||
return
|
||||
for ticker, points in price_history.items():
|
||||
if not ticker or not isinstance(points, list):
|
||||
continue
|
||||
for point in points:
|
||||
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
||||
timestamp, price = point[0], point[1]
|
||||
try:
|
||||
self.record_price_point(
|
||||
ticker=str(ticker),
|
||||
timestamp=str(timestamp),
|
||||
price=float(price),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
elif isinstance(point, dict):
|
||||
timestamp = point.get("timestamp") or point.get("label") or point.get("date")
|
||||
price = point.get("price") or point.get("close") or point.get("value")
|
||||
if not timestamp or price is None:
|
||||
continue
|
||||
try:
|
||||
self.record_price_point(
|
||||
ticker=str(ticker),
|
||||
timestamp=str(timestamp),
|
||||
price=float(price),
|
||||
open_price=point.get("open"),
|
||||
ret=point.get("ret"),
|
||||
source=point.get("source"),
|
||||
meta=point,
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
|
||||
def _get_default_stats(self) -> Dict[str, Any]:
|
||||
"""Get default stats structure"""
|
||||
return {
|
||||
@@ -889,6 +1003,7 @@ class StorageService:
|
||||
stats = self.load_file("stats") or self._get_default_stats()
|
||||
trades = self.load_file("trades") or []
|
||||
leaderboard = self.load_file("leaderboard") or []
|
||||
internal_state = self.load_internal_state()
|
||||
|
||||
# Update state
|
||||
state["portfolio"] = {
|
||||
@@ -910,6 +1025,9 @@ class StorageService:
|
||||
state["stats"] = stats
|
||||
state["trades"] = trades
|
||||
state["leaderboard"] = leaderboard
|
||||
state["price_history"] = internal_state.get("price_history", {})
|
||||
self.runtime_db.replace_signals_for_leaderboard(leaderboard)
|
||||
self._sync_price_history_to_db(state["price_history"])
|
||||
|
||||
# ========== Live Returns Tracking ==========
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ Returns human-readable text format for easy LLM consumption.
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301,W0613
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
@@ -20,6 +21,7 @@ import pandas as pd
|
||||
from agentscope.message import TextBlock
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
@@ -53,6 +55,16 @@ def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
|
||||
Returns:
|
||||
List of stock tickers.
|
||||
"""
|
||||
def _sanitize(values: List[object]) -> List[str]:
|
||||
cleaned: List[str] = []
|
||||
for value in values:
|
||||
if value is None:
|
||||
continue
|
||||
symbol = normalize_symbol(str(value).strip().strip("\"'"))
|
||||
if symbol and symbol not in cleaned:
|
||||
cleaned.append(symbol)
|
||||
return cleaned
|
||||
|
||||
if tickers is None:
|
||||
return []
|
||||
|
||||
@@ -60,17 +72,22 @@ def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
|
||||
try:
|
||||
parsed = json.loads(tickers)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
# If it's a single string, wrap in list
|
||||
return [parsed]
|
||||
return _sanitize(parsed)
|
||||
return _sanitize([parsed])
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, treat as comma-separated string
|
||||
return [t.strip() for t in tickers.split(",") if t.strip()]
|
||||
try:
|
||||
parsed = ast.literal_eval(tickers)
|
||||
if isinstance(parsed, list):
|
||||
return _sanitize(parsed)
|
||||
return _sanitize([parsed])
|
||||
except (SyntaxError, ValueError):
|
||||
pass
|
||||
return _sanitize(tickers.split(","))
|
||||
|
||||
if isinstance(tickers, list):
|
||||
return tickers
|
||||
return _sanitize(tickers)
|
||||
|
||||
return []
|
||||
return _sanitize([tickers])
|
||||
|
||||
|
||||
def _safe_float(value, default=0.0) -> float:
|
||||
@@ -350,6 +367,7 @@ def get_financial_metrics_tool(
|
||||
"""
|
||||
|
||||
current_date = _resolved_date(current_date)
|
||||
tickers = _parse_tickers(tickers)
|
||||
lines = [
|
||||
f"=== Comprehensive Financial Metrics ({current_date}, {period}) ===\n",
|
||||
]
|
||||
|
||||
@@ -96,13 +96,19 @@ def get_prices(
|
||||
list[Price]: List of Price objects
|
||||
"""
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
cached_sources = _router.price_sources()
|
||||
for source in cached_sources:
|
||||
cache_key = f"{ticker}_{start_date}_{end_date}_{source}"
|
||||
if cached_data := _cache.get_prices(cache_key):
|
||||
return [Price(**price) for price in cached_data]
|
||||
|
||||
prices, data_source = _router.get_prices(ticker, start_date, end_date)
|
||||
try:
|
||||
prices, data_source = _router.get_prices(ticker, start_date, end_date)
|
||||
except Exception as exc:
|
||||
logger.info("Price lookup failed for %s: %s", ticker, exc)
|
||||
return []
|
||||
|
||||
if not prices:
|
||||
return []
|
||||
@@ -133,17 +139,23 @@ def get_financial_metrics(
|
||||
list[FinancialMetrics]: List of financial metrics
|
||||
"""
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
for source in _router.api_sources():
|
||||
cache_key = f"{ticker}_{period}_{end_date}_{limit}_{source}"
|
||||
if cached_data := _cache.get_financial_metrics(cache_key):
|
||||
return [FinancialMetrics(**metric) for metric in cached_data]
|
||||
|
||||
financial_metrics, data_source = _router.get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
try:
|
||||
financial_metrics, data_source = _router.get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.info("Financial metrics lookup failed for %s: %s", ticker, exc)
|
||||
return []
|
||||
|
||||
if not financial_metrics:
|
||||
return []
|
||||
@@ -169,6 +181,8 @@ def search_line_items(
|
||||
"""
|
||||
try:
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
return _router.search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
@@ -190,6 +204,8 @@ def get_insider_trades(
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from cache or API."""
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
for source in _router.api_sources():
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{source}"
|
||||
@@ -197,12 +213,16 @@ def get_insider_trades(
|
||||
if cached_data := _cache.get_insider_trades(cache_key):
|
||||
return [InsiderTrade(**trade) for trade in cached_data]
|
||||
|
||||
all_trades, data_source = _router.get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
try:
|
||||
all_trades, data_source = _router.get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.info("Insider trades lookup failed for %s: %s", ticker, exc)
|
||||
return []
|
||||
|
||||
if not all_trades:
|
||||
return []
|
||||
@@ -219,6 +239,8 @@ def get_company_news(
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from cache or API."""
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
for source in _router.api_sources():
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{source}"
|
||||
@@ -226,12 +248,16 @@ def get_company_news(
|
||||
if cached_data := _cache.get_company_news(cache_key):
|
||||
return [CompanyNews(**news) for news in cached_data]
|
||||
|
||||
all_news, data_source = _router.get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
try:
|
||||
all_news, data_source = _router.get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.info("Company news lookup failed for %s: %s", ticker, exc)
|
||||
return []
|
||||
|
||||
if not all_news:
|
||||
return []
|
||||
@@ -243,6 +269,8 @@ def get_company_news(
|
||||
def get_market_cap(ticker: str, end_date: str) -> float | None:
|
||||
"""Fetch market cap from the API. Finnhub values are converted from millions."""
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return None
|
||||
|
||||
def _metrics_lookup(symbol: str, date: str):
|
||||
for source in _router.api_sources():
|
||||
@@ -256,11 +284,15 @@ def get_market_cap(ticker: str, end_date: str) -> float | None:
|
||||
limit=10,
|
||||
)
|
||||
|
||||
market_cap, _ = _router.get_market_cap(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
metrics_lookup=_metrics_lookup,
|
||||
)
|
||||
try:
|
||||
market_cap, _ = _router.get_market_cap(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
metrics_lookup=_metrics_lookup,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.info("Market cap lookup failed for %s: %s", ticker, exc)
|
||||
return None
|
||||
return market_cap
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user