feat: 微服务架构拆分和前后端优化

后端:
- 拆分出 agent_service, runtime_service, trading_service, news_service
- Gateway 模块化拆分 (gateway_*.py)
- 添加 domains/ 领域层
- 新增 control_client, runtime_client
- 更新 start-dev.sh 支持 split 服务模式

前端:
- 完善 API 服务层 (newsApi, tradingApi)
- 更新 vite.config.js
- Explain 组件优化

测试:
- 添加多个服务 app 测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-23 17:45:39 +08:00
parent 0f1bc2bb39
commit 3448667b79
54 changed files with 5440 additions and 2947 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
"""Runtime/workspace/skills handlers extracted from the main Gateway module."""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
from backend.config.bootstrap_config import (
get_bootstrap_config_for_run,
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.market_ingest import ingest_symbols
from backend.llm.models import get_agent_model_info
async def handle_reload_runtime_assets(gateway: Any) -> None:
config_name = gateway.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=gateway._project_root,
config_name=config_name,
enable_memory=gateway.config.get("enable_memory", False),
schedule_mode=gateway.config.get("schedule_mode", "daily"),
interval_minutes=gateway.config.get("interval_minutes", 60),
trigger_time=gateway.config.get("trigger_time", "09:30"),
)
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = gateway._apply_runtime_config(runtime_config)
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
updates: dict[str, Any] = {}
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
if schedule_mode:
if schedule_mode not in {"daily", "intraday"}:
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
return
updates["schedule_mode"] = schedule_mode
interval_minutes = data.get("interval_minutes")
if interval_minutes is not None:
try:
parsed_interval = int(interval_minutes)
except (TypeError, ValueError):
parsed_interval = 0
if parsed_interval <= 0:
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
return
updates["interval_minutes"] = parsed_interval
trigger_time = data.get("trigger_time")
if trigger_time is not None:
raw_trigger = str(trigger_time).strip()
if raw_trigger and raw_trigger != "now":
try:
datetime.strptime(raw_trigger, "%H:%M")
except ValueError:
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
return
updates["trigger_time"] = raw_trigger or "09:30"
max_comm_cycles = data.get("max_comm_cycles")
if max_comm_cycles is not None:
try:
parsed_cycles = int(max_comm_cycles)
except (TypeError, ValueError):
parsed_cycles = 0
if parsed_cycles <= 0:
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
return
updates["max_comm_cycles"] = parsed_cycles
initial_cash = data.get("initial_cash")
if initial_cash is not None:
try:
parsed_initial_cash = float(initial_cash)
except (TypeError, ValueError):
parsed_initial_cash = 0.0
if parsed_initial_cash <= 0:
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
return
updates["initial_cash"] = parsed_initial_cash
margin_requirement = data.get("margin_requirement")
if margin_requirement is not None:
try:
parsed_margin_requirement = float(margin_requirement)
except (TypeError, ValueError):
parsed_margin_requirement = -1.0
if parsed_margin_requirement < 0:
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
return
updates["margin_requirement"] = parsed_margin_requirement
enable_memory = data.get("enable_memory")
if enable_memory is not None:
updates["enable_memory"] = bool(enable_memory)
if not updates:
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates=updates,
)
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
await handle_reload_runtime_assets(gateway)
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
tickers = gateway._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
await handle_reload_runtime_assets(gateway)
gateway._schedule_watchlist_market_store_refresh(tickers)
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
enabled = set(agent_config.enabled_skills)
disabled = set(agent_config.disabled_skills)
payload = []
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
if item.skill_name in disabled:
status = "disabled"
elif item.skill_name in enabled:
status = "enabled"
elif item.skill_name in resolved_skills:
status = "active"
else:
status = "available"
payload.append({
"skill_name": item.skill_name,
"name": item.name,
"description": item.description,
"version": item.version,
"source": item.source,
"tools": item.tools,
"status": status,
})
await websocket.send(json.dumps({
"type": "agent_skills_loaded",
"config_name": config_name,
"agent_id": agent_id,
"skills": payload,
}, ensure_ascii=False))
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
profiles = load_agent_profiles()
profile = profiles.get(agent_id, {})
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
override = bootstrap.agent_override(agent_id)
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
if not isinstance(active_tool_groups, list):
active_tool_groups = []
disabled_tool_groups = agent_config.disabled_tool_groups
if disabled_tool_groups:
disabled_set = set(disabled_tool_groups)
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
default_skills = profile.get("skills", [])
if not isinstance(default_skills, list):
default_skills = []
resolved_skills = skills_manager.resolve_agent_skill_names(
config_name=config_name,
agent_id=agent_id,
default_skills=default_skills,
)
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
model_name, model_provider = get_agent_model_info(agent_id)
await websocket.send(json.dumps({
"type": "agent_profile_loaded",
"config_name": config_name,
"agent_id": agent_id,
"profile": {
"model_name": model_name,
"model_provider": model_provider,
"prompt_files": prompt_files,
"default_skills": default_skills,
"resolved_skills": resolved_skills,
"active_tool_groups": active_tool_groups,
"disabled_tool_groups": disabled_tool_groups,
"enabled_skills": agent_config.enabled_skills,
"disabled_skills": agent_config.disabled_skills,
},
}, ensure_ascii=False))
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
return
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
if agent_id:
config_name = gateway.config.get("config_name", "default")
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
else:
detail = skills_manager.load_skill_document(skill_name)
except FileNotFoundError:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
await websocket.send(json.dumps({
"type": "skill_detail_loaded",
"agent_id": agent_id,
"skill": detail,
}, ensure_ascii=False))
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
except (ValueError, FileExistsError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
content = data.get("content")
if not agent_id or not skill_name or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {
item.skill_name
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
if item.source != "local"
}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
return
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
enabled = data.get("enabled")
if not agent_id or not skill_name or not isinstance(enabled, bool):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
if enabled:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
else:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
await websocket.send(json.dumps({
"type": "agent_skill_updated",
"agent_id": agent_id,
"skill_name": skill_name,
"enabled": enabled,
}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
if not agent_id or not filename:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
content = path.read_text(encoding="utf-8") if path.exists() else ""
await websocket.send(json.dumps({
"type": "agent_workspace_file_loaded",
"config_name": config_name,
"agent_id": agent_id,
"filename": filename,
"content": content,
}, ensure_ascii=False))
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
content = data.get("content")
if not agent_id or not filename or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
path.write_text(content, encoding="utf-8")
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
"""Cycle and monitoring helpers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from backend.data.market_ingest import ingest_symbols
from backend.domains import trading as trading_domain
from backend.utils.msg_adapter import FrontendAdapter
logger = logging.getLogger(__name__)
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
if not tickers:
return
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
gateway._watchlist_ingest_task.cancel()
gateway._watchlist_ingest_task = asyncio.create_task(
refresh_market_store_for_watchlist(gateway, tickers),
)
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await gateway.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await gateway.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await gateway.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def market_status_monitor(gateway: Any) -> None:
"""Periodically check and broadcast market status changes."""
while True:
try:
await gateway.market_service.check_and_broadcast_market_status()
status = gateway.market_service.get_market_status()
if status["status"] == "open" and not gateway.storage.is_live_session_active:
gateway.storage.start_live_session()
summary = gateway.storage.load_file("summary") or {}
gateway._session_start_portfolio_value = summary.get(
"totalAssetValue",
gateway.storage.initial_cash,
)
logger.info(
"Session start portfolio: $%s",
f"{gateway._session_start_portfolio_value:,.2f}",
)
elif status["status"] != "open" and gateway.storage.is_live_session_active:
gateway.storage.end_live_session()
gateway._session_start_portfolio_value = None
if gateway.storage.is_live_session_active:
await update_and_broadcast_live_returns(gateway)
await asyncio.sleep(60)
except asyncio.CancelledError:
break
except Exception as exc:
logger.error("Market status monitor error: %s", exc)
await asyncio.sleep(60)
async def update_and_broadcast_live_returns(gateway: Any) -> None:
"""Calculate and broadcast live returns for current session."""
if not gateway.storage.is_live_session_active:
return
prices = gateway.market_service.get_all_prices()
if not prices or not any(p > 0 for p in prices.values()):
return
state = gateway.storage.load_internal_state()
equity_history = state.get("equity_history", [])
baseline_history = state.get("baseline_history", [])
baseline_vw_history = state.get("baseline_vw_history", [])
momentum_history = state.get("momentum_history", [])
current_equity = equity_history[-1]["v"] if equity_history else None
current_baseline = baseline_history[-1]["v"] if baseline_history else None
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
current_momentum = momentum_history[-1]["v"] if momentum_history else None
point = gateway.storage.update_live_returns(
current_equity=current_equity,
current_baseline=current_baseline,
current_baseline_vw=current_baseline_vw,
current_momentum=current_momentum,
)
if point:
live_returns = gateway.storage.get_live_returns()
await gateway.broadcast(
{
"type": "team_summary",
"equity_return": live_returns["equity_return"],
"baseline_return": live_returns["baseline_return"],
"baseline_vw_return": live_returns["baseline_vw_return"],
"momentum_return": live_returns["momentum_return"],
},
)
async def on_strategy_trigger(gateway: Any, date: str) -> None:
"""Handle trading cycle trigger."""
if gateway._cycle_lock.locked():
logger.warning("Trading cycle already running, skipping trigger for %s", date)
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
return
async with gateway._cycle_lock:
logger.info("Strategy triggered for %s", date)
tickers = gateway.config.get("tickers", [])
if gateway.is_backtest:
await run_backtest_cycle(gateway, date, tickers)
else:
await run_live_cycle(gateway, date, tickers)
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
"""Run lightweight heartbeat check for all analysts."""
logger.info("[Heartbeat] Running heartbeat check for %s", date)
analysts = gateway.pipeline._all_analysts()
for analyst in analysts:
try:
ws_id = getattr(analyst, "workspace_id", None)
if ws_id:
from backend.agents.workspace_manager import get_workspace_dir
from pathlib import Path
from agentscope.message import Msg
ws_dir = get_workspace_dir(ws_id)
if ws_dir:
hb_path = Path(ws_dir) / "HEARTBEAT.md"
if hb_path.exists():
content = hb_path.read_text(encoding="utf-8").strip()
if content:
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
msg = Msg(role="user", content=hb_task, name="system")
await analyst.reply([msg])
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
continue
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
except Exception as exc:
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
gateway.market_service.set_backtest_date(date)
await gateway.market_service.emit_market_open()
await gateway.state_sync.on_cycle_start(date)
gateway._dashboard.update(date=date, status="Analyzing...")
prices = gateway.market_service.get_open_prices()
close_prices = gateway.market_service.get_close_prices()
market_caps = await get_market_caps(gateway, tickers, date)
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=date,
prices=prices,
close_prices=close_prices,
market_caps=market_caps,
)
await gateway.market_service.emit_market_close()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, date)
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
trading_date = gateway.market_service.get_live_trading_date()
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
await gateway.state_sync.on_cycle_start(trading_date)
gateway._dashboard.update(date=trading_date, status="Analyzing...")
market_caps = await get_market_caps(gateway, tickers, trading_date)
schedule_mode = gateway.config.get("schedule_mode", "daily")
market_status = gateway.market_service.get_market_status()
current_prices = gateway.market_service.get_all_prices()
if schedule_mode == "intraday":
execute_decisions = market_status.get("status") == "open"
if execute_decisions:
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
else:
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
prices=current_prices,
market_caps=market_caps,
execute_decisions=execute_decisions,
)
close_prices = current_prices
else:
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
market_caps=market_caps,
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
)
close_prices = gateway.market_service.get_all_prices()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, trading_date)
async def finalize_cycle(gateway: Any, date: str) -> None:
summary = gateway.storage.load_file("summary") or {}
if gateway.storage.is_live_session_active:
summary.update(gateway.storage.get_live_returns())
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
holdings = gateway.storage.load_file("holdings") or []
trades = gateway.storage.load_file("trades") or []
leaderboard = gateway.storage.load_file("leaderboard") or []
if leaderboard:
await gateway.state_sync.on_leaderboard_update(leaderboard)
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
market_caps: dict[str, float] = {}
for ticker in tickers:
try:
market_cap = None
response = await gateway._call_trading_service(
f"get_market_cap for {ticker}",
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
)
if response is not None:
market_cap = response.get("market_cap")
if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9
except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
market_caps[ticker] = 1e9
return market_caps
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
portfolio = result.get("portfolio", {})
if portfolio:
holdings = FrontendAdapter.build_holdings(portfolio, prices)
if holdings:
await gateway.state_sync.on_holdings_update(holdings)
stats = FrontendAdapter.build_stats(portfolio, prices)
if stats:
await gateway.state_sync.on_stats_update(stats)
executed_trades = result.get("executed_trades", [])
if executed_trades:
await gateway.state_sync.on_trades_executed(executed_trades)
def save_cycle_results(
gateway: Any,
result: dict[str, Any],
date: str,
prices: dict[str, float],
settlement_result: dict[str, Any] | None = None,
) -> None:
portfolio = result.get("portfolio", {})
executed_trades = result.get("executed_trades", [])
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
if portfolio:
gateway.storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date=date,
executed_trades=executed_trades,
baseline_values=baseline_values,
)
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
gateway._dashboard.update(days_total=len(dates), days_completed=0)
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
try:
for i, date in enumerate(dates):
gateway._dashboard.update(days_completed=i)
await gateway.on_strategy_trigger(date=date)
await asyncio.sleep(0.1)
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
summary = gateway.storage.load_file("summary") or {}
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
gateway._dashboard.stop()
gateway._dashboard.print_final_summary()
except Exception as exc:
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
logger.error(error_msg, exc_info=True)
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
gateway._dashboard.update(status=f"Failed: {str(exc)}")
gateway._dashboard.stop()
raise
finally:
gateway._backtest_task = None
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
logger.info("Backtest task was cancelled")
except Exception as exc:
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
gateway._manual_cycle_task = None
try:
task.result()
except asyncio.CancelledError:
logger.info("Manual cycle task was cancelled")
except Exception as exc:
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
if dates:
gateway._backtest_start_date = dates[0]
gateway._backtest_end_date = dates[-1]
gateway._dashboard.days_total = len(dates)
def stop_gateway(gateway: Any) -> None:
gateway.state_sync.save_state()
gateway.market_service.stop()
if gateway._backtest_task:
gateway._backtest_task.cancel()
if gateway._market_status_task:
gateway._market_status_task.cancel()
if gateway._watchlist_ingest_task:
gateway._watchlist_ingest_task.cancel()
gateway._dashboard.stop()

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
"""Runtime/state support helpers extracted from the main Gateway module."""
from __future__ import annotations
from typing import Any
from backend.data.provider_utils import normalize_symbol
def normalize_watchlist(raw_tickers: Any) -> list[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: list[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def normalize_agent_workspace_filename(
raw_name: Any,
*,
allowlist: set[str],
) -> str | None:
"""Restrict editable workspace files to a safe allowlist."""
filename = str(raw_name or "").strip()
if filename in allowlist:
return filename
return None
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: list[str] = []
ticker_changes = gateway.market_service.update_tickers(
runtime_config.get("tickers", []),
)
gateway.config["tickers"] = ticker_changes["active"]
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
gateway.config["schedule_mode"] = runtime_config.get(
"schedule_mode",
gateway.config.get("schedule_mode", "daily"),
)
gateway.config["interval_minutes"] = int(
runtime_config.get(
"interval_minutes",
gateway.config.get("interval_minutes", 60),
),
)
gateway.config["trigger_time"] = runtime_config.get(
"trigger_time",
gateway.config.get("trigger_time", "09:30"),
)
if gateway.scheduler:
gateway.scheduler.reconfigure(
mode=gateway.config["schedule_mode"],
trigger_time=gateway.config["trigger_time"],
interval_minutes=gateway.config["interval_minutes"],
)
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(gateway.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
gateway.storage.can_apply_initial_cash()
and gateway.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = gateway.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
gateway.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
gateway.config["initial_cash"] = gateway.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(gateway.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
sync_runtime_state(gateway)
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(gateway.config.get("tickers", [])),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config["margin_requirement"],
"max_comm_cycles": gateway.config["max_comm_cycles"],
"enable_memory": gateway.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"schedule_mode": True,
"interval_minutes": True,
"trigger_time": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def sync_runtime_state(gateway: Any) -> None:
"""Refresh persisted state and dashboard after runtime config changes."""
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
gateway.state_sync.update_state(
"runtime_config",
{
"tickers": gateway.config.get("tickers", []),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config.get("margin_requirement"),
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
"enable_memory": gateway.config.get("enable_memory", False),
},
)
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
gateway.state_sync.save_state()
gateway._dashboard.tickers = list(gateway.config.get("tickers", []))
gateway._dashboard.initial_cash = gateway.storage.initial_cash
gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False))
summary = gateway.storage.load_file("summary") or {}
holdings = gateway.storage.load_file("holdings") or []
trades = gateway.storage.load_file("trades") or []
gateway._dashboard.update(
portfolio=summary,
holdings=holdings,
trades=trades,
)

View File

@@ -0,0 +1,711 @@
# -*- coding: utf-8 -*-
"""Stock-related Gateway handlers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain
from backend.domains import trading as trading_domain
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.tools.data_tools import prices_to_df
from shared.client import NewsServiceClient, TradingServiceClient
logger = logging.getLogger(__name__)
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
prices = []
source = "polygon"
response = await gateway._call_trading_service(
"get_prices for history",
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
prices = response.prices
source = "trading_service"
if not prices:
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
if not prices:
payload = await asyncio.to_thread(
trading_domain.get_prices_payload,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
prices = payload.get("prices") or []
usage_snapshot = gateway._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
gateway.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(json.dumps({
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
}, ensure_ascii=False, default=str))
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
news_rows = []
source = "polygon"
response = await gateway._call_news_service(
"get_enriched_news",
lambda client: client.get_enriched_news(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_enriched_news,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
)
news_rows = (payload.get("news") or [])[-limit:]
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
source = "market_store"
news_rows = []
response = await gateway._call_news_service(
"get_news_for_date",
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_news_for_date,
gateway.storage.market_store,
ticker=ticker,
date=trade_date,
limit=limit,
)
news_rows = payload.get("news") or []
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": source,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
timeline = []
response = await gateway._call_news_service(
"get_news_timeline",
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
timeline = response.get("timeline") or []
if not timeline:
payload = await asyncio.to_thread(
news_domain.get_news_timeline,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
timeline = payload.get("timeline") or []
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
categories = {}
response = await gateway._call_news_service(
"get_categories",
lambda client: client.get_categories(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
),
)
if response is not None:
categories = response.get("categories") or {}
if not categories:
payload = await asyncio.to_thread(
news_domain.get_news_categories,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
categories = payload.get("categories") or {}
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
}, ensure_ascii=False))
return
article_ids = data.get("article_ids")
result = None
response = await gateway._call_news_service(
"get_range_explain",
lambda client: client.get_range_explain(
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
),
)
if response is not None:
result = response.get("result")
if result is None:
payload = await asyncio.to_thread(
news_domain.get_range_explain_payload,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
)
result = payload.get("result")
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": "",
"trades": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
start_date = str(data.get("start_date") or "").strip()[:10]
limit = int(data.get("limit", 50))
trades = []
response = await gateway._call_trading_service(
"get_insider_trades",
lambda client: client.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
),
)
if response is not None:
trades = response.insider_trades
if not trades:
payload = await asyncio.to_thread(
trading_domain.get_insider_trades_payload,
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
)
trades = payload.get("insider_trades") or []
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
formatted_trades = [{
"ticker": t.ticker,
"name": t.name,
"title": t.title,
"is_board_director": t.is_board_director,
"transaction_date": t.transaction_date,
"transaction_shares": t.transaction_shares,
"transaction_price_per_share": t.transaction_price_per_share,
"transaction_value": t.transaction_value,
"shares_owned_before_transaction": t.shares_owned_before_transaction,
"shares_owned_after_transaction": t.shares_owned_after_transaction,
"security_title": t.security_title,
"filing_date": t.filing_date,
"holding_change": (
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
),
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
} for t in sorted_trades]
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": ticker,
"start_date": start_date or None,
"end_date": end_date,
"trades": formatted_trades,
}, ensure_ascii=False, default=str))
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
}, ensure_ascii=False))
return
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
result = await gateway._call_news_service(
"get_story",
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=as_of_date,
)
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
}, ensure_ascii=False, default=str))
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
result = await gateway._call_news_service(
"get_similar_days",
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=top_k,
)
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=250)
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
if prices is None:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
import pandas as pd
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
ma_distance = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_value = getattr(signal, ma_key, None)
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
indicators = {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_distance,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": vol_10,
"vol_20d": vol_20,
"vol_60d": vol_60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str))
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
}, ensure_ascii=False))
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
}, ensure_ascii=False))
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
gateway.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
story_result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=story_date,
)
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
similar_result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
}, ensure_ascii=False, default=str))

View File

@@ -9,7 +9,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable
from backend.data.schema import CompanyNews
from shared.schema import CompanyNews
SCHEMA = """