2 Commits

Author SHA1 Message Date
3926a6bd07 feat: 架构修复 - P0/P1 问题全面修复
P0 修复:
- runtimeStore: 添加缺失的 lastDayHistory 字段
- Gateway/RuntimeService: 状态同步改为内存优先,消除 glob 竞态
- App.jsx: 从 3075 行重构到 ~500 行,提取 8 个独立文件

P1 修复:
- CORS: 4 个服务改为从环境变量读取允许 origins
- MarketStore: 改为模块级单例模式
- Domain 层: 删除 trading thin wrapper,保留 news 真实逻辑
- 测试: 补齐 77 个 gateway/runtime 测试

新增文件:
- backend/tests/test_gateway.py (43 tests)
- frontend/src/hooks/useWebSocketHandler.js
- frontend/src/hooks/useStockRequestCallbacks.js
- frontend/src/hooks/useAgentCallbacks.js
- frontend/src/hooks/useRuntimeCallbacks.js
- frontend/src/hooks/useWatchlistCallbacks.js
- frontend/src/components/TickerBar.jsx
- frontend/src/components/HeaderRight.jsx
- frontend/src/components/ChartTabs.jsx

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 18:45:57 +08:00
80256a4079 fix(frontend): 添加缺失的 lastDayHistory 字段到 runtimeStore
修复 App.jsx 中使用不存在的 store 字段导致的潜在运行时错误。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 18:20:29 +08:00
22 changed files with 4284 additions and 2790 deletions

View File

@@ -302,36 +302,28 @@ def _start_gateway_process(
@router.get("/context", response_model=RunContextResponse) @router.get("/context", response_model=RunContextResponse)
async def get_run_context() -> RunContextResponse: async def get_run_context() -> RunContextResponse:
"""Return the most recent run context.""" """Return the current run context from in-memory state (avoids glob race condition)."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") manager = _runtime_state.runtime_manager
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) if manager is None or manager.context is None:
if not snapshots:
raise HTTPException(status_code=404, detail="No run context available") raise HTTPException(status_code=404, detail="No run context available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8")) context = manager.context
context = latest.get("context")
if context is None:
raise HTTPException(status_code=404, detail="Run context is not ready")
return RunContextResponse( return RunContextResponse(
config_name=context["config_name"], config_name=context.config_name,
run_dir=context["run_dir"], run_dir=str(context.run_dir),
bootstrap_values=context["bootstrap_values"], bootstrap_values=context.bootstrap_values,
) )
@router.get("/agents", response_model=RuntimeAgentsResponse) @router.get("/agents", response_model=RuntimeAgentsResponse)
async def get_runtime_agents() -> RuntimeAgentsResponse: async def get_runtime_agents() -> RuntimeAgentsResponse:
"""Return agent states from the most recent run.""" """Return agent states from the in-memory runtime manager (avoids glob race condition)."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") manager = _runtime_state.runtime_manager
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) if manager is None:
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime state available") raise HTTPException(status_code=404, detail="No runtime state available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8")) snapshot = manager.build_snapshot()
agents = latest.get("agents", []) agents = snapshot.get("agents", [])
return RuntimeAgentsResponse( return RuntimeAgentsResponse(
agents=[RuntimeAgentState(**a) for a in agents] agents=[RuntimeAgentState(**a) for a in agents]
@@ -340,15 +332,13 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
@router.get("/events", response_model=RuntimeEventsResponse) @router.get("/events", response_model=RuntimeEventsResponse)
async def get_runtime_events() -> RuntimeEventsResponse: async def get_runtime_events() -> RuntimeEventsResponse:
"""Return events from the most recent run.""" """Return events from the in-memory runtime manager (avoids glob race condition)."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") manager = _runtime_state.runtime_manager
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) if manager is None:
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime state available") raise HTTPException(status_code=404, detail="No runtime state available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8")) snapshot = manager.build_snapshot()
events = latest.get("events", []) events = snapshot.get("events", [])
return RuntimeEventsResponse( return RuntimeEventsResponse(
events=[RuntimeEvent(**e) for e in events] events=[RuntimeEvent(**e) for e in events]
@@ -362,15 +352,10 @@ async def get_gateway_status() -> GatewayStatusResponse:
run_id = None run_id = None
if is_running: if is_running:
# Try to find run_id from runtime state # Get run_id from in-memory runtime manager (avoids glob race condition)
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") manager = _runtime_state.runtime_manager
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) if manager is not None and manager.context is not None:
if snapshots: run_id = manager.context.config_name
try:
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
run_id = latest.get("context", {}).get("config_name")
except Exception as e:
logger.warning(f"Failed to parse latest snapshot: {e}")
return GatewayStatusResponse( return GatewayStatusResponse(
is_running=is_running, is_running=is_running,
@@ -404,8 +389,28 @@ def _build_gateway_ws_url(request: Request, port: int) -> str:
return f"{ws_scheme}://{host}:{port}" return f"{ws_scheme}://{host}:{port}"
def _load_latest_runtime_snapshot() -> Dict[str, Any]: def _get_current_runtime_context() -> Dict[str, Any]:
"""Load the latest persisted runtime snapshot.""" """Return the active runtime context from the in-memory manager (avoids glob race condition).
Falls back to file-based lookup only when the in-memory manager is not available
(e.g., after a service restart). File-based lookup is deprecated and exists
only for backward compatibility.
"""
if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running")
# Primary: use in-memory manager (always correct for current process)
manager = _runtime_state.runtime_manager
if manager is not None and manager.context is not None:
ctx = manager.context
return {
"config_name": ctx.config_name,
"run_dir": str(ctx.run_dir),
"bootstrap_values": ctx.bootstrap_values,
}
# Deprecated fallback: scan filesystem (only for backward compatibility
# after service restart without a restart of the runtime itself)
snapshots = sorted( snapshots = sorted(
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"), PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
key=lambda p: p.stat().st_mtime, key=lambda p: p.stat().st_mtime,
@@ -413,14 +418,7 @@ def _load_latest_runtime_snapshot() -> Dict[str, Any]:
) )
if not snapshots: if not snapshots:
raise HTTPException(status_code=404, detail="No runtime information available") raise HTTPException(status_code=404, detail="No runtime information available")
return json.loads(snapshots[0].read_text(encoding="utf-8")) latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
def _get_current_runtime_context() -> Dict[str, Any]:
"""Return the active runtime context from the latest snapshot."""
if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running")
latest = _load_latest_runtime_snapshot()
context = latest.get("context") or {} context = latest.get("context") or {}
if not context.get("config_name"): if not context.get("config_name"):
raise HTTPException(status_code=404, detail="No runtime context available") raise HTTPException(status_code=404, detail="No runtime context available")
@@ -663,15 +661,8 @@ async def get_current_runtime():
if not _is_gateway_running(): if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running") raise HTTPException(status_code=404, detail="No runtime is currently running")
# Find latest runtime state # Get context from in-memory manager (avoids glob race condition)
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") context = _get_current_runtime_context()
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime information available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
context = latest.get("context", {})
return { return {
"run_id": context.get("config_name"), "run_id": context.get("config_name"),

View File

@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from backend.api import agents_router, guard_router, workspaces_router from backend.api import agents_router, guard_router, workspaces_router
from backend.agents import AgentFactory, WorkspaceManager, get_registry from backend.agents import AgentFactory, WorkspaceManager, get_registry
from backend.config.env_config import get_cors_origins
# Global instances (initialized on startup) # Global instances (initialized on startup)
agent_factory: AgentFactory | None = None agent_factory: AgentFactory | None = None
@@ -49,7 +50,7 @@ def create_app(project_root: Path | None = None) -> FastAPI:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=get_cors_origins(),
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
from backend.data.market_store import MarketStore from backend.data.market_store import MarketStore
from backend.domains import news as news_domain from backend.domains import news as news_domain
from backend.config.env_config import get_cors_origins
def get_market_store() -> MarketStore: def get_market_store() -> MarketStore:
@@ -27,7 +28,7 @@ def create_app() -> FastAPI:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=get_cors_origins(),
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
from backend.api import runtime_router from backend.api import runtime_router
from backend.api.runtime import get_runtime_state from backend.api.runtime import get_runtime_state
from backend.config.env_config import get_cors_origins
def create_app() -> FastAPI: def create_app() -> FastAPI:
@@ -20,7 +21,7 @@ def create_app() -> FastAPI:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=get_cors_origins(),
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],

View File

@@ -8,7 +8,16 @@ from typing import Any
from fastapi import FastAPI, Query from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from backend.domains import trading as trading_domain from backend.config.env_config import get_cors_origins
from backend.services.market import MarketService
from backend.tools.data_tools import (
get_company_news,
get_financial_metrics,
get_insider_trades,
get_market_cap,
get_prices,
search_line_items,
)
from shared.schema import ( from shared.schema import (
CompanyNewsResponse, CompanyNewsResponse,
FinancialMetricsResponse, FinancialMetricsResponse,
@@ -28,7 +37,7 @@ def create_app() -> FastAPI:
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=get_cors_origins(),
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
@@ -45,12 +54,8 @@ def create_app() -> FastAPI:
start_date: str = Query(...), start_date: str = Query(...),
end_date: str = Query(...), end_date: str = Query(...),
) -> PriceResponse: ) -> PriceResponse:
payload = trading_domain.get_prices_payload( prices = get_prices(ticker=ticker, start_date=start_date, end_date=end_date)
ticker=ticker, return PriceResponse(ticker=ticker, prices=prices)
start_date=start_date,
end_date=end_date,
)
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
@app.get("/api/financials", response_model=FinancialMetricsResponse) @app.get("/api/financials", response_model=FinancialMetricsResponse)
async def api_get_financials( async def api_get_financials(
@@ -59,13 +64,13 @@ def create_app() -> FastAPI:
period: str = Query("ttm"), period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100), limit: int = Query(10, ge=1, le=100),
) -> FinancialMetricsResponse: ) -> FinancialMetricsResponse:
payload = trading_domain.get_financials_payload( metrics = get_financial_metrics(
ticker=ticker, ticker=ticker,
end_date=end_date, end_date=end_date,
period=period, period=period,
limit=limit, limit=limit,
) )
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"]) return FinancialMetricsResponse(financial_metrics=metrics)
@app.get("/api/news", response_model=CompanyNewsResponse) @app.get("/api/news", response_model=CompanyNewsResponse)
async def api_get_news( async def api_get_news(
@@ -74,13 +79,13 @@ def create_app() -> FastAPI:
start_date: str | None = Query(None), start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000), limit: int = Query(1000, ge=1, le=5000),
) -> CompanyNewsResponse: ) -> CompanyNewsResponse:
payload = trading_domain.get_news_payload( news = get_company_news(
ticker=ticker, ticker=ticker,
end_date=end_date, end_date=end_date,
start_date=start_date, start_date=start_date,
limit=limit, limit=limit,
) )
return CompanyNewsResponse(news=payload["news"]) return CompanyNewsResponse(news=news)
@app.get("/api/insider-trades", response_model=InsiderTradeResponse) @app.get("/api/insider-trades", response_model=InsiderTradeResponse)
async def api_get_insider_trades( async def api_get_insider_trades(
@@ -89,18 +94,19 @@ def create_app() -> FastAPI:
start_date: str | None = Query(None), start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000), limit: int = Query(1000, ge=1, le=5000),
) -> InsiderTradeResponse: ) -> InsiderTradeResponse:
payload = trading_domain.get_insider_trades_payload( trades = get_insider_trades(
ticker=ticker, ticker=ticker,
end_date=end_date, end_date=end_date,
start_date=start_date, start_date=start_date,
limit=limit, limit=limit,
) )
return InsiderTradeResponse(insider_trades=payload["insider_trades"]) return InsiderTradeResponse(insider_trades=trades)
@app.get("/api/market/status") @app.get("/api/market/status")
async def api_get_market_status() -> dict[str, Any]: async def api_get_market_status() -> dict[str, Any]:
"""Return current market status using the existing market service logic.""" """Return current market status using the existing market service logic."""
return trading_domain.get_market_status_payload() service = MarketService(tickers=[])
return service.get_market_status()
@app.get("/api/market-cap") @app.get("/api/market-cap")
async def api_get_market_cap( async def api_get_market_cap(
@@ -108,10 +114,12 @@ def create_app() -> FastAPI:
end_date: str = Query(...), end_date: str = Query(...),
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Return market cap for one ticker/date.""" """Return market cap for one ticker/date."""
return trading_domain.get_market_cap_payload( market_cap = get_market_cap(ticker=ticker, end_date=end_date)
ticker=ticker, return {
end_date=end_date, "ticker": ticker,
) "end_date": end_date,
"market_cap": market_cap,
}
@app.get("/api/line-items", response_model=LineItemResponse) @app.get("/api/line-items", response_model=LineItemResponse)
async def api_get_line_items( async def api_get_line_items(
@@ -121,14 +129,14 @@ def create_app() -> FastAPI:
period: str = Query("ttm"), period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100), limit: int = Query(10, ge=1, le=100),
) -> LineItemResponse: ) -> LineItemResponse:
payload = trading_domain.get_line_items_payload( items = search_line_items(
ticker=ticker, ticker=ticker,
line_items=line_items, line_items=line_items,
end_date=end_date, end_date=end_date,
period=period, period=period,
limit=limit, limit=limit,
) )
return LineItemResponse(search_results=payload["search_results"]) return LineItemResponse(search_results=items)
return app return app

View File

@@ -3,6 +3,7 @@
"""Environment config helpers with light validation and normalization.""" """Environment config helpers with light validation and normalization."""
import os import os
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
@@ -16,6 +17,36 @@ PROVIDER_ALIASES = {
"vertexai": "GEMINI", "vertexai": "GEMINI",
} }
# Default dev CORS origins (localhost variants used by common dev servers)
_LOCALHOST_ORIGINS = [
"http://localhost:5173",
"http://localhost:3000",
"http://localhost:8000",
"http://127.0.0.1:5173",
"http://127.0.0.1:3000",
"http://127.0.0.1:8000",
]
def get_cors_origins() -> list[str]:
"""Get CORS allowed origins from environment.
Reads CORS_ALLOWED_ORIGINS env var (comma-separated).
Falls back to localhost dev origins if not set.
Warns if "*" is configured (only acceptable for local dev).
"""
origins = get_env_list("CORS_ALLOWED_ORIGINS", default=[])
if origins:
if "*" in origins:
warnings.warn(
"CORS_ALLOWED_ORIGINS contains '*' — this allows any origin. "
"Only use in local development, never in production.",
UserWarning,
)
return origins
# Fallback: local dev only
return _LOCALHOST_ORIGINS
@dataclass(frozen=True) @dataclass(frozen=True)
class AgentModelConfig: class AgentModelConfig:

View File

@@ -8,7 +8,7 @@ import logging
from typing import Any from typing import Any
from backend.data.market_ingest import ingest_symbols from backend.data.market_ingest import ingest_symbols
from backend.domains import trading as trading_domain from backend.tools.data_tools import get_market_cap
from backend.utils.msg_adapter import FrontendAdapter from backend.utils.msg_adapter import FrontendAdapter
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -265,8 +265,7 @@ async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[s
if response is not None: if response is not None:
market_cap = response.get("market_cap") market_cap = response.get("market_cap")
if market_cap is None: if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date) market_cap = get_market_cap(ticker=ticker, end_date=date)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9 market_caps[ticker] = market_cap if market_cap else 1e9
except Exception as exc: except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc) logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)

View File

@@ -11,10 +11,9 @@ from typing import Any
from backend.data.provider_utils import normalize_symbol from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain from backend.domains import news as news_domain
from backend.domains import trading as trading_domain
from backend.enrich.news_enricher import enrich_news_for_symbol from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.enrich.llm_enricher import llm_enrichment_enabled from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.tools.data_tools import prices_to_df from backend.tools.data_tools import get_insider_trades, get_prices, prices_to_df
from shared.client import NewsServiceClient, TradingServiceClient from shared.client import NewsServiceClient, TradingServiceClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,13 +58,12 @@ async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str,
if not prices: if not prices:
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date) prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
if not prices: if not prices:
payload = await asyncio.to_thread( prices = await asyncio.to_thread(
trading_domain.get_prices_payload, get_prices,
ticker=ticker, ticker=ticker,
start_date=start_date, start_date=start_date,
end_date=end_date, end_date=end_date,
) )
prices = payload.get("prices") or []
usage_snapshot = gateway._provider_router.get_usage_snapshot() usage_snapshot = gateway._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices") source = usage_snapshot.get("last_success", {}).get("prices")
if prices: if prices:
@@ -400,14 +398,13 @@ async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: di
trades = response.insider_trades trades = response.insider_trades
if not trades: if not trades:
payload = await asyncio.to_thread( trades = await asyncio.to_thread(
trading_domain.get_insider_trades_payload, get_insider_trades,
ticker=ticker, ticker=ticker,
end_date=end_date, end_date=end_date,
start_date=start_date if start_date else None, start_date=start_date if start_date else None,
limit=limit, limit=limit,
) )
trades = payload.get("insider_trades") or []
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True) sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
formatted_trades = [{ formatted_trades = [{
@@ -540,12 +537,11 @@ async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, da
prices = response.prices prices = response.prices
if prices is None: if prices is None:
payload = trading_domain.get_prices_payload( prices = get_prices(
ticker=ticker, ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"), start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"), end_date=end_date.strftime("%Y-%m-%d"),
) )
prices = payload.get("prices") or []
if not prices or len(prices) < 20: if not prices or len(prices) < 20:
await websocket.send(json.dumps({ await websocket.send(json.dumps({

View File

@@ -0,0 +1,549 @@
# -*- coding: utf-8 -*-
"""Tests for the Gateway main class - core behavior and fallback paths."""
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
class DummyWebSocket:
def __init__(self):
self.messages = []
self.closed = False
self._queue = []
def queue(self, data: str):
"""Queue a raw message string to be yielded by the async iterator."""
self._queue.append(data)
def __aiter__(self):
return self
async def __anext__(self):
if not self._queue:
raise StopAsyncIteration
return self._queue.pop(0)
async def send(self, payload: str):
self.messages.append(json.loads(payload))
async def close(self):
self.closed = True
class DummyStateSync:
def __init__(self, current_date="2026-03-16"):
self.state = {"current_date": current_date}
self.system_messages = []
self.saved = False
self.initial_state_payload = {}
def set_broadcast_fn(self, _fn):
return None
def update_state(self, key, value):
self.state[key] = value
def save_state(self):
self.saved = True
async def on_system_message(self, message):
self.system_messages.append(message)
def get_initial_state_payload(self, include_dashboard=True):
return {
"status": "running",
"current_date": self.state.get("current_date", ""),
"portfolio": {},
"holdings": [],
"trades": [],
}
class DummyMarketService:
def __init__(self):
self.broadcast_func = None
self.market_status = {"is_open": True, "session": "regular"}
def set_price_recorder(self, _fn):
return None
async def start(self, broadcast_func=None):
self.broadcast_func = broadcast_func
def get_market_status(self):
return self.market_status
class DummyStorage:
def __init__(self, initial_cash=100000.0, live_session=False):
self.initial_cash = initial_cash
self.is_live_session_active = live_session
self._market_store = SimpleNamespace()
@property
def market_store(self):
return self._market_store
def load_file(self, name):
if name == "summary":
return {"totalAssetValue": self.initial_cash}
if name in ("holdings", "trades"):
return []
return None
def get_live_returns(self):
return {"session_pnl": 0.0, "session_return": 0.0}
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
storage = storage or DummyStorage()
state_sync = state_sync or DummyStateSync()
market_service = market_service or DummyMarketService()
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
return Gateway(
market_service=market_service,
storage_service=storage,
pipeline=pipeline,
state_sync=state_sync,
config=config or {"mode": "live"},
)
# =============================================================================
# Gateway initialization and core properties
# =============================================================================
def test_gateway_init_sets_live_mode():
gateway = make_gateway(config={"mode": "live"})
assert gateway.mode == "live"
assert gateway.is_backtest is False
def test_gateway_init_sets_backtest_mode_from_config():
gateway = make_gateway(config={"mode": "backtest"})
assert gateway.mode == "backtest"
assert gateway.is_backtest is True
def test_gateway_init_sets_backtest_mode_from_flag():
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
assert gateway.is_backtest is True
def test_gateway_init_defaults_to_live_mode():
gateway = make_gateway(config={})
assert gateway.mode == "live"
assert gateway.is_backtest is False
def test_gateway_state_property_returns_state_sync_state():
state_sync = DummyStateSync()
state_sync.state["foo"] = "bar"
gateway = make_gateway(state_sync=state_sync)
assert gateway.state["foo"] == "bar"
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
rows = [{"id": "1"}, {"id": "2"}]
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
result = Gateway._news_rows_need_enrichment(rows)
mock.assert_called_once_with(rows)
assert result is True
# =============================================================================
# Service URL helpers and fallback paths
# =============================================================================
def test_news_service_url_returns_config_value(monkeypatch):
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
assert gateway._news_service_url() == "http://custom-news:9000"
def test_news_service_url_falls_back_to_env(monkeypatch):
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
gateway = make_gateway(config={})
assert gateway._news_service_url() == "http://env-news:9001"
def test_news_service_url_returns_none_when_unset(monkeypatch):
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
assert gateway._news_service_url() is None
def test_news_service_url_strips_whitespace(monkeypatch):
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
assert gateway._news_service_url() == "http://whitespace-news:9000"
def test_trading_service_url_returns_config_value(monkeypatch):
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
assert gateway._trading_service_url() == "http://custom-trading:9000"
def test_trading_service_url_falls_back_to_env(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
gateway = make_gateway(config={})
assert gateway._trading_service_url() == "http://env-trading:9001"
def test_trading_service_url_returns_none_when_unset(monkeypatch):
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
assert gateway._trading_service_url() is None
def test_trading_service_url_strips_whitespace(monkeypatch):
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
@pytest.mark.asyncio
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
async def dummy_callback(client):
return "should not be called"
result = await gateway._call_news_service("test_action", dummy_callback)
assert result is None
@pytest.mark.asyncio
async def test_call_news_service_calls_callback_and_returns():
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
async def callback(client):
return {"result": "ok"}
result = await gateway._call_news_service("test_action", callback)
assert result == {"result": "ok"}
@pytest.mark.asyncio
async def test_call_news_service_returns_none_on_exception():
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
async def failing_callback(client):
raise RuntimeError("connection failed")
result = await gateway._call_news_service("test_action", failing_callback)
assert result is None
@pytest.mark.asyncio
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
result = await gateway._call_trading_service("test_action", lambda c: None)
assert result is None
@pytest.mark.asyncio
async def test_call_trading_service_calls_callback_and_returns():
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
async def callback(client):
return {"result": "ok"}
result = await gateway._call_trading_service("test_action", callback)
assert result == {"result": "ok"}
@pytest.mark.asyncio
async def test_call_trading_service_returns_none_on_exception():
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
async def failing_callback(client):
raise RuntimeError("connection failed")
result = await gateway._call_trading_service("test_action", failing_callback)
assert result is None
# =============================================================================
# WebSocket message handlers
# =============================================================================
@pytest.mark.asyncio
async def test_handle_client_messages_ping_returns_pong():
"""Ping message type results in a pong response."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "ping"}))
await gateway._handle_client_messages(ws)
assert ws.messages[-1]["type"] == "pong"
assert "timestamp" in ws.messages[-1]
@pytest.mark.asyncio
async def test_handle_client_messages_get_state_sends_initial_state():
"""get_state message type triggers _send_initial_state."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "get_state"}))
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
await gateway._handle_client_messages(ws)
mock_send.assert_called_once_with(ws)
@pytest.mark.asyncio
async def test_handle_client_messages_unknown_type_is_silently_ignored():
"""Unknown message types are silently ignored without error."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "unknown_type"}))
# Should not raise
await gateway._handle_client_messages(ws)
assert len(ws.messages) == 0
@pytest.mark.asyncio
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
"""Invalid JSON messages are caught by the handler's except block."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue("not valid json")
# Should not raise
await gateway._handle_client_messages(ws)
assert len(ws.messages) == 0
# =============================================================================
# Backtest handling
# =============================================================================
@pytest.mark.asyncio
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
gateway = make_gateway(config={"mode": "live"})
# Should not raise - backtest is ignored in live mode
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
# Gateway should not have started a backtest task
assert gateway._backtest_task is None
@pytest.mark.asyncio
async def test_handle_start_backtest_ignored_when_task_already_running():
gateway = make_gateway(config={"mode": "backtest"})
# Pre-set a backtest task
dummy_task = MagicMock()
dummy_task.done.return_value = False
gateway._backtest_task = dummy_task
# Should not start a new task
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
assert gateway._backtest_task is dummy_task # unchanged
# =============================================================================
# Manual trigger (live/mock mode)
# =============================================================================
@pytest.mark.asyncio
async def test_handle_manual_trigger_rejected_in_backtest_mode():
gateway = make_gateway(config={"mode": "backtest"})
ws = DummyWebSocket()
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
@pytest.mark.asyncio
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
gateway = make_gateway(config={"mode": "live"})
ws = DummyWebSocket()
# Simulate a running cycle task
dummy_task = MagicMock()
dummy_task.done.return_value = False
gateway._manual_cycle_task = dummy_task
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
# =============================================================================
# Normalization helpers
# =============================================================================
def test_normalize_watchlist_filters_empty_and_dedupes():
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
assert result == ["AAPL", "MSFT"]
def test_normalize_watchlist_handles_string_input():
result = Gateway._normalize_watchlist("aapl, msft, aapl")
assert result == ["AAPL", "MSFT"]
def test_normalize_agent_workspace_filename_allows_editable_files():
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
result = Gateway._normalize_agent_workspace_filename(filename)
assert result == filename
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
result = Gateway._normalize_agent_workspace_filename("README.md")
assert result is None
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
assert result is None
# =============================================================================
# Broadcast
# =============================================================================
@pytest.mark.asyncio
async def test_broadcast_skips_when_no_clients():
gateway = make_gateway()
gateway.connected_clients = set()
# Should not raise
await gateway.broadcast({"type": "test"})
@pytest.mark.asyncio
async def test_broadcast_sends_to_all_connected_clients():
gateway = make_gateway()
ws1 = DummyWebSocket()
ws2 = DummyWebSocket()
gateway.connected_clients = {ws1, ws2}
await gateway.broadcast({"type": "market_update", "data": "test"})
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
assert ws1.messages[0]["data"] == "test"
assert ws2.messages[0]["data"] == "test"
@pytest.mark.asyncio
async def test_broadcast_removes_closed_connections():
"""Verify closed connections are removed from connected_clients set.
The broadcast method's _send_to_client helper removes a client
when it raises websockets.ConnectionClosed.
"""
gateway = make_gateway()
closed_ws = DummyWebSocket()
open_ws = DummyWebSocket()
gateway.connected_clients = {closed_ws, open_ws}
# Make closed_ws.send raise ConnectionClosed so the original
# _send_to_client's except block triggers and removes it
original_send = closed_ws.send
async def raising_send(payload):
raise gateway_module.websockets.ConnectionClosed(None, None)
closed_ws.send = raising_send
try:
await gateway.broadcast({"type": "test"})
except gateway_module.websockets.ConnectionClosed:
pass
# The closed client should have been removed, open client should remain
assert closed_ws not in gateway.connected_clients
assert open_ws in gateway.connected_clients
@pytest.mark.asyncio
async def test_broadcast_sends_to_all_connected_clients():
"""Verify broadcast sends to all connected clients and collects results."""
gateway = make_gateway()
ws1 = DummyWebSocket()
ws2 = DummyWebSocket()
gateway.connected_clients = {ws1, ws2}
await gateway.broadcast({"type": "market_update", "data": "test"})
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
assert ws1.messages[0]["data"] == "test"
assert ws2.messages[0]["data"] == "test"
# =============================================================================
# Stop
# =============================================================================
def test_stop_gateway_calls_cycle_support():
gateway = make_gateway()
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
gateway.stop()
mock.assert_called_once_with(gateway)
# =============================================================================
# set_backtest_dates
# =============================================================================
def test_set_backtest_dates_delegates_to_cycle_support():
gateway = make_gateway()
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
# =============================================================================
# Provider usage change callback
# =============================================================================
def test_on_provider_usage_changed_updates_state_sync():
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
gateway = make_gateway()
gateway._loop = None # no loop set
snapshot = {"provider": "finnhub", "calls": 10}
gateway._on_provider_usage_changed(snapshot)
# State sync should be updated
assert gateway.state_sync.state.get("data_sources") == snapshot
# =============================================================================
# handle_client lifecycle
# =============================================================================
@pytest.mark.asyncio
async def test_handle_client_adds_and_removes_client_from_connected_set():
gateway = make_gateway()
ws = DummyWebSocket()
with patch.object(gateway, "_send_initial_state", AsyncMock()):
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
await gateway.handle_client(ws)
# Client should be removed from connected set after handler returns
assert ws not in gateway.connected_clients
@pytest.mark.asyncio
async def test_handle_client_adds_client_before_handler():
gateway = make_gateway()
ws = DummyWebSocket()
with patch.object(gateway, "_send_initial_state", AsyncMock()):
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
await gateway.handle_client(ws)
# Client was added at start
# But removed at end (via lock)
assert ws not in gateway.connected_clients

View File

@@ -1,14 +1,31 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Tests for the extracted runtime service app surface.""" """Tests for the extracted runtime service app surface."""
import asyncio
import json import json
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from backend.api import runtime as runtime_module from backend.api import runtime as runtime_module
from backend.apps.runtime_service import create_app from backend.apps.runtime_service import create_app
@pytest.fixture(autouse=True)
def reset_runtime_module_state():
"""Reset module-level runtime_manager before each test."""
runtime_module.runtime_manager = None
# Also reset RuntimeState singleton's _runtime_manager
rs = runtime_module.get_runtime_state()
rs._runtime_manager = None
yield
runtime_module.runtime_manager = None
rs = runtime_module.get_runtime_state()
rs._runtime_manager = None
def test_runtime_service_routes_are_exposed(): def test_runtime_service_routes_are_exposed():
app = create_app() app = create_app()
paths = {route.path for route in app.routes} paths = {route.path for route in app.routes}
@@ -153,7 +170,9 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
) )
class _DummyContext: class _DummyContext:
def __init__(self): def __init__(self, run_dir):
self.config_name = "demo"
self.run_dir = run_dir
self.bootstrap_values = { self.bootstrap_values = {
"tickers": ["AAPL"], "tickers": ["AAPL"],
"schedule_mode": "daily", "schedule_mode": "daily",
@@ -165,8 +184,17 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
class _DummyManager: class _DummyManager:
def __init__(self): def __init__(self):
self.config_name = "demo" self.config_name = "demo"
self.bootstrap = dict(_DummyContext().bootstrap_values) self.bootstrap = dict(_DummyContext(run_dir).bootstrap_values)
self.context = _DummyContext() self.context = _DummyContext(run_dir)
def build_snapshot(self):
return {
"context": {
"config_name": self.context.config_name,
"run_dir": str(self.context.run_dir),
"bootstrap_values": self.context.bootstrap_values,
}
}
def _persist_snapshot(self): def _persist_snapshot(self):
return None return None
@@ -192,3 +220,385 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
assert payload["bootstrap"]["schedule_mode"] == "intraday" assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 15 assert payload["resolved"]["interval_minutes"] == 15
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8") assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
# =============================================================================
# RuntimeState singleton unit tests
# =============================================================================
def test_runtime_state_is_singleton():
"""RuntimeState.__new__ returns the same instance across calls."""
state1 = runtime_module.RuntimeState()
state2 = runtime_module.RuntimeState()
assert state1 is state2
def test_runtime_state_get_runtime_state_returns_same_instance():
"""get_runtime_state() returns the module singleton."""
instance = runtime_module.get_runtime_state()
assert instance is runtime_module._runtime_state
def test_runtime_state_default_values():
"""RuntimeState initializes with sensible defaults on first instantiation."""
# Reset singleton to get fresh __init__ values
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
assert state._runtime_manager is None
assert state._gateway_process is None
assert state._gateway_port == 8765
def test_runtime_state_gateway_port_property():
"""gateway_port property getter and setter work correctly."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
state.gateway_port = 9999
assert state.gateway_port == 9999
state.gateway_port = 1234
assert state.gateway_port == 1234
def test_runtime_state_gateway_process_property():
"""gateway_process property getter and setter work correctly."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
assert state.gateway_process is None
fake_process = object()
state.gateway_process = fake_process
assert state.gateway_process is fake_process
state.gateway_process = None
assert state.gateway_process is None
def test_runtime_state_runtime_manager_property():
"""runtime_manager property getter and setter work correctly."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
assert state.runtime_manager is None
fake_manager = object()
state.runtime_manager = fake_manager
assert state.runtime_manager is fake_manager
state.runtime_manager = None
assert state.runtime_manager is None
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_runtime_state_lock_property_is_async():
"""lock is an async property that returns a coroutine producing an asyncio.Lock."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
lock_coro = state.lock
assert asyncio.iscoroutine(lock_coro)
@pytest.mark.asyncio
async def test_runtime_state_async_set_get_gateway_port():
"""Async setters and getters for gateway_port with lock protection."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
await state.set_gateway_port(8888)
assert await state.get_gateway_port() == 8888
await state.set_gateway_port(7777)
assert await state.get_gateway_port() == 7777
@pytest.mark.asyncio
async def test_runtime_state_async_set_get_gateway_process():
"""Async setters and getters for gateway_process with lock protection."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
await state.set_gateway_process(None)
assert await state.get_gateway_process() is None
fake_process = object()
await state.set_gateway_process(fake_process)
assert await state.get_gateway_process() is fake_process
@pytest.mark.asyncio
async def test_runtime_state_async_set_get_runtime_manager():
"""Async setters and getters for runtime_manager with lock protection."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
await state.set_runtime_manager(None)
assert await state.get_runtime_manager() is None
fake_manager = object()
await state.set_runtime_manager(fake_manager)
assert await state.get_runtime_manager() is fake_manager
# =============================================================================
# _is_gateway_running helper tests
# =============================================================================
def test_is_gateway_running_returns_false_when_process_is_none():
"""_is_gateway_running returns False when gateway_process is None."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
new_state = runtime_module.RuntimeState()
new_state._gateway_process = None
runtime_module._runtime_state = new_state
assert runtime_module._is_gateway_running() is False
def test_is_gateway_running_returns_false_when_process_exited():
"""_is_gateway_running returns False when process has terminated."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
runtime_module._runtime_state = state
mock_process = MagicMock()
mock_process.poll.return_value = 1 # non-None = process has exited
state._gateway_process = mock_process
assert runtime_module._is_gateway_running() is False
def test_is_gateway_running_returns_true_when_process_running():
"""_is_gateway_running returns True when process is alive."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
runtime_module._runtime_state = state
mock_process = MagicMock()
mock_process.poll.return_value = None # None = still running
state._gateway_process = mock_process
assert runtime_module._is_gateway_running() is True
# =============================================================================
# _stop_gateway helper tests
# =============================================================================
def test_stop_gateway_returns_false_when_no_process():
"""_stop_gateway returns False if no gateway process exists."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
new_state = runtime_module.RuntimeState()
new_state._gateway_process = None
runtime_module._runtime_state = new_state
result = runtime_module._stop_gateway()
assert result is False
def test_stop_gateway_sets_process_to_none_after_stop():
"""_stop_gateway sets _gateway_process to None after stopping."""
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
runtime_module._runtime_state = state
mock_process = MagicMock()
mock_process.poll.return_value = None
mock_process.wait.return_value = 0
state._gateway_process = mock_process
result = runtime_module._stop_gateway()
assert result is True
assert state._gateway_process is None
mock_process.terminate.assert_called_once()
mock_process.wait.assert_called_once()
def test_stop_gateway_kills_when_terminate_times_out():
"""_stop_gateway kills the process if terminate times out."""
import subprocess
runtime_module.RuntimeState._instance = None
runtime_module.RuntimeState._lock = asyncio.Lock()
state = runtime_module.RuntimeState()
runtime_module._runtime_state = state
mock_process = MagicMock()
mock_process.poll.return_value = None
mock_process.wait.side_effect = subprocess.TimeoutExpired("cmd", 5)
mock_process.kill.return_value = None
state._gateway_process = mock_process
result = runtime_module._stop_gateway()
assert result is True
assert state._gateway_process is None
mock_process.kill.assert_called_once()
# =============================================================================
# _build_gateway_ws_url helper tests
# =============================================================================
def test_build_gateway_ws_url_defaults_to_ws():
from fastapi import Request
mock_request = MagicMock(spec=Request)
mock_request.headers.get.side_effect = lambda k, d="": d
mock_request.url.scheme = "http"
mock_request.url.hostname = "localhost"
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
assert url == "ws://localhost:8765"
def test_build_gateway_ws_url_uses_wss_for_https():
from fastapi import Request
mock_request = MagicMock(spec=Request)
mock_request.headers.get.side_effect = lambda k, d="": d
mock_request.url.scheme = "https"
mock_request.url.hostname = "example.com"
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
assert url == "wss://example.com:8765"
def test_build_gateway_ws_url_respects_forwarded_proto():
from fastapi import Request
mock_request = MagicMock(spec=Request)
def header_get(key, default=""):
if key == "x-forwarded-proto":
return "https"
return default
mock_request.headers.get.side_effect = header_get
mock_request.url.scheme = "http"
mock_request.url.hostname = "internal.example"
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
assert url == "wss://internal.example:8765"
def test_build_gateway_ws_url_respects_forwarded_host():
from fastapi import Request
mock_request = MagicMock(spec=Request)
mock_request.headers.get.side_effect = lambda k, d="": {
"x-forwarded-host": "external.example.com"
}.get(k, d)
mock_request.url.scheme = "http"
mock_request.url.hostname = "internal.example"
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
assert url == "ws://external.example.com:8765"
# =============================================================================
# _normalize_runtime_config_updates tests
# =============================================================================
def test_normalize_runtime_config_updates_validates_schedule_mode():
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode="invalid")
with pytest.raises(HTTPException) as exc_info:
runtime_module._normalize_runtime_config_updates(req)
assert "schedule_mode" in str(exc_info.value.detail).lower()
def test_normalize_runtime_config_updates_validates_schedule_mode_values():
for invalid in ["weekly", "monthly", "once"]:
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=invalid)
with pytest.raises(HTTPException):
runtime_module._normalize_runtime_config_updates(req)
def test_normalize_runtime_config_updates_accepts_daily_and_intraday():
for valid in ["daily", "intraday", "DAILY", "IntraDay"]:
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=valid)
result = runtime_module._normalize_runtime_config_updates(req)
assert "schedule_mode" in result
def test_normalize_runtime_config_updates_validates_trigger_time_format():
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="25:99")
with pytest.raises(HTTPException) as exc_info:
runtime_module._normalize_runtime_config_updates(req)
assert "trigger_time" in str(exc_info.value.detail).lower()
def test_normalize_runtime_config_updates_accepts_now_trigger_time():
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="now")
result = runtime_module._normalize_runtime_config_updates(req)
assert result["trigger_time"] == "now"
def test_normalize_runtime_config_updates_defaults_empty_trigger_time():
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time=" ")
result = runtime_module._normalize_runtime_config_updates(req)
assert result["trigger_time"] == "09:30"
def test_normalize_runtime_config_updates_rejects_no_updates():
req = runtime_module.UpdateRuntimeConfigRequest()
with pytest.raises(HTTPException) as exc_info:
runtime_module._normalize_runtime_config_updates(req)
assert "no runtime config updates" in str(exc_info.value.detail).lower()
def test_normalize_runtime_config_updates_coerces_types():
req = runtime_module.UpdateRuntimeConfigRequest(
schedule_mode="intraday",
interval_minutes="30", # string from JSON
initial_cash="50000.0", # string from JSON
margin_requirement="0.25",
)
result = runtime_module._normalize_runtime_config_updates(req)
assert result["schedule_mode"] == "intraday"
assert result["interval_minutes"] == 30
assert result["initial_cash"] == 50000.0
assert result["margin_requirement"] == 0.25
# =============================================================================
# register_runtime_manager / unregister_runtime_manager tests
# =============================================================================
def test_register_runtime_manager_sets_module_and_singleton():
runtime_module._runtime_state._initialized = True # prevent re-init
fake_manager = object()
runtime_module.register_runtime_manager(fake_manager)
assert runtime_module.runtime_manager is fake_manager
assert runtime_module._runtime_state.runtime_manager is fake_manager
def test_unregister_runtime_manager_clears_module_and_singleton():
runtime_module._runtime_state._initialized = True # prevent re-init
runtime_module._runtime_state.runtime_manager = object()
runtime_module.runtime_manager = object()
runtime_module.unregister_runtime_manager()
assert runtime_module.runtime_manager is None
assert runtime_module._runtime_state.runtime_manager is None
# =============================================================================
# _generate_run_id tests
# =============================================================================
def test_generate_run_id_returns_timestamp_format():
run_id = runtime_module._generate_run_id()
# Format: YYYYMMDD_HHMMSS - length is 15
assert len(run_id) == 15
assert run_id[8] == "_" # separator between date and time
assert run_id[:8].isdigit() # YYYYMMDD
assert run_id[9:].isdigit() # HHMMSS

View File

@@ -1,47 +1,21 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Unit tests for the trading domain helpers.""" """Unit tests for data_tools functions (replaces the deleted trading_domain)."""
from backend.domains import trading as trading_domain from backend.tools.data_tools import (
get_company_news,
get_financial_metrics,
get_insider_trades,
get_market_cap,
get_prices,
search_line_items,
)
def test_trading_domain_payload_wrappers(monkeypatch): def test_data_tools_functions_exist():
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}]) """Verify that all data_tools functions are importable and callable."""
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}]) assert callable(get_prices)
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}]) assert callable(get_financial_metrics)
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}]) assert callable(get_company_news)
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12) assert callable(get_insider_trades)
assert callable(get_market_cap)
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == { assert callable(search_line_items)
"ticker": "AAPL",
"prices": [{"close": 1}],
}
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
"financial_metrics": [{"ticker": "AAPL"}],
}
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
"news": [{"ticker": "AAPL"}],
}
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
"insider_trades": [{"ticker": "AAPL"}],
}
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
"ticker": "AAPL",
"end_date": "2026-03-16",
"market_cap": 2.5e12,
}
def test_get_market_status_payload_uses_market_service(monkeypatch):
class _FakeMarketService:
def __init__(self, tickers):
self.tickers = tickers
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
assert trading_domain.get_market_status_payload() == {
"status": "open",
"status_text": "Open",
}

View File

@@ -24,10 +24,8 @@ def test_trading_service_routes_are_exposed():
def test_trading_service_prices_endpoint(monkeypatch): def test_trading_service_prices_endpoint(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_prices_payload", "backend.apps.trading_service.get_prices",
lambda ticker, start_date, end_date: { lambda ticker, start_date, end_date: [
"ticker": ticker,
"prices": [
Price( Price(
open=1.0, open=1.0,
close=2.0, close=2.0,
@@ -37,7 +35,6 @@ def test_trading_service_prices_endpoint(monkeypatch):
time="2026-03-20", time="2026-03-20",
) )
], ],
},
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:
@@ -57,9 +54,8 @@ def test_trading_service_prices_endpoint(monkeypatch):
def test_trading_service_financials_endpoint(monkeypatch): def test_trading_service_financials_endpoint(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_financials_payload", "backend.apps.trading_service.get_financial_metrics",
lambda ticker, end_date, period, limit: { lambda ticker, end_date, period, limit: [
"financial_metrics": [
FinancialMetrics( FinancialMetrics(
ticker=ticker, ticker=ticker,
report_period=end_date, report_period=end_date,
@@ -105,8 +101,7 @@ def test_trading_service_financials_endpoint(monkeypatch):
book_value_per_share=None, book_value_per_share=None,
free_cash_flow_per_share=None, free_cash_flow_per_share=None,
) )
] ],
},
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:
@@ -121,9 +116,8 @@ def test_trading_service_financials_endpoint(monkeypatch):
def test_trading_service_news_and_insider_endpoints(monkeypatch): def test_trading_service_news_and_insider_endpoints(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_news_payload", "backend.apps.trading_service.get_company_news",
lambda ticker, end_date, start_date=None, limit=1000: { lambda ticker, end_date, start_date=None, limit=1000: [
"news": [
CompanyNews( CompanyNews(
ticker=ticker, ticker=ticker,
title="News title", title="News title",
@@ -131,16 +125,13 @@ def test_trading_service_news_and_insider_endpoints(monkeypatch):
url="https://example.com/news", url="https://example.com/news",
date=end_date, date=end_date,
) )
] ],
},
) )
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_insider_trades_payload", "backend.apps.trading_service.get_insider_trades",
lambda ticker, end_date, start_date=None, limit=1000: { lambda ticker, end_date, start_date=None, limit=1000: [
"insider_trades": [
InsiderTrade(ticker=ticker, filing_date=end_date) InsiderTrade(ticker=ticker, filing_date=end_date)
] ],
},
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:
@@ -165,8 +156,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
return {"status": "open", "status_text": "Open"} return {"status": "open", "status_text": "Open"}
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_market_status_payload", "backend.apps.trading_service.MarketService",
lambda: _FakeMarketService().get_market_status(), lambda tickers: _FakeMarketService(),
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:
@@ -178,12 +169,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
def test_trading_service_market_cap_endpoint(monkeypatch): def test_trading_service_market_cap_endpoint(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_market_cap_payload", "backend.apps.trading_service.get_market_cap",
lambda ticker, end_date: { lambda ticker, end_date: 3.5e12,
"ticker": ticker,
"end_date": end_date,
"market_cap": 3.5e12,
},
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:
@@ -202,9 +189,8 @@ def test_trading_service_market_cap_endpoint(monkeypatch):
def test_trading_service_line_items_endpoint(monkeypatch): def test_trading_service_line_items_endpoint(monkeypatch):
monkeypatch.setattr( monkeypatch.setattr(
"backend.domains.trading.get_line_items_payload", "backend.apps.trading_service.search_line_items",
lambda ticker, line_items, end_date, period, limit: { lambda ticker, line_items, end_date, period, limit: [
"search_results": [
LineItem( LineItem(
ticker=ticker, ticker=ticker,
report_period=end_date, report_period=end_date,
@@ -212,8 +198,7 @@ def test_trading_service_line_items_endpoint(monkeypatch):
currency="USD", currency="USD",
free_cash_flow=123.0, free_cash_flow=123.0,
) )
] ],
},
) )
with TestClient(create_app()) as client: with TestClient(create_app()) as client:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,18 @@
import React from 'react';
export default function ChartTabs({
chartTab,
setChartTab,
isLiveEnabled
}) {
return (
<div className="chart-tabs-floating">
<button
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
onClick={() => setChartTab('all')}
>
日线
</button>
</div>
);
}

View File

@@ -0,0 +1,293 @@
import React from 'react';
import RuntimeSettingsPanel from './RuntimeSettingsPanel.jsx';
export default function HeaderRight({
// Connection state
isConnected,
// Virtual time
virtualTime,
now,
// Market & server
marketStatus,
marketStatusLabel,
serverMode,
// Labels
runtimeSummaryLabel,
livePriceSourceLabel,
historicalPriceSourceLabel,
// Settings state
isRuntimeSettingsOpen,
isRuntimeConfigSaving,
isWatchlistSaving,
runtimeConfigFeedback,
watchlistFeedback,
// Settings panel props
scheduleModeDraft,
intervalMinutesDraft,
triggerTimeDraft,
maxCommCyclesDraft,
initialCashDraft,
marginRequirementDraft,
enableMemoryDraft,
modeDraft,
pollIntervalDraft,
startDateDraft,
endDateDraft,
enableMockDraft,
watchlistDraftSymbols,
watchlistInputValue,
watchlistSuggestions,
// Callbacks
onRuntimeSettingsToggle,
onCloseSettings,
onScheduleModeChange,
onIntervalMinutesChange,
onTriggerTimeChange,
onMaxCommCyclesChange,
onInitialCashChange,
onMarginRequirementChange,
onEnableMemoryChange,
onModeChange,
onPollIntervalChange,
onStartDateChange,
onEndDateChange,
onEnableMockChange,
onWatchlistInputChange,
onWatchlistInputKeyDown,
onWatchlistAdd,
onWatchlistRemove,
onWatchlistRestoreCurrent,
onWatchlistRestoreDefault,
onWatchlistSuggestionClick,
onLaunchConfigSave,
onRestoreDefaults,
onManualTrigger,
clientRef
}) {
return (
<div className="header-right" style={{ display: 'flex', alignItems: 'center', gap: 24, marginLeft: 'auto', flexWrap: 'wrap', minWidth: 0 }}>
{/* Mock Mode Indicator */}
{virtualTime && (
<div style={{
display: 'flex',
alignItems: 'center',
gap: 6,
padding: '4px 10px',
borderRadius: 4,
background: '#FF9800',
border: '1px solid #FFB74D'
}}>
<span style={{
fontSize: '11px',
fontWeight: 600,
color: '#FFFFFF',
fontFamily: '"Courier New", monospace',
letterSpacing: '0.5px'
}}>
模拟模式
</span>
</div>
)}
{/* Clock Display (only in Mock mode) */}
{virtualTime && (
<div style={{
display: 'flex',
alignItems: 'center',
gap: 8
}}>
<div style={{
display: 'flex',
flexDirection: 'column',
alignItems: 'flex-end',
gap: 2,
padding: '4px 12px',
borderRadius: 4,
background: '#1A237E',
border: '1px solid #3F51B5'
}}>
<span style={{
fontSize: '11px',
color: '#999',
fontFamily: '"Courier New", monospace',
textTransform: 'uppercase',
letterSpacing: '0.5px'
}}>
虚拟时间
</span>
<span style={{
fontSize: '14px',
fontWeight: 700,
color: '#FFFFFF',
fontFamily: '"Courier New", monospace',
letterSpacing: '1px'
}}>
{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}
</span>
<span style={{
fontSize: '10px',
color: '#999',
fontFamily: '"Courier New", monospace'
}}>
{now.toLocaleDateString('en-US', { month: 'short', day: 'numeric', year: 'numeric' })}
</span>
</div>
{/* Fast Forward Button (only in Mock mode) */}
<button
onClick={() => {
if (clientRef.current) {
const success = clientRef.current.send({
type: 'fast_forward_time',
minutes: 30
});
if (!success) {
console.error('Failed to send fast forward request');
}
}
}}
style={{
padding: '6px 12px',
borderRadius: 4,
background: '#3F51B5',
border: '1px solid #5C6BC0',
color: '#FFFFFF',
fontSize: '12px',
fontFamily: '"Courier New", monospace',
fontWeight: 600,
cursor: 'pointer',
transition: 'all 0.2s',
display: 'flex',
alignItems: 'center',
gap: 4,
textTransform: 'uppercase',
letterSpacing: '0.5px'
}}
onMouseEnter={(e) => {
e.target.style.background = '#5C6BC0';
e.target.style.borderColor = '#7986CB';
}}
onMouseLeave={(e) => {
e.target.style.background = '#3F51B5';
e.target.style.borderColor = '#5C6BC0';
}}
title="快进30分钟 (Mock模式)"
>
+30min
</button>
</div>
)}
{/* Unified Status Indicator */}
<div className="header-status-inline">
<span className={`status-dot ${isConnected ? 'live' : 'offline'}`} />
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
{isConnected ? '在线' : '离线'}
</span>
{marketStatus && (
<>
<span className="status-sep">·</span>
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
{marketStatusLabel}
</span>
</>
)}
{livePriceSourceLabel && (
<>
<span className="status-sep">·</span>
<span className="market-text backtest">
{livePriceSourceLabel}
</span>
</>
)}
{historicalPriceSourceLabel && (
<>
<span className="status-sep">·</span>
<span className="market-text backtest">
{historicalPriceSourceLabel}
</span>
</>
)}
{runtimeSummaryLabel && (
<>
<span className="status-sep">·</span>
<span className="market-text backtest" title="当前运行配置">
{runtimeSummaryLabel}
</span>
</>
)}
<span className="status-sep">·</span>
<span className="time-text">{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}</span>
</div>
{serverMode !== 'backtest' && (
<button
onClick={onManualTrigger}
disabled={!isConnected}
style={{
padding: '6px 12px',
borderRadius: 4,
background: isConnected ? '#111111' : '#8a8a8a',
border: '1px solid #111111',
color: '#FFFFFF',
fontSize: '11px',
fontFamily: '"Courier New", monospace',
fontWeight: 700,
cursor: isConnected ? 'pointer' : 'not-allowed',
letterSpacing: '0.4px',
textTransform: 'uppercase'
}}
title="手动触发一轮分析与交易决策"
>
手动运行
</button>
)}
<RuntimeSettingsPanel
showTrigger={false}
isOpen={isRuntimeSettingsOpen}
isConnected={isConnected}
isSaving={isRuntimeConfigSaving || isWatchlistSaving}
feedback={runtimeConfigFeedback || watchlistFeedback}
scheduleMode={scheduleModeDraft}
intervalMinutes={intervalMinutesDraft}
triggerTime={triggerTimeDraft}
maxCommCycles={maxCommCyclesDraft}
initialCash={initialCashDraft}
marginRequirement={marginRequirementDraft}
enableMemory={enableMemoryDraft}
mode={modeDraft}
pollInterval={pollIntervalDraft}
startDate={startDateDraft}
endDate={endDateDraft}
enableMock={enableMockDraft}
watchlistSymbols={watchlistDraftSymbols}
watchlistInputValue={watchlistInputValue}
watchlistSuggestions={watchlistSuggestions}
onToggle={onRuntimeSettingsToggle}
onClose={onCloseSettings}
onScheduleModeChange={onScheduleModeChange}
onIntervalMinutesChange={onIntervalMinutesChange}
onTriggerTimeChange={onTriggerTimeChange}
onMaxCommCyclesChange={onMaxCommCyclesChange}
onInitialCashChange={onInitialCashChange}
onMarginRequirementChange={onMarginRequirementChange}
onEnableMemoryChange={onEnableMemoryChange}
onModeChange={onModeChange}
onPollIntervalChange={onPollIntervalChange}
onStartDateChange={onStartDateChange}
onEndDateChange={onEndDateChange}
onEnableMockChange={onEnableMockChange}
onWatchlistInputChange={onWatchlistInputChange}
onWatchlistInputKeyDown={onWatchlistInputKeyDown}
onWatchlistAdd={onWatchlistAdd}
onWatchlistRemove={onWatchlistRemove}
onWatchlistRestoreCurrent={onWatchlistRestoreCurrent}
onWatchlistRestoreDefault={onWatchlistRestoreDefault}
onWatchlistSuggestionClick={onWatchlistSuggestionClick}
onSave={onLaunchConfigSave}
onRestoreDefaults={onRestoreDefaults}
/>
</div>
);
}

View File

@@ -0,0 +1,52 @@
import React from 'react';
import StockLogo from './StockLogo';
import { formatNumber, formatTickerPrice } from '../utils/formatters';
export default function TickerBar({
displayTickers,
rollingTickers,
portfolioData,
onTickerSelect
}) {
return (
<div className="ticker-bar">
<div className="ticker-track">
{[0, 1].map((groupIdx) => (
<div key={groupIdx} className="ticker-group">
{displayTickers.map(ticker => (
<div
key={`${ticker.symbol}-${groupIdx}`}
className="ticker-item"
onClick={() => onTickerSelect && onTickerSelect(ticker.symbol)}
style={{ cursor: onTickerSelect ? 'pointer' : 'default' }}
>
<StockLogo ticker={ticker.symbol} size={16} />
<span className="ticker-symbol">{ticker.symbol}</span>
<span className="ticker-price">
<span className={`ticker-price-value ${rollingTickers[ticker.symbol] ? 'rolling' : ''}`}>
{ticker.price !== null && ticker.price !== undefined
? `$${formatTickerPrice(ticker.price)}`
: '-'}
</span>
</span>
<span className={`ticker-change ${
ticker.change === null || ticker.change === undefined
? ''
: ticker.change >= 0 ? 'positive' : 'negative'
}`}>
{ticker.change !== null && ticker.change !== undefined
? `${ticker.change >= 0 ? '+' : ''}${ticker.change.toFixed(2)}%`
: '-'}
</span>
</div>
))}
</div>
))}
</div>
<div className="portfolio-value">
<span className="portfolio-label">投资组合</span>
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
</div>
</div>
);
}

View File

@@ -0,0 +1,308 @@
import { useCallback, useEffect } from 'react';
import { uploadAgentSkillZip } from '../services/runtimeApi';
/**
* Extracts agent/skill-related callbacks from App.jsx into a single hook.
*/
export function useAgentCallbacks({
clientRef,
selectedSkillAgentId,
selectedWorkspaceFile,
workspaceDraftContent,
agentProfilesByAgent,
agentSkillsByAgent,
workspaceFilesByAgent,
AGENTS,
setters
}) {
const {
setIsAgentSkillsLoading,
setAgentSkillsFeedback,
setSkillDetailLoadingKey,
setAgentSkillsSavingKey,
setIsWorkspaceFileLoading,
setWorkspaceFileSavingKey,
setWorkspaceFileFeedback,
setLocalSkillDraftsByKey,
setAgentSkillsByAgent,
setAgentProfilesByAgent,
setSkillDetailsByName,
setWorkspaceFilesByAgent,
setSelectedSkillAgentId,
setSelectedWorkspaceFile,
setWorkspaceDraftContent
} = setters;
const requestAgentSkills = useCallback((agentId) => {
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
if (!normalized || !clientRef.current) {
return false;
}
setIsAgentSkillsLoading(true);
setAgentSkillsFeedback(null);
return clientRef.current.send({
type: 'get_agent_skills',
agent_id: normalized
});
}, [clientRef, setIsAgentSkillsLoading, setAgentSkillsFeedback]);
const requestAgentProfile = useCallback((agentId) => {
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_agent_profile',
agent_id: normalized
});
}, [clientRef]);
const requestSkillDetail = useCallback((skillName) => {
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
if (!normalized || !clientRef.current) {
return false;
}
const detailKey = `${selectedSkillAgentId}:${normalized}`;
setSkillDetailLoadingKey(detailKey);
return clientRef.current.send({
type: 'get_skill_detail',
agent_id: selectedSkillAgentId,
skill_name: normalized
});
}, [clientRef, selectedSkillAgentId, setSkillDetailLoadingKey]);
const requestWorkspaceFile = useCallback((agentId, filename) => {
const normalizedAgentId = typeof agentId === 'string' ? agentId.trim() : '';
const normalizedFilename = typeof filename === 'string' ? filename.trim() : '';
if (!normalizedAgentId || !normalizedFilename || !clientRef.current) {
return false;
}
setIsWorkspaceFileLoading(true);
setWorkspaceFileFeedback(null);
return clientRef.current.send({
type: 'get_agent_workspace_file',
agent_id: normalizedAgentId,
filename: normalizedFilename
});
}, [clientRef, setIsWorkspaceFileLoading, setWorkspaceFileFeedback]);
const handleCreateLocalSkill = useCallback((skillName) => {
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
if (!normalized) {
setAgentSkillsFeedback({ type: 'error', text: '技能名称不能为空' });
return;
}
if (!clientRef.current) {
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${normalized}:create`);
setAgentSkillsFeedback(null);
const success = clientRef.current.send({
type: 'create_agent_local_skill',
agent_id: selectedSkillAgentId,
skill_name: normalized
});
if (!success) {
setAgentSkillsSavingKey(null);
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
const handleLocalSkillDraftChange = useCallback((skillName, content) => {
const detailKey = `${selectedSkillAgentId}:${skillName}`;
setLocalSkillDraftsByKey((prev) => ({
...prev,
[detailKey]: content
}));
}, [selectedSkillAgentId, setLocalSkillDraftsByKey]);
const handleLocalSkillSave = useCallback((skillName) => {
if (!clientRef.current) {
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
const detailKey = `${selectedSkillAgentId}:${skillName}`;
const content = setters.localSkillDraftsByKey[detailKey];
if (typeof content !== 'string') {
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:content`);
setAgentSkillsFeedback(null);
const success = clientRef.current.send({
type: 'update_agent_local_skill',
agent_id: selectedSkillAgentId,
skill_name: skillName,
content
});
if (!success) {
setAgentSkillsSavingKey(null);
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, setters.localSkillDraftsByKey, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
const handleLocalSkillDelete = useCallback((skillName) => {
if (!clientRef.current) {
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:delete`);
setAgentSkillsFeedback(null);
const success = clientRef.current.send({
type: 'delete_agent_local_skill',
agent_id: selectedSkillAgentId,
skill_name: skillName
});
if (!success) {
setAgentSkillsSavingKey(null);
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
const handleRemoveSharedSkill = useCallback((skillName) => {
if (!clientRef.current) {
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:remove`);
setAgentSkillsFeedback(null);
const success = clientRef.current.send({
type: 'remove_agent_skill',
agent_id: selectedSkillAgentId,
skill_name: skillName
});
if (!success) {
setAgentSkillsSavingKey(null);
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
const handleAgentSkillToggle = useCallback((skillName, enabled) => {
if (!clientRef.current) {
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}`);
setAgentSkillsFeedback(null);
const success = clientRef.current.send({
type: 'update_agent_skill',
agent_id: selectedSkillAgentId,
skill_name: skillName,
enabled
});
if (!success) {
setAgentSkillsSavingKey(null);
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
const handleSkillAgentChange = useCallback((agentId) => {
setSelectedSkillAgentId(agentId);
requestAgentProfile(agentId);
requestAgentSkills(agentId);
requestWorkspaceFile(agentId, selectedWorkspaceFile);
}, [requestAgentProfile, requestAgentSkills, requestWorkspaceFile, selectedWorkspaceFile, setSelectedSkillAgentId]);
const handleWorkspaceFileChange = useCallback((filename) => {
setSelectedWorkspaceFile(filename);
requestWorkspaceFile(selectedSkillAgentId, filename);
}, [requestWorkspaceFile, selectedSkillAgentId, setSelectedWorkspaceFile]);
const handleWorkspaceFileSave = useCallback(() => {
if (!clientRef.current) {
setWorkspaceFileFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
const key = `${selectedSkillAgentId}:${selectedWorkspaceFile}`;
setWorkspaceFileSavingKey(key);
setWorkspaceFileFeedback(null);
const success = clientRef.current.send({
type: 'update_agent_workspace_file',
agent_id: selectedSkillAgentId,
filename: selectedWorkspaceFile,
content: workspaceDraftContent
});
if (!success) {
setWorkspaceFileSavingKey(null);
setWorkspaceFileFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [clientRef, selectedSkillAgentId, selectedWorkspaceFile, workspaceDraftContent, setWorkspaceFileSavingKey, setWorkspaceFileFeedback]);
const handleUploadExternalSkill = useCallback(async (file) => {
if (!(file instanceof File)) {
setAgentSkillsFeedback({ type: 'error', text: '请选择 zip 文件后再上传' });
return;
}
if (!selectedSkillAgentId) {
setAgentSkillsFeedback({ type: 'error', text: '未选择目标 Agent' });
return;
}
setAgentSkillsSavingKey(`${selectedSkillAgentId}:__upload__`);
setAgentSkillsFeedback(null);
try {
const result = await uploadAgentSkillZip({
agentId: selectedSkillAgentId,
file,
activate: true
});
setAgentSkillsFeedback({
type: 'success',
text: `已上传并安装技能 ${result.skill_name || ''}`.trim()
});
requestAgentSkills(selectedSkillAgentId);
} catch (error) {
setAgentSkillsFeedback({
type: 'error',
text: `上传失败: ${error.message || '未知错误'}`
});
} finally {
setAgentSkillsSavingKey(null);
}
}, [selectedSkillAgentId, requestAgentSkills, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
// Sync workspace draft content when selected content changes
useEffect(() => {
const selectedWorkspaceContent = workspaceFilesByAgent[selectedSkillAgentId]?.[selectedWorkspaceFile] || '';
setWorkspaceDraftContent(selectedWorkspaceContent);
}, [selectedWorkspaceFile, selectedSkillAgentId, workspaceFilesByAgent, setWorkspaceDraftContent]);
// Load agent profiles and skills when view changes
const currentView = setters.currentView;
const isConnected = setters.isConnected;
useEffect(() => {
if (currentView !== 'traders' || !isConnected) {
return;
}
AGENTS.forEach((agent) => {
if (!agentProfilesByAgent[agent.id]) {
requestAgentProfile(agent.id);
}
if (!agentSkillsByAgent[agent.id]) {
requestAgentSkills(agent.id);
}
if (!workspaceFilesByAgent[agent.id]?.['MEMORY.md']) {
requestWorkspaceFile(agent.id, 'MEMORY.md');
}
});
}, [agentProfilesByAgent, agentSkillsByAgent, currentView, isConnected, requestAgentProfile, requestAgentSkills, requestWorkspaceFile, workspaceFilesByAgent, AGENTS]);
return {
requestAgentSkills,
requestAgentProfile,
requestSkillDetail,
requestWorkspaceFile,
handleCreateLocalSkill,
handleLocalSkillDraftChange,
handleLocalSkillSave,
handleLocalSkillDelete,
handleRemoveSharedSkill,
handleAgentSkillToggle,
handleSkillAgentChange,
handleWorkspaceFileChange,
handleWorkspaceFileSave,
handleUploadExternalSkill
};
}

View File

@@ -0,0 +1,257 @@
import { useCallback } from 'react';
import { startRuntime } from '../services/runtimeApi';
/**
* Extracts runtime config callbacks from App.jsx into a single hook.
*/
export function useRuntimeCallbacks({
clientRef,
addSystemMessage,
parseWatchlistInput,
setters
}) {
const {
setScheduleModeDraft,
setIntervalMinutesDraft,
setTriggerTimeDraft,
setMaxCommCyclesDraft,
setInitialCashDraft,
setMarginRequirementDraft,
setEnableMemoryDraft,
setModeDraft,
setPollIntervalDraft,
setStartDateDraft,
setEndDateDraft,
setEnableMockDraft,
setRuntimeConfigFeedback,
setIsRuntimeConfigSaving,
setIsWatchlistSaving,
setIsRuntimeSettingsOpen,
watchlistDraftSymbols,
watchlistInputValue,
scheduleModeDraft,
intervalMinutesDraft,
maxCommCyclesDraft,
initialCashDraft,
marginRequirementDraft,
enableMemoryDraft,
modeDraft,
pollIntervalDraft,
startDateDraft,
endDateDraft,
enableMockDraft
} = setters;
const handleRuntimeConfigSave = useCallback(() => {
if (!clientRef.current) {
setRuntimeConfigFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
const interval = Number(intervalMinutesDraft);
const maxCommCycles = Number(maxCommCyclesDraft);
if (!Number.isInteger(interval) || interval <= 0) {
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
return;
}
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
return;
}
setIsRuntimeConfigSaving(true);
setRuntimeConfigFeedback(null);
const success = clientRef.current.send({
type: 'update_runtime_config',
schedule_mode: scheduleModeDraft,
interval_minutes: interval,
trigger_time: triggerTimeDraft,
max_comm_cycles: maxCommCycles,
initial_cash: Number(initialCashDraft),
margin_requirement: Number(marginRequirementDraft),
enable_memory: Boolean(enableMemoryDraft)
});
if (!success) {
setIsRuntimeConfigSaving(false);
setRuntimeConfigFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [
clientRef,
intervalMinutesDraft,
maxCommCyclesDraft,
scheduleModeDraft,
triggerTimeDraft,
initialCashDraft,
marginRequirementDraft,
enableMemoryDraft,
setIsRuntimeConfigSaving,
setRuntimeConfigFeedback
]);
const handleLaunchConfigSave = useCallback(async () => {
const pendingTickers = parseWatchlistInput(watchlistInputValue);
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
if (nextTickers.length === 0) {
setRuntimeConfigFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
return;
}
const interval = Number(intervalMinutesDraft);
const maxCommCycles = Number(maxCommCyclesDraft);
const initialCash = Number(initialCashDraft);
const marginRequirement = Number(marginRequirementDraft);
if (!Number.isInteger(interval) || interval <= 0) {
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
return;
}
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
return;
}
if (!Number.isFinite(initialCash) || initialCash <= 0) {
setRuntimeConfigFeedback({ type: 'error', text: '初始资金必须是正数' });
return;
}
if (!Number.isFinite(marginRequirement) || marginRequirement < 0) {
setRuntimeConfigFeedback({ type: 'error', text: '保证金要求不能为负数' });
return;
}
setIsRuntimeConfigSaving(true);
setIsWatchlistSaving(true);
setRuntimeConfigFeedback(null);
setters.setWatchlistFeedback(null);
setters.setWatchlistDraftSymbols(nextTickers);
setters.setWatchlistInputValue('');
try {
const result = await startRuntime({
tickers: nextTickers,
schedule_mode: scheduleModeDraft,
interval_minutes: interval,
trigger_time: triggerTimeDraft,
max_comm_cycles: maxCommCycles,
initial_cash: initialCash,
margin_requirement: marginRequirement,
enable_memory: Boolean(enableMemoryDraft),
mode: modeDraft || 'live',
poll_interval: Number(pollIntervalDraft) || 10,
start_date: startDateDraft || null,
end_date: endDateDraft || null,
enable_mock: Boolean(enableMockDraft)
});
setIsRuntimeConfigSaving(false);
setIsWatchlistSaving(false);
setIsRuntimeSettingsOpen(false);
setRuntimeConfigFeedback({
type: 'success',
text: `任务已启动: ${result.run_id}`
});
addSystemMessage(`新任务已启动: ${result.run_id}`);
} catch (error) {
setIsRuntimeConfigSaving(false);
setIsWatchlistSaving(false);
setRuntimeConfigFeedback({
type: 'error',
text: `启动失败: ${error.message}`
});
}
}, [
parseWatchlistInput,
watchlistInputValue,
watchlistDraftSymbols,
intervalMinutesDraft,
maxCommCyclesDraft,
initialCashDraft,
marginRequirementDraft,
enableMemoryDraft,
scheduleModeDraft,
triggerTimeDraft,
modeDraft,
pollIntervalDraft,
startDateDraft,
endDateDraft,
enableMockDraft,
setters,
setIsRuntimeConfigSaving,
setIsWatchlistSaving,
setRuntimeConfigFeedback,
setIsRuntimeSettingsOpen,
addSystemMessage
]);
const handleRuntimeDefaultsRestore = useCallback(() => {
setScheduleModeDraft('daily');
setIntervalMinutesDraft('60');
setTriggerTimeDraft('09:30');
setMaxCommCyclesDraft('2');
setInitialCashDraft('100000');
setMarginRequirementDraft('0');
setEnableMemoryDraft(false);
setModeDraft('live');
setPollIntervalDraft('10');
setStartDateDraft('');
setEndDateDraft('');
setEnableMockDraft(false);
setRuntimeConfigFeedback(null);
}, [
setScheduleModeDraft,
setIntervalMinutesDraft,
setTriggerTimeDraft,
setMaxCommCyclesDraft,
setInitialCashDraft,
setMarginRequirementDraft,
setEnableMemoryDraft,
setModeDraft,
setPollIntervalDraft,
setStartDateDraft,
setEndDateDraft,
setEnableMockDraft,
setRuntimeConfigFeedback
]);
const handleRuntimeSettingsToggle = useCallback(() => {
setRuntimeConfigFeedback(null);
setters.setAgentSkillsFeedback(null);
setters.setWorkspaceFileFeedback(null);
setIsRuntimeSettingsOpen((prev) => {
const nextOpen = !prev;
if (nextOpen) {
// Initialize watchlist draft when opening settings
setters.setWatchlistDraftSymbols(settlers.runtimeWatchlistSymbols);
setters.setWatchlistInputValue('');
setters.setWatchlistFeedback(null);
}
return nextOpen;
});
setters.setIsWatchlistPanelOpen(false);
}, [setRuntimeConfigFeedback, setters, setIsRuntimeSettingsOpen]);
const handleManualTrigger = useCallback(() => {
if (!clientRef.current) {
addSystemMessage('连接未就绪,无法手动触发');
return;
}
const success = clientRef.current.send({
type: 'trigger_strategy'
});
if (!success) {
addSystemMessage('手动触发发送失败,请检查连接状态');
return;
}
addSystemMessage('已发送手动触发请求');
}, [clientRef, addSystemMessage]);
return {
handleRuntimeConfigSave,
handleLaunchConfigSave,
handleRuntimeDefaultsRestore,
handleRuntimeSettingsToggle,
handleManualTrigger
};
}

View File

@@ -0,0 +1,584 @@
import { useCallback } from 'react';
import {
fetchNewsCategoriesDirect,
fetchNewsForDateDirect,
fetchRangeExplainDirect,
fetchSimilarDaysDirect,
fetchStockStoryDirect,
hasDirectNewsService
} from '../services/newsApi';
import {
fetchInsiderTradesDirect,
fetchStockHistoryDirect,
hasDirectTradingService
} from '../services/tradingApi';
/**
* Extracts all requestStock* callbacks from App.jsx into a single hook.
*/
export function useStockRequestCallbacks({
clientRef,
currentDate,
requestedStockHistoryRef,
setters,
apiHelpers
}) {
const {
setOhlcHistoryByTicker,
setHistorySourceByTicker,
setExplainEventsByTicker,
setNewsByTicker,
setInsiderTradesByTicker,
setTechnicalIndicatorsByTicker,
setPriceHistoryByTicker
} = setters;
const {
hasDirectTradingService: _hasDirectTradingService,
fetchStockHistoryDirect: _fetchStockHistoryDirect,
hasDirectNewsService: _hasDirectNewsService,
fetchNewsForDateDirect: _fetchNewsForDateDirect,
fetchNewsCategoriesDirect: _fetchNewsCategoriesDirect,
fetchInsiderTradesDirect: _fetchInsiderTradesDirect,
fetchRangeExplainDirect: _fetchRangeExplainDirect,
fetchStockStoryDirect: _fetchStockStoryDirect,
fetchSimilarDaysDirect: _fetchSimilarDaysDirect
} = apiHelpers;
const buildTickersFromSymbols = useCallback((symbols, previousTickers = []) => {
if (!Array.isArray(symbols) || symbols.length === 0) {
return previousTickers;
}
return symbols
.filter((symbol) => typeof symbol === 'string' && symbol.trim())
.map((symbol) => {
const normalized = symbol.trim().toUpperCase();
const existing = previousTickers.find((ticker) => ticker.symbol === normalized);
return existing || {
symbol: normalized,
price: null,
change: null
};
});
}, []);
const normalizePriceHistory = useCallback((payload) => {
if (!payload || typeof payload !== 'object') {
return {};
}
const normalized = {};
Object.entries(payload).forEach(([symbol, points]) => {
const ticker = String(symbol || '').trim().toUpperCase();
if (!ticker || !Array.isArray(points)) {
return;
}
normalized[ticker] = points
.map((point) => {
if (Array.isArray(point) && point.length >= 2) {
const [label, value] = point;
const price = Number(value);
if (!label || !Number.isFinite(price)) return null;
return {
timestamp: String(label),
label: String(label),
price
};
}
if (point && typeof point === 'object') {
const rawTimestamp = point.timestamp ?? point.t ?? point.date ?? point.label;
const price = Number(point.price ?? point.v ?? point.value ?? point.close);
if (!rawTimestamp || !Number.isFinite(price)) return null;
return {
timestamp: String(rawTimestamp),
label: String(rawTimestamp),
price
};
}
return null;
})
.filter(Boolean)
.slice(-120);
});
return normalized;
}, []);
const requestStockHistory = useCallback((symbol, { force = false } = {}) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized) {
return false;
}
if (!force && requestedStockHistoryRef.current.has(normalized)) {
return false;
}
const endDate = currentDate
? String(currentDate).slice(0, 10)
: new Date().toISOString().slice(0, 10);
const end = new Date(`${endDate}T00:00:00`);
const start = new Date(end);
start.setDate(start.getDate() - 120);
const startDate = start.toISOString().slice(0, 10);
if (_hasDirectTradingService()) {
void _fetchStockHistoryDirect(normalized, startDate, endDate)
.then((payload) => {
const prices = Array.isArray(payload?.prices) ? payload.prices : [];
setOhlcHistoryByTicker((prev) => ({
...prev,
[normalized]: prices
}));
setPriceHistoryByTicker((prev) => ({
...prev,
[normalized]: prices
.map((point) => {
const price = Number(point?.close);
const timestamp = point?.time;
if (!timestamp || !Number.isFinite(price)) {
return null;
}
return {
timestamp: String(timestamp),
label: String(timestamp),
price
};
})
.filter(Boolean)
}));
setHistorySourceByTicker((prev) => ({
...prev,
[normalized]: 'trading_service'
}));
})
.catch((error) => {
console.error('Direct stock-history fetch failed, falling back to websocket:', error);
if (clientRef.current) {
const success = clientRef.current.send({
type: 'get_stock_history',
ticker: normalized,
lookback_days: 120
});
if (success) {
requestedStockHistoryRef.current.add(normalized);
}
}
});
requestedStockHistoryRef.current.add(normalized);
return true;
}
if (!clientRef.current) {
return false;
}
const success = clientRef.current.send({
type: 'get_stock_history',
ticker: normalized,
lookback_days: 120
});
if (success) {
requestedStockHistoryRef.current.add(normalized);
}
return success;
}, [currentDate, _hasDirectTradingService, _fetchStockHistoryDirect, clientRef, requestedStockHistoryRef, setOhlcHistoryByTicker, setPriceHistoryByTicker, setHistorySourceByTicker]);
const requestStockExplainEvents = useCallback((symbol) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_explain_events',
ticker: normalized
});
}, [clientRef]);
const requestStockNews = useCallback((symbol) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_news',
ticker: normalized,
lookback_days: 45,
limit: 12
});
}, [clientRef]);
const requestStockNewsForDate = useCallback((symbol, date) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !date) {
return false;
}
if (_hasDirectNewsService()) {
void _fetchNewsForDateDirect(normalized, date, 20)
.then((payload) => {
const targetDate = typeof payload?.date === 'string' ? payload.date.trim() : date;
const news = Array.isArray(payload?.news) ? payload.news : [];
const freshness = payload?.freshness || null;
setNewsByTicker((prev) => ({
...prev,
[normalized]: {
...(prev[normalized] || {}),
byDate: {
...((prev[normalized] && prev[normalized].byDate) || {}),
[targetDate]: news
},
byDateFreshness: {
...((prev[normalized] && prev[normalized].byDateFreshness) || {}),
[targetDate]: freshness
}
}
}));
})
.catch((error) => {
console.error('Direct news-for-date fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_news_for_date',
ticker: normalized,
date,
limit: 20
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_news_for_date',
ticker: normalized,
date,
limit: 20
});
}, [clientRef, _hasDirectNewsService, _fetchNewsForDateDirect, setNewsByTicker]);
const requestStockNewsTimeline = useCallback((symbol) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_news_timeline',
ticker: normalized,
lookback_days: 90
});
}, [clientRef]);
const requestStockNewsCategories = useCallback((symbol) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized) {
return false;
}
const endDate = currentDate
? String(currentDate).slice(0, 10)
: new Date().toISOString().slice(0, 10);
const end = new Date(`${endDate}T00:00:00`);
const start = new Date(end);
start.setDate(start.getDate() - 90);
const startDate = start.toISOString().slice(0, 10);
if (_hasDirectNewsService()) {
void _fetchNewsCategoriesDirect(normalized, startDate, endDate, 200)
.then((payload) => {
const freshness = payload?.freshness || null;
setNewsByTicker((prev) => ({
...prev,
[normalized]: {
...(prev[normalized] || {}),
categories: payload?.categories || {},
categoriesStartDate: startDate,
categoriesEndDate: endDate,
categoriesFreshness: freshness
}
}));
})
.catch((error) => {
console.error('Direct news-categories fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_news_categories',
ticker: normalized,
lookback_days: 90
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_news_categories',
ticker: normalized,
lookback_days: 90
});
}, [currentDate, clientRef, _hasDirectNewsService, _fetchNewsCategoriesDirect, setNewsByTicker]);
const requestStockInsiderTrades = useCallback((symbol, startDate = null, endDate = null) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized) {
return false;
}
if (_hasDirectTradingService()) {
void _fetchInsiderTradesDirect(normalized, startDate, endDate, 50)
.then((payload) => {
const rows = Array.isArray(payload?.insider_trades) ? payload.insider_trades : [];
setInsiderTradesByTicker((prev) => ({
...prev,
[normalized]: {
ticker: normalized,
startDate: startDate || null,
endDate: endDate || null,
trades: rows
}
}));
})
.catch((error) => {
console.error('Direct insider-trades fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_insider_trades',
ticker: normalized,
start_date: startDate,
end_date: endDate,
limit: 50
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_insider_trades',
ticker: normalized,
start_date: startDate,
end_date: endDate,
limit: 50
});
}, [clientRef, _hasDirectTradingService, _fetchInsiderTradesDirect, setInsiderTradesByTicker]);
const requestStockTechnicalIndicators = useCallback((symbol) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_technical_indicators',
ticker: normalized
});
}, [clientRef]);
const requestStockRangeExplain = useCallback((symbol, startDate, endDate, articleIds = []) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !startDate || !endDate) {
return false;
}
if (_hasDirectNewsService()) {
void _fetchRangeExplainDirect(normalized, startDate, endDate, articleIds)
.then((payload) => {
const result = payload?.result && typeof payload.result === 'object' ? payload.result : null;
const freshness = payload?.freshness || null;
if (!result?.start_date || !result?.end_date) {
return;
}
const cacheKey = `${result.start_date}:${result.end_date}`;
setNewsByTicker((prev) => ({
...prev,
[normalized]: {
...(prev[normalized] || {}),
rangeExplainCache: {
...((prev[normalized] && prev[normalized].rangeExplainCache) || {}),
[cacheKey]: {
...result,
freshness
}
}
}
}));
})
.catch((error) => {
console.error('Direct range explain fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_range_explain',
ticker: normalized,
start_date: startDate,
end_date: endDate,
article_ids: Array.isArray(articleIds) ? articleIds : []
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_range_explain',
ticker: normalized,
start_date: startDate,
end_date: endDate,
article_ids: Array.isArray(articleIds) ? articleIds : []
});
}, [clientRef, _hasDirectNewsService, _fetchRangeExplainDirect, setNewsByTicker]);
const requestStockStory = useCallback((symbol, asOfDate) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
const date = typeof asOfDate === 'string' ? asOfDate.trim() : '';
if (!normalized || !date) {
return false;
}
if (_hasDirectNewsService()) {
void _fetchStockStoryDirect(normalized, date)
.then((payload) => {
setNewsByTicker((prev) => ({
...prev,
[normalized]: {
...(prev[normalized] || {}),
storyCache: {
...((prev[normalized] && prev[normalized].storyCache) || {}),
[date]: {
story: payload?.story || '',
source: payload?.source || null,
asOfDate: date,
freshness: payload?.freshness || null
}
}
}
}));
})
.catch((error) => {
console.error('Direct story fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_story',
ticker: normalized,
as_of_date: date
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_story',
ticker: normalized,
as_of_date: date
});
}, [clientRef, _hasDirectNewsService, _fetchStockStoryDirect, setNewsByTicker]);
const requestStockSimilarDays = useCallback((symbol, targetDate, lookbackDays = 365) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
const date = typeof targetDate === 'string' ? targetDate.trim() : '';
if (!normalized || !date) {
return false;
}
if (_hasDirectNewsService()) {
void _fetchSimilarDaysDirect(normalized, date, lookbackDays)
.then((payload) => {
setNewsByTicker((prev) => ({
...prev,
[normalized]: {
...(prev[normalized] || {}),
similarDaysCache: {
...((prev[normalized] && prev[normalized].similarDaysCache) || {}),
[date]: {
target_features: payload?.target_features || {},
items: Array.isArray(payload?.items) ? payload?.items : [],
error: payload?.error || null,
freshness: payload?.freshness || null
}
}
}
}));
})
.catch((error) => {
console.error('Direct similar-days fetch failed, falling back to websocket:', error);
if (clientRef.current) {
clientRef.current.send({
type: 'get_stock_similar_days',
ticker: normalized,
target_date: date,
lookback_days: lookbackDays
});
}
});
return true;
}
if (!clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'get_stock_similar_days',
ticker: normalized,
target_date: date,
lookback_days: lookbackDays
});
}, [clientRef, _hasDirectNewsService, _fetchSimilarDaysDirect, setNewsByTicker]);
const requestStockEnrich = useCallback((symbol, startDate, endDate, { force = false, onlyLocalToLlm = false } = {}) => {
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
if (!normalized || !clientRef.current) {
return false;
}
return clientRef.current.send({
type: 'enrich_stock_news',
ticker: normalized,
start_date: startDate,
end_date: endDate,
force: Boolean(force),
only_local_to_llm: Boolean(onlyLocalToLlm)
});
}, [clientRef]);
return {
buildTickersFromSymbols,
normalizePriceHistory,
requestStockHistory,
requestStockExplainEvents,
requestStockNews,
requestStockNewsForDate,
requestStockNewsTimeline,
requestStockNewsCategories,
requestStockInsiderTrades,
requestStockTechnicalIndicators,
requestStockRangeExplain,
requestStockStory,
requestStockSimilarDays,
requestStockEnrich
};
}

View File

@@ -0,0 +1,144 @@
import { useCallback, useMemo } from 'react';
import { INITIAL_TICKERS } from '../config/constants';
/**
* Extracts watchlist-related callbacks from App.jsx into a single hook.
*/
export function useWatchlistCallbacks({
clientRef,
runtimeWatchlistSymbols,
watchlistDraftSymbols,
watchlistInputValue,
watchlistFeedback,
setters
}) {
const {
setWatchlistDraftSymbols,
setWatchlistInputValue,
setWatchlistFeedback
} = setters;
const parseWatchlistInput = useCallback((value) => {
if (typeof value !== 'string') {
return [];
}
return Array.from(
new Set(
value
.split(/[\s,]+/)
.map((symbol) => symbol.trim().toUpperCase())
.filter(Boolean)
)
);
}, []);
const commitWatchlistInput = useCallback((value) => {
const parsed = parseWatchlistInput(value);
if (parsed.length === 0) {
return [];
}
setWatchlistDraftSymbols((prev) => Array.from(new Set([...prev, ...parsed])));
setWatchlistInputValue('');
if (watchlistFeedback) {
setWatchlistFeedback(null);
}
return parsed;
}, [parseWatchlistInput, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback, setters]);
const handleWatchlistRemove = useCallback((symbolToRemove) => {
setWatchlistDraftSymbols((prev) => prev.filter((symbol) => symbol !== symbolToRemove));
if (watchlistFeedback) {
setWatchlistFeedback(null);
}
}, [watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
const handleWatchlistInputChange = useCallback((value) => {
setWatchlistInputValue(value);
if (watchlistFeedback) {
setWatchlistFeedback(null);
}
}, [watchlistFeedback, setWatchlistInputValue, setWatchlistFeedback]);
const handleWatchlistInputKeyDown = useCallback((e) => {
if (e.key === 'Enter' || e.key === ',') {
e.preventDefault();
commitWatchlistInput(watchlistInputValue);
}
}, [commitWatchlistInput, watchlistInputValue]);
const handleWatchlistSuggestionClick = useCallback((symbol) => {
if (watchlistDraftSymbols.includes(symbol)) {
return;
}
setWatchlistDraftSymbols((prev) => [...prev, symbol]);
if (watchlistFeedback) {
setWatchlistFeedback(null);
}
}, [watchlistDraftSymbols, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
const handleWatchlistRestoreCurrent = useCallback(() => {
setWatchlistDraftSymbols(runtimeWatchlistSymbols);
setWatchlistInputValue('');
setWatchlistFeedback(null);
}, [runtimeWatchlistSymbols, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback]);
const handleWatchlistSave = useCallback(() => {
const pendingTickers = parseWatchlistInput(watchlistInputValue);
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
if (nextTickers.length === 0) {
setWatchlistFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
return;
}
if (!clientRef.current) {
setWatchlistFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
return;
}
setters.setIsWatchlistSaving(true);
setWatchlistFeedback(null);
setWatchlistDraftSymbols(nextTickers);
setWatchlistInputValue('');
const success = clientRef.current.send({
type: 'update_watchlist',
tickers: nextTickers
});
if (!success) {
setters.setIsWatchlistSaving(false);
setWatchlistFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
}
}, [parseWatchlistInput, watchlistDraftSymbols, watchlistInputValue, clientRef, setters.setIsWatchlistSaving, setWatchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue]);
const watchlistSuggestions = useMemo(
() => INITIAL_TICKERS.map((ticker) => ticker.symbol).filter((symbol, index, list) => list.indexOf(symbol) === index),
[]
);
const isWatchlistDraftDirty = useMemo(() => {
if (watchlistInputValue.trim()) {
return true;
}
if (watchlistDraftSymbols.length !== runtimeWatchlistSymbols.length) {
return true;
}
return watchlistDraftSymbols.some((symbol, index) => symbol !== runtimeWatchlistSymbols[index]);
}, [runtimeWatchlistSymbols, watchlistDraftSymbols, watchlistInputValue]);
return {
parseWatchlistInput,
commitWatchlistInput,
handleWatchlistRemove,
handleWatchlistInputChange,
handleWatchlistInputKeyDown,
handleWatchlistSuggestionClick,
handleWatchlistRestoreCurrent,
handleWatchlistSave,
watchlistSuggestions,
isWatchlistDraftDirty
};
}

File diff suppressed because it is too large Load Diff

View File

@@ -87,4 +87,8 @@ export const useRuntimeStore = create((set) => ({
isRuntimeConfigSaving: false, isRuntimeConfigSaving: false,
setRuntimeConfigFeedback: (runtimeConfigFeedback) => set({ runtimeConfigFeedback }), setRuntimeConfigFeedback: (runtimeConfigFeedback) => set({ runtimeConfigFeedback }),
setIsRuntimeConfigSaving: (isRuntimeConfigSaving) => set({ isRuntimeConfigSaving }), setIsRuntimeConfigSaving: (isRuntimeConfigSaving) => set({ isRuntimeConfigSaving }),
// Last day history (for replay)
lastDayHistory: [],
setLastDayHistory: (lastDayHistory) => set({ lastDayHistory }),
})); }));