feat: 架构修复 - P0/P1 问题全面修复
P0 修复: - runtimeStore: 添加缺失的 lastDayHistory 字段 - Gateway/RuntimeService: 状态同步改为内存优先,消除 glob 竞态 - App.jsx: 从 3075 行重构到 ~500 行,提取 8 个独立文件 P1 修复: - CORS: 4 个服务改为从环境变量读取允许 origins - MarketStore: 改为模块级单例模式 - Domain 层: 删除 trading thin wrapper,保留 news 真实逻辑 - 测试: 补齐 77 个 gateway/runtime 测试 新增文件: - backend/tests/test_gateway.py (43 tests) - frontend/src/hooks/useWebSocketHandler.js - frontend/src/hooks/useStockRequestCallbacks.js - frontend/src/hooks/useAgentCallbacks.js - frontend/src/hooks/useRuntimeCallbacks.js - frontend/src/hooks/useWatchlistCallbacks.js - frontend/src/components/TickerBar.jsx - frontend/src/components/HeaderRight.jsx - frontend/src/components/ChartTabs.jsx Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,36 +302,28 @@ def _start_gateway_process(
|
||||
|
||||
@router.get("/context", response_model=RunContextResponse)
|
||||
async def get_run_context() -> RunContextResponse:
|
||||
"""Return the most recent run context."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return the current run context from in-memory state (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None or manager.context is None:
|
||||
raise HTTPException(status_code=404, detail="No run context available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context")
|
||||
if context is None:
|
||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||
|
||||
context = manager.context
|
||||
return RunContextResponse(
|
||||
config_name=context["config_name"],
|
||||
run_dir=context["run_dir"],
|
||||
bootstrap_values=context["bootstrap_values"],
|
||||
config_name=context.config_name,
|
||||
run_dir=str(context.run_dir),
|
||||
bootstrap_values=context.bootstrap_values,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
"""Return agent states from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return agent states from the in-memory runtime manager (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
agents = latest.get("agents", [])
|
||||
snapshot = manager.build_snapshot()
|
||||
agents = snapshot.get("agents", [])
|
||||
|
||||
return RuntimeAgentsResponse(
|
||||
agents=[RuntimeAgentState(**a) for a in agents]
|
||||
@@ -340,15 +332,13 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
|
||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return events from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return events from the in-memory runtime manager (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
events = latest.get("events", [])
|
||||
snapshot = manager.build_snapshot()
|
||||
events = snapshot.get("events", [])
|
||||
|
||||
return RuntimeEventsResponse(
|
||||
events=[RuntimeEvent(**e) for e in events]
|
||||
@@ -362,15 +352,10 @@ async def get_gateway_status() -> GatewayStatusResponse:
|
||||
run_id = None
|
||||
|
||||
if is_running:
|
||||
# Try to find run_id from runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if snapshots:
|
||||
try:
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
run_id = latest.get("context", {}).get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse latest snapshot: {e}")
|
||||
# Get run_id from in-memory runtime manager (avoids glob race condition)
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and manager.context is not None:
|
||||
run_id = manager.context.config_name
|
||||
|
||||
return GatewayStatusResponse(
|
||||
is_running=is_running,
|
||||
@@ -404,8 +389,28 @@ def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||
return f"{ws_scheme}://{host}:{port}"
|
||||
|
||||
|
||||
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Load the latest persisted runtime snapshot."""
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the in-memory manager (avoids glob race condition).
|
||||
|
||||
Falls back to file-based lookup only when the in-memory manager is not available
|
||||
(e.g., after a service restart). File-based lookup is deprecated and exists
|
||||
only for backward compatibility.
|
||||
"""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Primary: use in-memory manager (always correct for current process)
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and manager.context is not None:
|
||||
ctx = manager.context
|
||||
return {
|
||||
"config_name": ctx.config_name,
|
||||
"run_dir": str(ctx.run_dir),
|
||||
"bootstrap_values": ctx.bootstrap_values,
|
||||
}
|
||||
|
||||
# Deprecated fallback: scan filesystem (only for backward compatibility
|
||||
# after service restart without a restart of the runtime itself)
|
||||
snapshots = sorted(
|
||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
@@ -413,14 +418,7 @@ def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
)
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
latest = _load_latest_runtime_snapshot()
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
@@ -663,15 +661,8 @@ async def get_current_runtime():
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Find latest runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context", {})
|
||||
# Get context from in-memory manager (avoids glob race condition)
|
||||
context = _get_current_runtime_context()
|
||||
|
||||
return {
|
||||
"run_id": context.get("config_name"),
|
||||
|
||||
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.api import agents_router, guard_router, workspaces_router
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
agent_factory: AgentFactory | None = None
|
||||
@@ -49,7 +50,7 @@ def create_app(project_root: Path | None = None) -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.domains import news as news_domain
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
|
||||
def get_market_store() -> MarketStore:
|
||||
@@ -27,7 +28,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.api import runtime_router
|
||||
from backend.api.runtime import get_runtime_state
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -20,7 +21,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,7 +8,16 @@ from typing import Any
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.config.env_config import get_cors_origins
|
||||
from backend.services.market import MarketService
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
from shared.schema import (
|
||||
CompanyNewsResponse,
|
||||
FinancialMetricsResponse,
|
||||
@@ -28,7 +37,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -45,12 +54,8 @@ def create_app() -> FastAPI:
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
) -> PriceResponse:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
||||
prices = get_prices(ticker=ticker, start_date=start_date, end_date=end_date)
|
||||
return PriceResponse(ticker=ticker, prices=prices)
|
||||
|
||||
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||
async def api_get_financials(
|
||||
@@ -59,13 +64,13 @@ def create_app() -> FastAPI:
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> FinancialMetricsResponse:
|
||||
payload = trading_domain.get_financials_payload(
|
||||
metrics = get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
||||
return FinancialMetricsResponse(financial_metrics=metrics)
|
||||
|
||||
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||
async def api_get_news(
|
||||
@@ -74,13 +79,13 @@ def create_app() -> FastAPI:
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> CompanyNewsResponse:
|
||||
payload = trading_domain.get_news_payload(
|
||||
news = get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return CompanyNewsResponse(news=payload["news"])
|
||||
return CompanyNewsResponse(news=news)
|
||||
|
||||
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||
async def api_get_insider_trades(
|
||||
@@ -89,18 +94,19 @@ def create_app() -> FastAPI:
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> InsiderTradeResponse:
|
||||
payload = trading_domain.get_insider_trades_payload(
|
||||
trades = get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
||||
return InsiderTradeResponse(insider_trades=trades)
|
||||
|
||||
@app.get("/api/market/status")
|
||||
async def api_get_market_status() -> dict[str, Any]:
|
||||
"""Return current market status using the existing market service logic."""
|
||||
return trading_domain.get_market_status_payload()
|
||||
service = MarketService(tickers=[])
|
||||
return service.get_market_status()
|
||||
|
||||
@app.get("/api/market-cap")
|
||||
async def api_get_market_cap(
|
||||
@@ -108,10 +114,12 @@ def create_app() -> FastAPI:
|
||||
end_date: str = Query(...),
|
||||
) -> dict[str, Any]:
|
||||
"""Return market cap for one ticker/date."""
|
||||
return trading_domain.get_market_cap_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
)
|
||||
market_cap = get_market_cap(ticker=ticker, end_date=end_date)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": market_cap,
|
||||
}
|
||||
|
||||
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||
async def api_get_line_items(
|
||||
@@ -121,14 +129,14 @@ def create_app() -> FastAPI:
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> LineItemResponse:
|
||||
payload = trading_domain.get_line_items_payload(
|
||||
items = search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return LineItemResponse(search_results=payload["search_results"])
|
||||
return LineItemResponse(search_results=items)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""Environment config helpers with light validation and normalization."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,6 +17,36 @@ PROVIDER_ALIASES = {
|
||||
"vertexai": "GEMINI",
|
||||
}
|
||||
|
||||
# Default dev CORS origins (localhost variants used by common dev servers)
|
||||
_LOCALHOST_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:8000",
|
||||
]
|
||||
|
||||
|
||||
def get_cors_origins() -> list[str]:
|
||||
"""Get CORS allowed origins from environment.
|
||||
|
||||
Reads CORS_ALLOWED_ORIGINS env var (comma-separated).
|
||||
Falls back to localhost dev origins if not set.
|
||||
Warns if "*" is configured (only acceptable for local dev).
|
||||
"""
|
||||
origins = get_env_list("CORS_ALLOWED_ORIGINS", default=[])
|
||||
if origins:
|
||||
if "*" in origins:
|
||||
warnings.warn(
|
||||
"CORS_ALLOWED_ORIGINS contains '*' — this allows any origin. "
|
||||
"Only use in local development, never in production.",
|
||||
UserWarning,
|
||||
)
|
||||
return origins
|
||||
# Fallback: local dev only
|
||||
return _LOCALHOST_ORIGINS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentModelConfig:
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.tools.data_tools import get_market_cap
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -265,8 +265,7 @@ async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[s
|
||||
if response is not None:
|
||||
market_cap = response.get("market_cap")
|
||||
if market_cap is None:
|
||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||
market_cap = payload.get("market_cap")
|
||||
market_cap = get_market_cap(ticker=ticker, end_date=date)
|
||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||
|
||||
@@ -11,10 +11,9 @@ from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.domains import news as news_domain
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.tools.data_tools import prices_to_df
|
||||
from backend.tools.data_tools import get_insider_trades, get_prices, prices_to_df
|
||||
from shared.client import NewsServiceClient, TradingServiceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,13 +58,12 @@ async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str,
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||
if not prices:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_prices_payload,
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
@@ -400,14 +398,13 @@ async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: di
|
||||
trades = response.insider_trades
|
||||
|
||||
if not trades:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_insider_trades_payload,
|
||||
trades = await asyncio.to_thread(
|
||||
get_insider_trades,
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
)
|
||||
trades = payload.get("insider_trades") or []
|
||||
|
||||
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||
formatted_trades = [{
|
||||
@@ -540,12 +537,11 @@ async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, da
|
||||
prices = response.prices
|
||||
|
||||
if prices is None:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
prices = get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
|
||||
if not prices or len(prices) < 20:
|
||||
await websocket.send(json.dumps({
|
||||
|
||||
549
backend/tests/test_gateway.py
Normal file
549
backend/tests/test_gateway.py
Normal file
@@ -0,0 +1,549 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the Gateway main class - core behavior and fallback paths."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services.gateway import Gateway
|
||||
import backend.services.gateway as gateway_module
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.closed = False
|
||||
self._queue = []
|
||||
|
||||
def queue(self, data: str):
|
||||
"""Queue a raw message string to be yielded by the async iterator."""
|
||||
self._queue.append(data)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._queue:
|
||||
raise StopAsyncIteration
|
||||
return self._queue.pop(0)
|
||||
|
||||
async def send(self, payload: str):
|
||||
self.messages.append(json.loads(payload))
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class DummyStateSync:
|
||||
def __init__(self, current_date="2026-03-16"):
|
||||
self.state = {"current_date": current_date}
|
||||
self.system_messages = []
|
||||
self.saved = False
|
||||
self.initial_state_payload = {}
|
||||
|
||||
def set_broadcast_fn(self, _fn):
|
||||
return None
|
||||
|
||||
def update_state(self, key, value):
|
||||
self.state[key] = value
|
||||
|
||||
def save_state(self):
|
||||
self.saved = True
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
def get_initial_state_payload(self, include_dashboard=True):
|
||||
return {
|
||||
"status": "running",
|
||||
"current_date": self.state.get("current_date", ""),
|
||||
"portfolio": {},
|
||||
"holdings": [],
|
||||
"trades": [],
|
||||
}
|
||||
|
||||
|
||||
class DummyMarketService:
|
||||
def __init__(self):
|
||||
self.broadcast_func = None
|
||||
self.market_status = {"is_open": True, "session": "regular"}
|
||||
|
||||
def set_price_recorder(self, _fn):
|
||||
return None
|
||||
|
||||
async def start(self, broadcast_func=None):
|
||||
self.broadcast_func = broadcast_func
|
||||
|
||||
def get_market_status(self):
|
||||
return self.market_status
|
||||
|
||||
|
||||
class DummyStorage:
|
||||
def __init__(self, initial_cash=100000.0, live_session=False):
|
||||
self.initial_cash = initial_cash
|
||||
self.is_live_session_active = live_session
|
||||
self._market_store = SimpleNamespace()
|
||||
|
||||
@property
|
||||
def market_store(self):
|
||||
return self._market_store
|
||||
|
||||
def load_file(self, name):
|
||||
if name == "summary":
|
||||
return {"totalAssetValue": self.initial_cash}
|
||||
if name in ("holdings", "trades"):
|
||||
return []
|
||||
return None
|
||||
|
||||
def get_live_returns(self):
|
||||
return {"session_pnl": 0.0, "session_return": 0.0}
|
||||
|
||||
|
||||
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
|
||||
storage = storage or DummyStorage()
|
||||
state_sync = state_sync or DummyStateSync()
|
||||
market_service = market_service or DummyMarketService()
|
||||
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
|
||||
return Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage,
|
||||
pipeline=pipeline,
|
||||
state_sync=state_sync,
|
||||
config=config or {"mode": "live"},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gateway initialization and core properties
|
||||
# =============================================================================
|
||||
|
||||
def test_gateway_init_sets_live_mode():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
assert gateway.mode == "live"
|
||||
assert gateway.is_backtest is False
|
||||
|
||||
|
||||
def test_gateway_init_sets_backtest_mode_from_config():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
assert gateway.mode == "backtest"
|
||||
assert gateway.is_backtest is True
|
||||
|
||||
|
||||
def test_gateway_init_sets_backtest_mode_from_flag():
|
||||
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
|
||||
assert gateway.is_backtest is True
|
||||
|
||||
|
||||
def test_gateway_init_defaults_to_live_mode():
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway.mode == "live"
|
||||
assert gateway.is_backtest is False
|
||||
|
||||
|
||||
def test_gateway_state_property_returns_state_sync_state():
|
||||
state_sync = DummyStateSync()
|
||||
state_sync.state["foo"] = "bar"
|
||||
gateway = make_gateway(state_sync=state_sync)
|
||||
assert gateway.state["foo"] == "bar"
|
||||
|
||||
|
||||
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
|
||||
rows = [{"id": "1"}, {"id": "2"}]
|
||||
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
|
||||
result = Gateway._news_rows_need_enrichment(rows)
|
||||
mock.assert_called_once_with(rows)
|
||||
assert result is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Service URL helpers and fallback paths
|
||||
# =============================================================================
|
||||
|
||||
def test_news_service_url_returns_config_value(monkeypatch):
|
||||
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
|
||||
assert gateway._news_service_url() == "http://custom-news:9000"
|
||||
|
||||
|
||||
def test_news_service_url_falls_back_to_env(monkeypatch):
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._news_service_url() == "http://env-news:9001"
|
||||
|
||||
|
||||
def test_news_service_url_returns_none_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._news_service_url() is None
|
||||
|
||||
|
||||
def test_news_service_url_strips_whitespace(monkeypatch):
|
||||
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
|
||||
assert gateway._news_service_url() == "http://whitespace-news:9000"
|
||||
|
||||
|
||||
def test_trading_service_url_returns_config_value(monkeypatch):
|
||||
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
|
||||
assert gateway._trading_service_url() == "http://custom-trading:9000"
|
||||
|
||||
|
||||
def test_trading_service_url_falls_back_to_env(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._trading_service_url() == "http://env-trading:9001"
|
||||
|
||||
|
||||
def test_trading_service_url_returns_none_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._trading_service_url() is None
|
||||
|
||||
|
||||
def test_trading_service_url_strips_whitespace(monkeypatch):
|
||||
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
|
||||
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
|
||||
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
|
||||
async def dummy_callback(client):
|
||||
return "should not be called"
|
||||
|
||||
result = await gateway._call_news_service("test_action", dummy_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_calls_callback_and_returns():
|
||||
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||
|
||||
async def callback(client):
|
||||
return {"result": "ok"}
|
||||
|
||||
result = await gateway._call_news_service("test_action", callback)
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_returns_none_on_exception():
|
||||
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||
|
||||
async def failing_callback(client):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
result = await gateway._call_news_service("test_action", failing_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
|
||||
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
|
||||
result = await gateway._call_trading_service("test_action", lambda c: None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_calls_callback_and_returns():
|
||||
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||
|
||||
async def callback(client):
|
||||
return {"result": "ok"}
|
||||
result = await gateway._call_trading_service("test_action", callback)
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_returns_none_on_exception():
|
||||
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||
|
||||
async def failing_callback(client):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
result = await gateway._call_trading_service("test_action", failing_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket message handlers
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_ping_returns_pong():
|
||||
"""Ping message type results in a pong response."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "ping"}))
|
||||
|
||||
await gateway._handle_client_messages(ws)
|
||||
|
||||
assert ws.messages[-1]["type"] == "pong"
|
||||
assert "timestamp" in ws.messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_get_state_sends_initial_state():
|
||||
"""get_state message type triggers _send_initial_state."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "get_state"}))
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
|
||||
await gateway._handle_client_messages(ws)
|
||||
mock_send.assert_called_once_with(ws)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_unknown_type_is_silently_ignored():
|
||||
"""Unknown message types are silently ignored without error."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "unknown_type"}))
|
||||
|
||||
# Should not raise
|
||||
await gateway._handle_client_messages(ws)
|
||||
assert len(ws.messages) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
|
||||
"""Invalid JSON messages are caught by the handler's except block."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue("not valid json")
|
||||
|
||||
# Should not raise
|
||||
await gateway._handle_client_messages(ws)
|
||||
assert len(ws.messages) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backtest handling
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
# Should not raise - backtest is ignored in live mode
|
||||
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
|
||||
# Gateway should not have started a backtest task
|
||||
assert gateway._backtest_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_start_backtest_ignored_when_task_already_running():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
|
||||
# Pre-set a backtest task
|
||||
dummy_task = MagicMock()
|
||||
dummy_task.done.return_value = False
|
||||
gateway._backtest_task = dummy_task
|
||||
|
||||
# Should not start a new task
|
||||
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
|
||||
|
||||
assert gateway._backtest_task is dummy_task # unchanged
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Manual trigger (live/mock mode)
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_manual_trigger_rejected_in_backtest_mode():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
ws = DummyWebSocket()
|
||||
|
||||
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||
|
||||
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
ws = DummyWebSocket()
|
||||
|
||||
# Simulate a running cycle task
|
||||
dummy_task = MagicMock()
|
||||
dummy_task.done.return_value = False
|
||||
gateway._manual_cycle_task = dummy_task
|
||||
|
||||
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||
|
||||
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Normalization helpers
|
||||
# =============================================================================
|
||||
|
||||
def test_normalize_watchlist_filters_empty_and_dedupes():
|
||||
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_watchlist_handles_string_input():
|
||||
result = Gateway._normalize_watchlist("aapl, msft, aapl")
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_allows_editable_files():
|
||||
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
|
||||
result = Gateway._normalize_agent_workspace_filename(filename)
|
||||
assert result == filename
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
|
||||
result = Gateway._normalize_agent_workspace_filename("README.md")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
|
||||
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Broadcast
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_skips_when_no_clients():
|
||||
gateway = make_gateway()
|
||||
gateway.connected_clients = set()
|
||||
|
||||
# Should not raise
|
||||
await gateway.broadcast({"type": "test"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_sends_to_all_connected_clients():
|
||||
gateway = make_gateway()
|
||||
ws1 = DummyWebSocket()
|
||||
ws2 = DummyWebSocket()
|
||||
gateway.connected_clients = {ws1, ws2}
|
||||
|
||||
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||
|
||||
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||
assert ws1.messages[0]["data"] == "test"
|
||||
assert ws2.messages[0]["data"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_removes_closed_connections():
|
||||
"""Verify closed connections are removed from connected_clients set.
|
||||
|
||||
The broadcast method's _send_to_client helper removes a client
|
||||
when it raises websockets.ConnectionClosed.
|
||||
"""
|
||||
gateway = make_gateway()
|
||||
closed_ws = DummyWebSocket()
|
||||
open_ws = DummyWebSocket()
|
||||
gateway.connected_clients = {closed_ws, open_ws}
|
||||
|
||||
# Make closed_ws.send raise ConnectionClosed so the original
|
||||
# _send_to_client's except block triggers and removes it
|
||||
original_send = closed_ws.send
|
||||
async def raising_send(payload):
|
||||
raise gateway_module.websockets.ConnectionClosed(None, None)
|
||||
closed_ws.send = raising_send
|
||||
|
||||
try:
|
||||
await gateway.broadcast({"type": "test"})
|
||||
except gateway_module.websockets.ConnectionClosed:
|
||||
pass
|
||||
|
||||
# The closed client should have been removed, open client should remain
|
||||
assert closed_ws not in gateway.connected_clients
|
||||
assert open_ws in gateway.connected_clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_sends_to_all_connected_clients():
|
||||
"""Verify broadcast sends to all connected clients and collects results."""
|
||||
gateway = make_gateway()
|
||||
ws1 = DummyWebSocket()
|
||||
ws2 = DummyWebSocket()
|
||||
gateway.connected_clients = {ws1, ws2}
|
||||
|
||||
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||
|
||||
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||
assert ws1.messages[0]["data"] == "test"
|
||||
assert ws2.messages[0]["data"] == "test"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stop
|
||||
# =============================================================================
|
||||
|
||||
def test_stop_gateway_calls_cycle_support():
|
||||
gateway = make_gateway()
|
||||
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
|
||||
gateway.stop()
|
||||
mock.assert_called_once_with(gateway)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# set_backtest_dates
|
||||
# =============================================================================
|
||||
|
||||
def test_set_backtest_dates_delegates_to_cycle_support():
|
||||
gateway = make_gateway()
|
||||
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
|
||||
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
|
||||
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider usage change callback
|
||||
# =============================================================================
|
||||
|
||||
def test_on_provider_usage_changed_updates_state_sync():
|
||||
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
|
||||
gateway = make_gateway()
|
||||
gateway._loop = None # no loop set
|
||||
|
||||
snapshot = {"provider": "finnhub", "calls": 10}
|
||||
gateway._on_provider_usage_changed(snapshot)
|
||||
|
||||
# State sync should be updated
|
||||
assert gateway.state_sync.state.get("data_sources") == snapshot
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# handle_client lifecycle
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_adds_and_removes_client_from_connected_set():
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||
await gateway.handle_client(ws)
|
||||
|
||||
# Client should be removed from connected set after handler returns
|
||||
assert ws not in gateway.connected_clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_adds_client_before_handler():
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||
await gateway.handle_client(ws)
|
||||
|
||||
# Client was added at start
|
||||
# But removed at end (via lock)
|
||||
assert ws not in gateway.connected_clients
|
||||
@@ -1,14 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted runtime service app surface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.api import runtime as runtime_module
|
||||
from backend.apps.runtime_service import create_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_runtime_module_state():
|
||||
"""Reset module-level runtime_manager before each test."""
|
||||
runtime_module.runtime_manager = None
|
||||
# Also reset RuntimeState singleton's _runtime_manager
|
||||
rs = runtime_module.get_runtime_state()
|
||||
rs._runtime_manager = None
|
||||
yield
|
||||
runtime_module.runtime_manager = None
|
||||
rs = runtime_module.get_runtime_state()
|
||||
rs._runtime_manager = None
|
||||
|
||||
|
||||
def test_runtime_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
@@ -153,7 +170,9 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
)
|
||||
|
||||
class _DummyContext:
|
||||
def __init__(self):
|
||||
def __init__(self, run_dir):
|
||||
self.config_name = "demo"
|
||||
self.run_dir = run_dir
|
||||
self.bootstrap_values = {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
@@ -165,8 +184,17 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
class _DummyManager:
|
||||
def __init__(self):
|
||||
self.config_name = "demo"
|
||||
self.bootstrap = dict(_DummyContext().bootstrap_values)
|
||||
self.context = _DummyContext()
|
||||
self.bootstrap = dict(_DummyContext(run_dir).bootstrap_values)
|
||||
self.context = _DummyContext(run_dir)
|
||||
|
||||
def build_snapshot(self):
|
||||
return {
|
||||
"context": {
|
||||
"config_name": self.context.config_name,
|
||||
"run_dir": str(self.context.run_dir),
|
||||
"bootstrap_values": self.context.bootstrap_values,
|
||||
}
|
||||
}
|
||||
|
||||
def _persist_snapshot(self):
|
||||
return None
|
||||
@@ -192,3 +220,385 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 15
|
||||
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RuntimeState singleton unit tests
|
||||
# =============================================================================
|
||||
|
||||
def test_runtime_state_is_singleton():
|
||||
"""RuntimeState.__new__ returns the same instance across calls."""
|
||||
state1 = runtime_module.RuntimeState()
|
||||
state2 = runtime_module.RuntimeState()
|
||||
assert state1 is state2
|
||||
|
||||
|
||||
def test_runtime_state_get_runtime_state_returns_same_instance():
|
||||
"""get_runtime_state() returns the module singleton."""
|
||||
instance = runtime_module.get_runtime_state()
|
||||
assert instance is runtime_module._runtime_state
|
||||
|
||||
|
||||
def test_runtime_state_default_values():
|
||||
"""RuntimeState initializes with sensible defaults on first instantiation."""
|
||||
# Reset singleton to get fresh __init__ values
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state._runtime_manager is None
|
||||
assert state._gateway_process is None
|
||||
assert state._gateway_port == 8765
|
||||
|
||||
|
||||
def test_runtime_state_gateway_port_property():
|
||||
"""gateway_port property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
state.gateway_port = 9999
|
||||
assert state.gateway_port == 9999
|
||||
state.gateway_port = 1234
|
||||
assert state.gateway_port == 1234
|
||||
|
||||
|
||||
def test_runtime_state_gateway_process_property():
|
||||
"""gateway_process property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state.gateway_process is None
|
||||
|
||||
fake_process = object()
|
||||
state.gateway_process = fake_process
|
||||
assert state.gateway_process is fake_process
|
||||
|
||||
state.gateway_process = None
|
||||
assert state.gateway_process is None
|
||||
|
||||
|
||||
def test_runtime_state_runtime_manager_property():
|
||||
"""runtime_manager property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state.runtime_manager is None
|
||||
|
||||
fake_manager = object()
|
||||
state.runtime_manager = fake_manager
|
||||
assert state.runtime_manager is fake_manager
|
||||
|
||||
state.runtime_manager = None
|
||||
assert state.runtime_manager is None
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
|
||||
def test_runtime_state_lock_property_is_async():
|
||||
"""lock is an async property that returns a coroutine producing an asyncio.Lock."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
lock_coro = state.lock
|
||||
assert asyncio.iscoroutine(lock_coro)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_gateway_port():
|
||||
"""Async setters and getters for gateway_port with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_gateway_port(8888)
|
||||
assert await state.get_gateway_port() == 8888
|
||||
await state.set_gateway_port(7777)
|
||||
assert await state.get_gateway_port() == 7777
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_gateway_process():
|
||||
"""Async setters and getters for gateway_process with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_gateway_process(None)
|
||||
assert await state.get_gateway_process() is None
|
||||
|
||||
fake_process = object()
|
||||
await state.set_gateway_process(fake_process)
|
||||
assert await state.get_gateway_process() is fake_process
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_runtime_manager():
|
||||
"""Async setters and getters for runtime_manager with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_runtime_manager(None)
|
||||
assert await state.get_runtime_manager() is None
|
||||
|
||||
fake_manager = object()
|
||||
await state.set_runtime_manager(fake_manager)
|
||||
assert await state.get_runtime_manager() is fake_manager
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_gateway_running helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_is_gateway_running_returns_false_when_process_is_none():
|
||||
"""_is_gateway_running returns False when gateway_process is None."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
new_state = runtime_module.RuntimeState()
|
||||
new_state._gateway_process = None
|
||||
runtime_module._runtime_state = new_state
|
||||
assert runtime_module._is_gateway_running() is False
|
||||
|
||||
|
||||
def test_is_gateway_running_returns_false_when_process_exited():
|
||||
"""_is_gateway_running returns False when process has terminated."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = 1 # non-None = process has exited
|
||||
|
||||
state._gateway_process = mock_process
|
||||
assert runtime_module._is_gateway_running() is False
|
||||
|
||||
|
||||
def test_is_gateway_running_returns_true_when_process_running():
|
||||
"""_is_gateway_running returns True when process is alive."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None # None = still running
|
||||
|
||||
state._gateway_process = mock_process
|
||||
assert runtime_module._is_gateway_running() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _stop_gateway helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_stop_gateway_returns_false_when_no_process():
|
||||
"""_stop_gateway returns False if no gateway process exists."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
new_state = runtime_module.RuntimeState()
|
||||
new_state._gateway_process = None
|
||||
runtime_module._runtime_state = new_state
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_stop_gateway_sets_process_to_none_after_stop():
|
||||
"""_stop_gateway sets _gateway_process to None after stopping."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None
|
||||
mock_process.wait.return_value = 0
|
||||
|
||||
state._gateway_process = mock_process
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
|
||||
assert result is True
|
||||
assert state._gateway_process is None
|
||||
mock_process.terminate.assert_called_once()
|
||||
mock_process.wait.assert_called_once()
|
||||
|
||||
|
||||
def test_stop_gateway_kills_when_terminate_times_out():
|
||||
"""_stop_gateway kills the process if terminate times out."""
|
||||
import subprocess
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None
|
||||
mock_process.wait.side_effect = subprocess.TimeoutExpired("cmd", 5)
|
||||
mock_process.kill.return_value = None
|
||||
|
||||
state._gateway_process = mock_process
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
|
||||
assert result is True
|
||||
assert state._gateway_process is None
|
||||
mock_process.kill.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _build_gateway_ws_url helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_build_gateway_ws_url_defaults_to_ws():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "localhost"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "ws://localhost:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_uses_wss_for_https():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||
mock_request.url.scheme = "https"
|
||||
mock_request.url.hostname = "example.com"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "wss://example.com:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_respects_forwarded_proto():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
def header_get(key, default=""):
|
||||
if key == "x-forwarded-proto":
|
||||
return "https"
|
||||
return default
|
||||
mock_request.headers.get.side_effect = header_get
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "internal.example"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "wss://internal.example:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_respects_forwarded_host():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": {
|
||||
"x-forwarded-host": "external.example.com"
|
||||
}.get(k, d)
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "internal.example"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "ws://external.example.com:8765"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _normalize_runtime_config_updates tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_schedule_mode():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode="invalid")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "schedule_mode" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_schedule_mode_values():
|
||||
for invalid in ["weekly", "monthly", "once"]:
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=invalid)
|
||||
with pytest.raises(HTTPException):
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_accepts_daily_and_intraday():
|
||||
for valid in ["daily", "intraday", "DAILY", "IntraDay"]:
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=valid)
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "schedule_mode" in result
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_trigger_time_format():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="25:99")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "trigger_time" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_accepts_now_trigger_time():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="now")
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["trigger_time"] == "now"
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_defaults_empty_trigger_time():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time=" ")
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["trigger_time"] == "09:30"
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_rejects_no_updates():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "no runtime config updates" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_coerces_types():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(
|
||||
schedule_mode="intraday",
|
||||
interval_minutes="30", # string from JSON
|
||||
initial_cash="50000.0", # string from JSON
|
||||
margin_requirement="0.25",
|
||||
)
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["schedule_mode"] == "intraday"
|
||||
assert result["interval_minutes"] == 30
|
||||
assert result["initial_cash"] == 50000.0
|
||||
assert result["margin_requirement"] == 0.25
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# register_runtime_manager / unregister_runtime_manager tests
|
||||
# =============================================================================
|
||||
|
||||
def test_register_runtime_manager_sets_module_and_singleton():
|
||||
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||
fake_manager = object()
|
||||
|
||||
runtime_module.register_runtime_manager(fake_manager)
|
||||
|
||||
assert runtime_module.runtime_manager is fake_manager
|
||||
assert runtime_module._runtime_state.runtime_manager is fake_manager
|
||||
|
||||
|
||||
def test_unregister_runtime_manager_clears_module_and_singleton():
|
||||
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||
runtime_module._runtime_state.runtime_manager = object()
|
||||
runtime_module.runtime_manager = object()
|
||||
|
||||
runtime_module.unregister_runtime_manager()
|
||||
|
||||
assert runtime_module.runtime_manager is None
|
||||
assert runtime_module._runtime_state.runtime_manager is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _generate_run_id tests
|
||||
# =============================================================================
|
||||
|
||||
def test_generate_run_id_returns_timestamp_format():
|
||||
run_id = runtime_module._generate_run_id()
|
||||
# Format: YYYYMMDD_HHMMSS - length is 15
|
||||
assert len(run_id) == 15
|
||||
assert run_id[8] == "_" # separator between date and time
|
||||
assert run_id[:8].isdigit() # YYYYMMDD
|
||||
assert run_id[9:].isdigit() # HHMMSS
|
||||
|
||||
@@ -1,47 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the trading domain helpers."""
|
||||
"""Unit tests for data_tools functions (replaces the deleted trading_domain)."""
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
|
||||
|
||||
def test_trading_domain_payload_wrappers(monkeypatch):
|
||||
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
|
||||
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
|
||||
|
||||
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"prices": [{"close": 1}],
|
||||
}
|
||||
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"financial_metrics": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"news": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"insider_trades": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-16",
|
||||
"market_cap": 2.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_get_market_status_payload_uses_market_service(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def __init__(self, tickers):
|
||||
self.tickers = tickers
|
||||
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
|
||||
|
||||
assert trading_domain.get_market_status_payload() == {
|
||||
"status": "open",
|
||||
"status_text": "Open",
|
||||
}
|
||||
def test_data_tools_functions_exist():
|
||||
"""Verify that all data_tools functions are importable and callable."""
|
||||
assert callable(get_prices)
|
||||
assert callable(get_financial_metrics)
|
||||
assert callable(get_company_news)
|
||||
assert callable(get_insider_trades)
|
||||
assert callable(get_market_cap)
|
||||
assert callable(search_line_items)
|
||||
|
||||
@@ -24,20 +24,17 @@ def test_trading_service_routes_are_exposed():
|
||||
|
||||
def test_trading_service_prices_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_prices_payload",
|
||||
lambda ticker, start_date, end_date: {
|
||||
"ticker": ticker,
|
||||
"prices": [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
},
|
||||
"backend.apps.trading_service.get_prices",
|
||||
lambda ticker, start_date, end_date: [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -57,56 +54,54 @@ def test_trading_service_prices_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_financials_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_financials_payload",
|
||||
lambda ticker, end_date, period, limit: {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_financial_metrics",
|
||||
lambda ticker, end_date, period, limit: [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -121,26 +116,22 @@ def test_trading_service_financials_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_news_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_company_news",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_insider_trades_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_insider_trades",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -165,8 +156,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_status_payload",
|
||||
lambda: _FakeMarketService().get_market_status(),
|
||||
"backend.apps.trading_service.MarketService",
|
||||
lambda tickers: _FakeMarketService(),
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -178,12 +169,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_cap_payload",
|
||||
lambda ticker, end_date: {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": 3.5e12,
|
||||
},
|
||||
"backend.apps.trading_service.get_market_cap",
|
||||
lambda ticker, end_date: 3.5e12,
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -202,18 +189,16 @@ def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_line_items_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_line_items_payload",
|
||||
lambda ticker, line_items, end_date, period, limit: {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.search_line_items",
|
||||
lambda ticker, line_items, end_date, period, limit: [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
|
||||
2931
frontend/src/App.jsx
2931
frontend/src/App.jsx
File diff suppressed because it is too large
Load Diff
18
frontend/src/components/ChartTabs.jsx
Normal file
18
frontend/src/components/ChartTabs.jsx
Normal file
@@ -0,0 +1,18 @@
|
||||
import React from 'react';
|
||||
|
||||
export default function ChartTabs({
|
||||
chartTab,
|
||||
setChartTab,
|
||||
isLiveEnabled
|
||||
}) {
|
||||
return (
|
||||
<div className="chart-tabs-floating">
|
||||
<button
|
||||
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
|
||||
onClick={() => setChartTab('all')}
|
||||
>
|
||||
日线
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
293
frontend/src/components/HeaderRight.jsx
Normal file
293
frontend/src/components/HeaderRight.jsx
Normal file
@@ -0,0 +1,293 @@
|
||||
import React from 'react';
|
||||
import RuntimeSettingsPanel from './RuntimeSettingsPanel.jsx';
|
||||
|
||||
export default function HeaderRight({
|
||||
// Connection state
|
||||
isConnected,
|
||||
// Virtual time
|
||||
virtualTime,
|
||||
now,
|
||||
// Market & server
|
||||
marketStatus,
|
||||
marketStatusLabel,
|
||||
serverMode,
|
||||
// Labels
|
||||
runtimeSummaryLabel,
|
||||
livePriceSourceLabel,
|
||||
historicalPriceSourceLabel,
|
||||
// Settings state
|
||||
isRuntimeSettingsOpen,
|
||||
isRuntimeConfigSaving,
|
||||
isWatchlistSaving,
|
||||
runtimeConfigFeedback,
|
||||
watchlistFeedback,
|
||||
// Settings panel props
|
||||
scheduleModeDraft,
|
||||
intervalMinutesDraft,
|
||||
triggerTimeDraft,
|
||||
maxCommCyclesDraft,
|
||||
initialCashDraft,
|
||||
marginRequirementDraft,
|
||||
enableMemoryDraft,
|
||||
modeDraft,
|
||||
pollIntervalDraft,
|
||||
startDateDraft,
|
||||
endDateDraft,
|
||||
enableMockDraft,
|
||||
watchlistDraftSymbols,
|
||||
watchlistInputValue,
|
||||
watchlistSuggestions,
|
||||
// Callbacks
|
||||
onRuntimeSettingsToggle,
|
||||
onCloseSettings,
|
||||
onScheduleModeChange,
|
||||
onIntervalMinutesChange,
|
||||
onTriggerTimeChange,
|
||||
onMaxCommCyclesChange,
|
||||
onInitialCashChange,
|
||||
onMarginRequirementChange,
|
||||
onEnableMemoryChange,
|
||||
onModeChange,
|
||||
onPollIntervalChange,
|
||||
onStartDateChange,
|
||||
onEndDateChange,
|
||||
onEnableMockChange,
|
||||
onWatchlistInputChange,
|
||||
onWatchlistInputKeyDown,
|
||||
onWatchlistAdd,
|
||||
onWatchlistRemove,
|
||||
onWatchlistRestoreCurrent,
|
||||
onWatchlistRestoreDefault,
|
||||
onWatchlistSuggestionClick,
|
||||
onLaunchConfigSave,
|
||||
onRestoreDefaults,
|
||||
onManualTrigger,
|
||||
clientRef
|
||||
}) {
|
||||
return (
|
||||
<div className="header-right" style={{ display: 'flex', alignItems: 'center', gap: 24, marginLeft: 'auto', flexWrap: 'wrap', minWidth: 0 }}>
|
||||
{/* Mock Mode Indicator */}
|
||||
{virtualTime && (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 6,
|
||||
padding: '4px 10px',
|
||||
borderRadius: 4,
|
||||
background: '#FF9800',
|
||||
border: '1px solid #FFB74D'
|
||||
}}>
|
||||
<span style={{
|
||||
fontSize: '11px',
|
||||
fontWeight: 600,
|
||||
color: '#FFFFFF',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
letterSpacing: '0.5px'
|
||||
}}>
|
||||
模拟模式
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Clock Display (only in Mock mode) */}
|
||||
{virtualTime && (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 8
|
||||
}}>
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'flex-end',
|
||||
gap: 2,
|
||||
padding: '4px 12px',
|
||||
borderRadius: 4,
|
||||
background: '#1A237E',
|
||||
border: '1px solid #3F51B5'
|
||||
}}>
|
||||
<span style={{
|
||||
fontSize: '11px',
|
||||
color: '#999',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
textTransform: 'uppercase',
|
||||
letterSpacing: '0.5px'
|
||||
}}>
|
||||
虚拟时间
|
||||
</span>
|
||||
<span style={{
|
||||
fontSize: '14px',
|
||||
fontWeight: 700,
|
||||
color: '#FFFFFF',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
letterSpacing: '1px'
|
||||
}}>
|
||||
{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}
|
||||
</span>
|
||||
<span style={{
|
||||
fontSize: '10px',
|
||||
color: '#999',
|
||||
fontFamily: '"Courier New", monospace'
|
||||
}}>
|
||||
{now.toLocaleDateString('en-US', { month: 'short', day: 'numeric', year: 'numeric' })}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Fast Forward Button (only in Mock mode) */}
|
||||
<button
|
||||
onClick={() => {
|
||||
if (clientRef.current) {
|
||||
const success = clientRef.current.send({
|
||||
type: 'fast_forward_time',
|
||||
minutes: 30
|
||||
});
|
||||
if (!success) {
|
||||
console.error('Failed to send fast forward request');
|
||||
}
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
padding: '6px 12px',
|
||||
borderRadius: 4,
|
||||
background: '#3F51B5',
|
||||
border: '1px solid #5C6BC0',
|
||||
color: '#FFFFFF',
|
||||
fontSize: '12px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontWeight: 600,
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 4,
|
||||
textTransform: 'uppercase',
|
||||
letterSpacing: '0.5px'
|
||||
}}
|
||||
onMouseEnter={(e) => {
|
||||
e.target.style.background = '#5C6BC0';
|
||||
e.target.style.borderColor = '#7986CB';
|
||||
}}
|
||||
onMouseLeave={(e) => {
|
||||
e.target.style.background = '#3F51B5';
|
||||
e.target.style.borderColor = '#5C6BC0';
|
||||
}}
|
||||
title="快进30分钟 (Mock模式)"
|
||||
>
|
||||
+30min
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Unified Status Indicator */}
|
||||
<div className="header-status-inline">
|
||||
<span className={`status-dot ${isConnected ? 'live' : 'offline'}`} />
|
||||
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
|
||||
{isConnected ? '在线' : '离线'}
|
||||
</span>
|
||||
{marketStatus && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
|
||||
{marketStatusLabel}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
{livePriceSourceLabel && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className="market-text backtest">
|
||||
{livePriceSourceLabel}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
{historicalPriceSourceLabel && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className="market-text backtest">
|
||||
{historicalPriceSourceLabel}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
{runtimeSummaryLabel && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className="market-text backtest" title="当前运行配置">
|
||||
{runtimeSummaryLabel}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
<span className="status-sep">·</span>
|
||||
<span className="time-text">{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}</span>
|
||||
</div>
|
||||
|
||||
{serverMode !== 'backtest' && (
|
||||
<button
|
||||
onClick={onManualTrigger}
|
||||
disabled={!isConnected}
|
||||
style={{
|
||||
padding: '6px 12px',
|
||||
borderRadius: 4,
|
||||
background: isConnected ? '#111111' : '#8a8a8a',
|
||||
border: '1px solid #111111',
|
||||
color: '#FFFFFF',
|
||||
fontSize: '11px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontWeight: 700,
|
||||
cursor: isConnected ? 'pointer' : 'not-allowed',
|
||||
letterSpacing: '0.4px',
|
||||
textTransform: 'uppercase'
|
||||
}}
|
||||
title="手动触发一轮分析与交易决策"
|
||||
>
|
||||
手动运行
|
||||
</button>
|
||||
)}
|
||||
|
||||
<RuntimeSettingsPanel
|
||||
showTrigger={false}
|
||||
isOpen={isRuntimeSettingsOpen}
|
||||
isConnected={isConnected}
|
||||
isSaving={isRuntimeConfigSaving || isWatchlistSaving}
|
||||
feedback={runtimeConfigFeedback || watchlistFeedback}
|
||||
scheduleMode={scheduleModeDraft}
|
||||
intervalMinutes={intervalMinutesDraft}
|
||||
triggerTime={triggerTimeDraft}
|
||||
maxCommCycles={maxCommCyclesDraft}
|
||||
initialCash={initialCashDraft}
|
||||
marginRequirement={marginRequirementDraft}
|
||||
enableMemory={enableMemoryDraft}
|
||||
mode={modeDraft}
|
||||
pollInterval={pollIntervalDraft}
|
||||
startDate={startDateDraft}
|
||||
endDate={endDateDraft}
|
||||
enableMock={enableMockDraft}
|
||||
watchlistSymbols={watchlistDraftSymbols}
|
||||
watchlistInputValue={watchlistInputValue}
|
||||
watchlistSuggestions={watchlistSuggestions}
|
||||
onToggle={onRuntimeSettingsToggle}
|
||||
onClose={onCloseSettings}
|
||||
onScheduleModeChange={onScheduleModeChange}
|
||||
onIntervalMinutesChange={onIntervalMinutesChange}
|
||||
onTriggerTimeChange={onTriggerTimeChange}
|
||||
onMaxCommCyclesChange={onMaxCommCyclesChange}
|
||||
onInitialCashChange={onInitialCashChange}
|
||||
onMarginRequirementChange={onMarginRequirementChange}
|
||||
onEnableMemoryChange={onEnableMemoryChange}
|
||||
onModeChange={onModeChange}
|
||||
onPollIntervalChange={onPollIntervalChange}
|
||||
onStartDateChange={onStartDateChange}
|
||||
onEndDateChange={onEndDateChange}
|
||||
onEnableMockChange={onEnableMockChange}
|
||||
onWatchlistInputChange={onWatchlistInputChange}
|
||||
onWatchlistInputKeyDown={onWatchlistInputKeyDown}
|
||||
onWatchlistAdd={onWatchlistAdd}
|
||||
onWatchlistRemove={onWatchlistRemove}
|
||||
onWatchlistRestoreCurrent={onWatchlistRestoreCurrent}
|
||||
onWatchlistRestoreDefault={onWatchlistRestoreDefault}
|
||||
onWatchlistSuggestionClick={onWatchlistSuggestionClick}
|
||||
onSave={onLaunchConfigSave}
|
||||
onRestoreDefaults={onRestoreDefaults}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
52
frontend/src/components/TickerBar.jsx
Normal file
52
frontend/src/components/TickerBar.jsx
Normal file
@@ -0,0 +1,52 @@
|
||||
import React from 'react';
|
||||
import StockLogo from './StockLogo';
|
||||
import { formatNumber, formatTickerPrice } from '../utils/formatters';
|
||||
|
||||
export default function TickerBar({
|
||||
displayTickers,
|
||||
rollingTickers,
|
||||
portfolioData,
|
||||
onTickerSelect
|
||||
}) {
|
||||
return (
|
||||
<div className="ticker-bar">
|
||||
<div className="ticker-track">
|
||||
{[0, 1].map((groupIdx) => (
|
||||
<div key={groupIdx} className="ticker-group">
|
||||
{displayTickers.map(ticker => (
|
||||
<div
|
||||
key={`${ticker.symbol}-${groupIdx}`}
|
||||
className="ticker-item"
|
||||
onClick={() => onTickerSelect && onTickerSelect(ticker.symbol)}
|
||||
style={{ cursor: onTickerSelect ? 'pointer' : 'default' }}
|
||||
>
|
||||
<StockLogo ticker={ticker.symbol} size={16} />
|
||||
<span className="ticker-symbol">{ticker.symbol}</span>
|
||||
<span className="ticker-price">
|
||||
<span className={`ticker-price-value ${rollingTickers[ticker.symbol] ? 'rolling' : ''}`}>
|
||||
{ticker.price !== null && ticker.price !== undefined
|
||||
? `$${formatTickerPrice(ticker.price)}`
|
||||
: '-'}
|
||||
</span>
|
||||
</span>
|
||||
<span className={`ticker-change ${
|
||||
ticker.change === null || ticker.change === undefined
|
||||
? ''
|
||||
: ticker.change >= 0 ? 'positive' : 'negative'
|
||||
}`}>
|
||||
{ticker.change !== null && ticker.change !== undefined
|
||||
? `${ticker.change >= 0 ? '+' : ''}${ticker.change.toFixed(2)}%`
|
||||
: '-'}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div className="portfolio-value">
|
||||
<span className="portfolio-label">投资组合</span>
|
||||
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
308
frontend/src/hooks/useAgentCallbacks.js
Normal file
308
frontend/src/hooks/useAgentCallbacks.js
Normal file
@@ -0,0 +1,308 @@
|
||||
import { useCallback, useEffect } from 'react';
|
||||
import { uploadAgentSkillZip } from '../services/runtimeApi';
|
||||
|
||||
/**
|
||||
* Extracts agent/skill-related callbacks from App.jsx into a single hook.
|
||||
*/
|
||||
export function useAgentCallbacks({
|
||||
clientRef,
|
||||
selectedSkillAgentId,
|
||||
selectedWorkspaceFile,
|
||||
workspaceDraftContent,
|
||||
agentProfilesByAgent,
|
||||
agentSkillsByAgent,
|
||||
workspaceFilesByAgent,
|
||||
AGENTS,
|
||||
setters
|
||||
}) {
|
||||
const {
|
||||
setIsAgentSkillsLoading,
|
||||
setAgentSkillsFeedback,
|
||||
setSkillDetailLoadingKey,
|
||||
setAgentSkillsSavingKey,
|
||||
setIsWorkspaceFileLoading,
|
||||
setWorkspaceFileSavingKey,
|
||||
setWorkspaceFileFeedback,
|
||||
setLocalSkillDraftsByKey,
|
||||
setAgentSkillsByAgent,
|
||||
setAgentProfilesByAgent,
|
||||
setSkillDetailsByName,
|
||||
setWorkspaceFilesByAgent,
|
||||
setSelectedSkillAgentId,
|
||||
setSelectedWorkspaceFile,
|
||||
setWorkspaceDraftContent
|
||||
} = setters;
|
||||
|
||||
const requestAgentSkills = useCallback((agentId) => {
|
||||
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
setIsAgentSkillsLoading(true);
|
||||
setAgentSkillsFeedback(null);
|
||||
return clientRef.current.send({
|
||||
type: 'get_agent_skills',
|
||||
agent_id: normalized
|
||||
});
|
||||
}, [clientRef, setIsAgentSkillsLoading, setAgentSkillsFeedback]);
|
||||
|
||||
const requestAgentProfile = useCallback((agentId) => {
|
||||
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'get_agent_profile',
|
||||
agent_id: normalized
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
const requestSkillDetail = useCallback((skillName) => {
|
||||
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
const detailKey = `${selectedSkillAgentId}:${normalized}`;
|
||||
setSkillDetailLoadingKey(detailKey);
|
||||
return clientRef.current.send({
|
||||
type: 'get_skill_detail',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: normalized
|
||||
});
|
||||
}, [clientRef, selectedSkillAgentId, setSkillDetailLoadingKey]);
|
||||
|
||||
const requestWorkspaceFile = useCallback((agentId, filename) => {
|
||||
const normalizedAgentId = typeof agentId === 'string' ? agentId.trim() : '';
|
||||
const normalizedFilename = typeof filename === 'string' ? filename.trim() : '';
|
||||
if (!normalizedAgentId || !normalizedFilename || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
setIsWorkspaceFileLoading(true);
|
||||
setWorkspaceFileFeedback(null);
|
||||
return clientRef.current.send({
|
||||
type: 'get_agent_workspace_file',
|
||||
agent_id: normalizedAgentId,
|
||||
filename: normalizedFilename
|
||||
});
|
||||
}, [clientRef, setIsWorkspaceFileLoading, setWorkspaceFileFeedback]);
|
||||
|
||||
const handleCreateLocalSkill = useCallback((skillName) => {
|
||||
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
|
||||
if (!normalized) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '技能名称不能为空' });
|
||||
return;
|
||||
}
|
||||
if (!clientRef.current) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${normalized}:create`);
|
||||
setAgentSkillsFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'create_agent_local_skill',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: normalized
|
||||
});
|
||||
if (!success) {
|
||||
setAgentSkillsSavingKey(null);
|
||||
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
const handleLocalSkillDraftChange = useCallback((skillName, content) => {
|
||||
const detailKey = `${selectedSkillAgentId}:${skillName}`;
|
||||
setLocalSkillDraftsByKey((prev) => ({
|
||||
...prev,
|
||||
[detailKey]: content
|
||||
}));
|
||||
}, [selectedSkillAgentId, setLocalSkillDraftsByKey]);
|
||||
|
||||
const handleLocalSkillSave = useCallback((skillName) => {
|
||||
if (!clientRef.current) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
const detailKey = `${selectedSkillAgentId}:${skillName}`;
|
||||
const content = setters.localSkillDraftsByKey[detailKey];
|
||||
if (typeof content !== 'string') {
|
||||
return;
|
||||
}
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:content`);
|
||||
setAgentSkillsFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'update_agent_local_skill',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: skillName,
|
||||
content
|
||||
});
|
||||
if (!success) {
|
||||
setAgentSkillsSavingKey(null);
|
||||
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, setters.localSkillDraftsByKey, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
const handleLocalSkillDelete = useCallback((skillName) => {
|
||||
if (!clientRef.current) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:delete`);
|
||||
setAgentSkillsFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'delete_agent_local_skill',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: skillName
|
||||
});
|
||||
if (!success) {
|
||||
setAgentSkillsSavingKey(null);
|
||||
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
const handleRemoveSharedSkill = useCallback((skillName) => {
|
||||
if (!clientRef.current) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:remove`);
|
||||
setAgentSkillsFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'remove_agent_skill',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: skillName
|
||||
});
|
||||
if (!success) {
|
||||
setAgentSkillsSavingKey(null);
|
||||
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
const handleAgentSkillToggle = useCallback((skillName, enabled) => {
|
||||
if (!clientRef.current) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}`);
|
||||
setAgentSkillsFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'update_agent_skill',
|
||||
agent_id: selectedSkillAgentId,
|
||||
skill_name: skillName,
|
||||
enabled
|
||||
});
|
||||
|
||||
if (!success) {
|
||||
setAgentSkillsSavingKey(null);
|
||||
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
const handleSkillAgentChange = useCallback((agentId) => {
|
||||
setSelectedSkillAgentId(agentId);
|
||||
requestAgentProfile(agentId);
|
||||
requestAgentSkills(agentId);
|
||||
requestWorkspaceFile(agentId, selectedWorkspaceFile);
|
||||
}, [requestAgentProfile, requestAgentSkills, requestWorkspaceFile, selectedWorkspaceFile, setSelectedSkillAgentId]);
|
||||
|
||||
const handleWorkspaceFileChange = useCallback((filename) => {
|
||||
setSelectedWorkspaceFile(filename);
|
||||
requestWorkspaceFile(selectedSkillAgentId, filename);
|
||||
}, [requestWorkspaceFile, selectedSkillAgentId, setSelectedWorkspaceFile]);
|
||||
|
||||
const handleWorkspaceFileSave = useCallback(() => {
|
||||
if (!clientRef.current) {
|
||||
setWorkspaceFileFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
const key = `${selectedSkillAgentId}:${selectedWorkspaceFile}`;
|
||||
setWorkspaceFileSavingKey(key);
|
||||
setWorkspaceFileFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'update_agent_workspace_file',
|
||||
agent_id: selectedSkillAgentId,
|
||||
filename: selectedWorkspaceFile,
|
||||
content: workspaceDraftContent
|
||||
});
|
||||
if (!success) {
|
||||
setWorkspaceFileSavingKey(null);
|
||||
setWorkspaceFileFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [clientRef, selectedSkillAgentId, selectedWorkspaceFile, workspaceDraftContent, setWorkspaceFileSavingKey, setWorkspaceFileFeedback]);
|
||||
|
||||
const handleUploadExternalSkill = useCallback(async (file) => {
|
||||
if (!(file instanceof File)) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '请选择 zip 文件后再上传' });
|
||||
return;
|
||||
}
|
||||
if (!selectedSkillAgentId) {
|
||||
setAgentSkillsFeedback({ type: 'error', text: '未选择目标 Agent' });
|
||||
return;
|
||||
}
|
||||
setAgentSkillsSavingKey(`${selectedSkillAgentId}:__upload__`);
|
||||
setAgentSkillsFeedback(null);
|
||||
try {
|
||||
const result = await uploadAgentSkillZip({
|
||||
agentId: selectedSkillAgentId,
|
||||
file,
|
||||
activate: true
|
||||
});
|
||||
setAgentSkillsFeedback({
|
||||
type: 'success',
|
||||
text: `已上传并安装技能 ${result.skill_name || ''}`.trim()
|
||||
});
|
||||
requestAgentSkills(selectedSkillAgentId);
|
||||
} catch (error) {
|
||||
setAgentSkillsFeedback({
|
||||
type: 'error',
|
||||
text: `上传失败: ${error.message || '未知错误'}`
|
||||
});
|
||||
} finally {
|
||||
setAgentSkillsSavingKey(null);
|
||||
}
|
||||
}, [selectedSkillAgentId, requestAgentSkills, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||
|
||||
// Sync workspace draft content when selected content changes
|
||||
useEffect(() => {
|
||||
const selectedWorkspaceContent = workspaceFilesByAgent[selectedSkillAgentId]?.[selectedWorkspaceFile] || '';
|
||||
setWorkspaceDraftContent(selectedWorkspaceContent);
|
||||
}, [selectedWorkspaceFile, selectedSkillAgentId, workspaceFilesByAgent, setWorkspaceDraftContent]);
|
||||
|
||||
// Load agent profiles and skills when view changes
|
||||
const currentView = setters.currentView;
|
||||
const isConnected = setters.isConnected;
|
||||
|
||||
useEffect(() => {
|
||||
if (currentView !== 'traders' || !isConnected) {
|
||||
return;
|
||||
}
|
||||
AGENTS.forEach((agent) => {
|
||||
if (!agentProfilesByAgent[agent.id]) {
|
||||
requestAgentProfile(agent.id);
|
||||
}
|
||||
if (!agentSkillsByAgent[agent.id]) {
|
||||
requestAgentSkills(agent.id);
|
||||
}
|
||||
if (!workspaceFilesByAgent[agent.id]?.['MEMORY.md']) {
|
||||
requestWorkspaceFile(agent.id, 'MEMORY.md');
|
||||
}
|
||||
});
|
||||
}, [agentProfilesByAgent, agentSkillsByAgent, currentView, isConnected, requestAgentProfile, requestAgentSkills, requestWorkspaceFile, workspaceFilesByAgent, AGENTS]);
|
||||
|
||||
return {
|
||||
requestAgentSkills,
|
||||
requestAgentProfile,
|
||||
requestSkillDetail,
|
||||
requestWorkspaceFile,
|
||||
handleCreateLocalSkill,
|
||||
handleLocalSkillDraftChange,
|
||||
handleLocalSkillSave,
|
||||
handleLocalSkillDelete,
|
||||
handleRemoveSharedSkill,
|
||||
handleAgentSkillToggle,
|
||||
handleSkillAgentChange,
|
||||
handleWorkspaceFileChange,
|
||||
handleWorkspaceFileSave,
|
||||
handleUploadExternalSkill
|
||||
};
|
||||
}
|
||||
257
frontend/src/hooks/useRuntimeCallbacks.js
Normal file
257
frontend/src/hooks/useRuntimeCallbacks.js
Normal file
@@ -0,0 +1,257 @@
|
||||
import { useCallback } from 'react';
|
||||
import { startRuntime } from '../services/runtimeApi';
|
||||
|
||||
/**
|
||||
* Extracts runtime config callbacks from App.jsx into a single hook.
|
||||
*/
|
||||
export function useRuntimeCallbacks({
|
||||
clientRef,
|
||||
addSystemMessage,
|
||||
parseWatchlistInput,
|
||||
setters
|
||||
}) {
|
||||
const {
|
||||
setScheduleModeDraft,
|
||||
setIntervalMinutesDraft,
|
||||
setTriggerTimeDraft,
|
||||
setMaxCommCyclesDraft,
|
||||
setInitialCashDraft,
|
||||
setMarginRequirementDraft,
|
||||
setEnableMemoryDraft,
|
||||
setModeDraft,
|
||||
setPollIntervalDraft,
|
||||
setStartDateDraft,
|
||||
setEndDateDraft,
|
||||
setEnableMockDraft,
|
||||
setRuntimeConfigFeedback,
|
||||
setIsRuntimeConfigSaving,
|
||||
setIsWatchlistSaving,
|
||||
setIsRuntimeSettingsOpen,
|
||||
watchlistDraftSymbols,
|
||||
watchlistInputValue,
|
||||
scheduleModeDraft,
|
||||
intervalMinutesDraft,
|
||||
maxCommCyclesDraft,
|
||||
initialCashDraft,
|
||||
marginRequirementDraft,
|
||||
enableMemoryDraft,
|
||||
modeDraft,
|
||||
pollIntervalDraft,
|
||||
startDateDraft,
|
||||
endDateDraft,
|
||||
enableMockDraft
|
||||
} = setters;
|
||||
|
||||
const handleRuntimeConfigSave = useCallback(() => {
|
||||
if (!clientRef.current) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = Number(intervalMinutesDraft);
|
||||
const maxCommCycles = Number(maxCommCyclesDraft);
|
||||
if (!Number.isInteger(interval) || interval <= 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
|
||||
return;
|
||||
}
|
||||
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsRuntimeConfigSaving(true);
|
||||
setRuntimeConfigFeedback(null);
|
||||
const success = clientRef.current.send({
|
||||
type: 'update_runtime_config',
|
||||
schedule_mode: scheduleModeDraft,
|
||||
interval_minutes: interval,
|
||||
trigger_time: triggerTimeDraft,
|
||||
max_comm_cycles: maxCommCycles,
|
||||
initial_cash: Number(initialCashDraft),
|
||||
margin_requirement: Number(marginRequirementDraft),
|
||||
enable_memory: Boolean(enableMemoryDraft)
|
||||
});
|
||||
|
||||
if (!success) {
|
||||
setIsRuntimeConfigSaving(false);
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [
|
||||
clientRef,
|
||||
intervalMinutesDraft,
|
||||
maxCommCyclesDraft,
|
||||
scheduleModeDraft,
|
||||
triggerTimeDraft,
|
||||
initialCashDraft,
|
||||
marginRequirementDraft,
|
||||
enableMemoryDraft,
|
||||
setIsRuntimeConfigSaving,
|
||||
setRuntimeConfigFeedback
|
||||
]);
|
||||
|
||||
const handleLaunchConfigSave = useCallback(async () => {
|
||||
const pendingTickers = parseWatchlistInput(watchlistInputValue);
|
||||
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
|
||||
if (nextTickers.length === 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
|
||||
return;
|
||||
}
|
||||
|
||||
const interval = Number(intervalMinutesDraft);
|
||||
const maxCommCycles = Number(maxCommCyclesDraft);
|
||||
const initialCash = Number(initialCashDraft);
|
||||
const marginRequirement = Number(marginRequirementDraft);
|
||||
if (!Number.isInteger(interval) || interval <= 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
|
||||
return;
|
||||
}
|
||||
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
|
||||
return;
|
||||
}
|
||||
if (!Number.isFinite(initialCash) || initialCash <= 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '初始资金必须是正数' });
|
||||
return;
|
||||
}
|
||||
if (!Number.isFinite(marginRequirement) || marginRequirement < 0) {
|
||||
setRuntimeConfigFeedback({ type: 'error', text: '保证金要求不能为负数' });
|
||||
return;
|
||||
}
|
||||
|
||||
setIsRuntimeConfigSaving(true);
|
||||
setIsWatchlistSaving(true);
|
||||
setRuntimeConfigFeedback(null);
|
||||
setters.setWatchlistFeedback(null);
|
||||
setters.setWatchlistDraftSymbols(nextTickers);
|
||||
setters.setWatchlistInputValue('');
|
||||
|
||||
try {
|
||||
const result = await startRuntime({
|
||||
tickers: nextTickers,
|
||||
schedule_mode: scheduleModeDraft,
|
||||
interval_minutes: interval,
|
||||
trigger_time: triggerTimeDraft,
|
||||
max_comm_cycles: maxCommCycles,
|
||||
initial_cash: initialCash,
|
||||
margin_requirement: marginRequirement,
|
||||
enable_memory: Boolean(enableMemoryDraft),
|
||||
mode: modeDraft || 'live',
|
||||
poll_interval: Number(pollIntervalDraft) || 10,
|
||||
start_date: startDateDraft || null,
|
||||
end_date: endDateDraft || null,
|
||||
enable_mock: Boolean(enableMockDraft)
|
||||
});
|
||||
|
||||
setIsRuntimeConfigSaving(false);
|
||||
setIsWatchlistSaving(false);
|
||||
setIsRuntimeSettingsOpen(false);
|
||||
setRuntimeConfigFeedback({
|
||||
type: 'success',
|
||||
text: `任务已启动: ${result.run_id}`
|
||||
});
|
||||
addSystemMessage(`新任务已启动: ${result.run_id}`);
|
||||
} catch (error) {
|
||||
setIsRuntimeConfigSaving(false);
|
||||
setIsWatchlistSaving(false);
|
||||
setRuntimeConfigFeedback({
|
||||
type: 'error',
|
||||
text: `启动失败: ${error.message}`
|
||||
});
|
||||
}
|
||||
}, [
|
||||
parseWatchlistInput,
|
||||
watchlistInputValue,
|
||||
watchlistDraftSymbols,
|
||||
intervalMinutesDraft,
|
||||
maxCommCyclesDraft,
|
||||
initialCashDraft,
|
||||
marginRequirementDraft,
|
||||
enableMemoryDraft,
|
||||
scheduleModeDraft,
|
||||
triggerTimeDraft,
|
||||
modeDraft,
|
||||
pollIntervalDraft,
|
||||
startDateDraft,
|
||||
endDateDraft,
|
||||
enableMockDraft,
|
||||
setters,
|
||||
setIsRuntimeConfigSaving,
|
||||
setIsWatchlistSaving,
|
||||
setRuntimeConfigFeedback,
|
||||
setIsRuntimeSettingsOpen,
|
||||
addSystemMessage
|
||||
]);
|
||||
|
||||
const handleRuntimeDefaultsRestore = useCallback(() => {
|
||||
setScheduleModeDraft('daily');
|
||||
setIntervalMinutesDraft('60');
|
||||
setTriggerTimeDraft('09:30');
|
||||
setMaxCommCyclesDraft('2');
|
||||
setInitialCashDraft('100000');
|
||||
setMarginRequirementDraft('0');
|
||||
setEnableMemoryDraft(false);
|
||||
setModeDraft('live');
|
||||
setPollIntervalDraft('10');
|
||||
setStartDateDraft('');
|
||||
setEndDateDraft('');
|
||||
setEnableMockDraft(false);
|
||||
setRuntimeConfigFeedback(null);
|
||||
}, [
|
||||
setScheduleModeDraft,
|
||||
setIntervalMinutesDraft,
|
||||
setTriggerTimeDraft,
|
||||
setMaxCommCyclesDraft,
|
||||
setInitialCashDraft,
|
||||
setMarginRequirementDraft,
|
||||
setEnableMemoryDraft,
|
||||
setModeDraft,
|
||||
setPollIntervalDraft,
|
||||
setStartDateDraft,
|
||||
setEndDateDraft,
|
||||
setEnableMockDraft,
|
||||
setRuntimeConfigFeedback
|
||||
]);
|
||||
|
||||
const handleRuntimeSettingsToggle = useCallback(() => {
|
||||
setRuntimeConfigFeedback(null);
|
||||
setters.setAgentSkillsFeedback(null);
|
||||
setters.setWorkspaceFileFeedback(null);
|
||||
setIsRuntimeSettingsOpen((prev) => {
|
||||
const nextOpen = !prev;
|
||||
if (nextOpen) {
|
||||
// Initialize watchlist draft when opening settings
|
||||
setters.setWatchlistDraftSymbols(settlers.runtimeWatchlistSymbols);
|
||||
setters.setWatchlistInputValue('');
|
||||
setters.setWatchlistFeedback(null);
|
||||
}
|
||||
return nextOpen;
|
||||
});
|
||||
setters.setIsWatchlistPanelOpen(false);
|
||||
}, [setRuntimeConfigFeedback, setters, setIsRuntimeSettingsOpen]);
|
||||
|
||||
const handleManualTrigger = useCallback(() => {
|
||||
if (!clientRef.current) {
|
||||
addSystemMessage('连接未就绪,无法手动触发');
|
||||
return;
|
||||
}
|
||||
|
||||
const success = clientRef.current.send({
|
||||
type: 'trigger_strategy'
|
||||
});
|
||||
|
||||
if (!success) {
|
||||
addSystemMessage('手动触发发送失败,请检查连接状态');
|
||||
return;
|
||||
}
|
||||
|
||||
addSystemMessage('已发送手动触发请求');
|
||||
}, [clientRef, addSystemMessage]);
|
||||
|
||||
return {
|
||||
handleRuntimeConfigSave,
|
||||
handleLaunchConfigSave,
|
||||
handleRuntimeDefaultsRestore,
|
||||
handleRuntimeSettingsToggle,
|
||||
handleManualTrigger
|
||||
};
|
||||
}
|
||||
584
frontend/src/hooks/useStockRequestCallbacks.js
Normal file
584
frontend/src/hooks/useStockRequestCallbacks.js
Normal file
@@ -0,0 +1,584 @@
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
fetchNewsCategoriesDirect,
|
||||
fetchNewsForDateDirect,
|
||||
fetchRangeExplainDirect,
|
||||
fetchSimilarDaysDirect,
|
||||
fetchStockStoryDirect,
|
||||
hasDirectNewsService
|
||||
} from '../services/newsApi';
|
||||
import {
|
||||
fetchInsiderTradesDirect,
|
||||
fetchStockHistoryDirect,
|
||||
hasDirectTradingService
|
||||
} from '../services/tradingApi';
|
||||
|
||||
/**
|
||||
* Extracts all requestStock* callbacks from App.jsx into a single hook.
|
||||
*/
|
||||
export function useStockRequestCallbacks({
|
||||
clientRef,
|
||||
currentDate,
|
||||
requestedStockHistoryRef,
|
||||
setters,
|
||||
apiHelpers
|
||||
}) {
|
||||
const {
|
||||
setOhlcHistoryByTicker,
|
||||
setHistorySourceByTicker,
|
||||
setExplainEventsByTicker,
|
||||
setNewsByTicker,
|
||||
setInsiderTradesByTicker,
|
||||
setTechnicalIndicatorsByTicker,
|
||||
setPriceHistoryByTicker
|
||||
} = setters;
|
||||
|
||||
const {
|
||||
hasDirectTradingService: _hasDirectTradingService,
|
||||
fetchStockHistoryDirect: _fetchStockHistoryDirect,
|
||||
hasDirectNewsService: _hasDirectNewsService,
|
||||
fetchNewsForDateDirect: _fetchNewsForDateDirect,
|
||||
fetchNewsCategoriesDirect: _fetchNewsCategoriesDirect,
|
||||
fetchInsiderTradesDirect: _fetchInsiderTradesDirect,
|
||||
fetchRangeExplainDirect: _fetchRangeExplainDirect,
|
||||
fetchStockStoryDirect: _fetchStockStoryDirect,
|
||||
fetchSimilarDaysDirect: _fetchSimilarDaysDirect
|
||||
} = apiHelpers;
|
||||
|
||||
const buildTickersFromSymbols = useCallback((symbols, previousTickers = []) => {
|
||||
if (!Array.isArray(symbols) || symbols.length === 0) {
|
||||
return previousTickers;
|
||||
}
|
||||
|
||||
return symbols
|
||||
.filter((symbol) => typeof symbol === 'string' && symbol.trim())
|
||||
.map((symbol) => {
|
||||
const normalized = symbol.trim().toUpperCase();
|
||||
const existing = previousTickers.find((ticker) => ticker.symbol === normalized);
|
||||
return existing || {
|
||||
symbol: normalized,
|
||||
price: null,
|
||||
change: null
|
||||
};
|
||||
});
|
||||
}, []);
|
||||
|
||||
const normalizePriceHistory = useCallback((payload) => {
|
||||
if (!payload || typeof payload !== 'object') {
|
||||
return {};
|
||||
}
|
||||
|
||||
const normalized = {};
|
||||
Object.entries(payload).forEach(([symbol, points]) => {
|
||||
const ticker = String(symbol || '').trim().toUpperCase();
|
||||
if (!ticker || !Array.isArray(points)) {
|
||||
return;
|
||||
}
|
||||
|
||||
normalized[ticker] = points
|
||||
.map((point) => {
|
||||
if (Array.isArray(point) && point.length >= 2) {
|
||||
const [label, value] = point;
|
||||
const price = Number(value);
|
||||
if (!label || !Number.isFinite(price)) return null;
|
||||
return {
|
||||
timestamp: String(label),
|
||||
label: String(label),
|
||||
price
|
||||
};
|
||||
}
|
||||
|
||||
if (point && typeof point === 'object') {
|
||||
const rawTimestamp = point.timestamp ?? point.t ?? point.date ?? point.label;
|
||||
const price = Number(point.price ?? point.v ?? point.value ?? point.close);
|
||||
if (!rawTimestamp || !Number.isFinite(price)) return null;
|
||||
return {
|
||||
timestamp: String(rawTimestamp),
|
||||
label: String(rawTimestamp),
|
||||
price
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
})
|
||||
.filter(Boolean)
|
||||
.slice(-120);
|
||||
});
|
||||
|
||||
return normalized;
|
||||
}, []);
|
||||
|
||||
const requestStockHistory = useCallback((symbol, { force = false } = {}) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!force && requestedStockHistoryRef.current.has(normalized)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const endDate = currentDate
|
||||
? String(currentDate).slice(0, 10)
|
||||
: new Date().toISOString().slice(0, 10);
|
||||
const end = new Date(`${endDate}T00:00:00`);
|
||||
const start = new Date(end);
|
||||
start.setDate(start.getDate() - 120);
|
||||
const startDate = start.toISOString().slice(0, 10);
|
||||
|
||||
if (_hasDirectTradingService()) {
|
||||
void _fetchStockHistoryDirect(normalized, startDate, endDate)
|
||||
.then((payload) => {
|
||||
const prices = Array.isArray(payload?.prices) ? payload.prices : [];
|
||||
setOhlcHistoryByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: prices
|
||||
}));
|
||||
setPriceHistoryByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: prices
|
||||
.map((point) => {
|
||||
const price = Number(point?.close);
|
||||
const timestamp = point?.time;
|
||||
if (!timestamp || !Number.isFinite(price)) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
timestamp: String(timestamp),
|
||||
label: String(timestamp),
|
||||
price
|
||||
};
|
||||
})
|
||||
.filter(Boolean)
|
||||
}));
|
||||
setHistorySourceByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: 'trading_service'
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct stock-history fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
const success = clientRef.current.send({
|
||||
type: 'get_stock_history',
|
||||
ticker: normalized,
|
||||
lookback_days: 120
|
||||
});
|
||||
if (success) {
|
||||
requestedStockHistoryRef.current.add(normalized);
|
||||
}
|
||||
}
|
||||
});
|
||||
requestedStockHistoryRef.current.add(normalized);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const success = clientRef.current.send({
|
||||
type: 'get_stock_history',
|
||||
ticker: normalized,
|
||||
lookback_days: 120
|
||||
});
|
||||
|
||||
if (success) {
|
||||
requestedStockHistoryRef.current.add(normalized);
|
||||
}
|
||||
|
||||
return success;
|
||||
}, [currentDate, _hasDirectTradingService, _fetchStockHistoryDirect, clientRef, requestedStockHistoryRef, setOhlcHistoryByTicker, setPriceHistoryByTicker, setHistorySourceByTicker]);
|
||||
|
||||
const requestStockExplainEvents = useCallback((symbol) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_explain_events',
|
||||
ticker: normalized
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
const requestStockNews = useCallback((symbol) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_news',
|
||||
ticker: normalized,
|
||||
lookback_days: 45,
|
||||
limit: 12
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
const requestStockNewsForDate = useCallback((symbol, date) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !date) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_hasDirectNewsService()) {
|
||||
void _fetchNewsForDateDirect(normalized, date, 20)
|
||||
.then((payload) => {
|
||||
const targetDate = typeof payload?.date === 'string' ? payload.date.trim() : date;
|
||||
const news = Array.isArray(payload?.news) ? payload.news : [];
|
||||
const freshness = payload?.freshness || null;
|
||||
setNewsByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
...(prev[normalized] || {}),
|
||||
byDate: {
|
||||
...((prev[normalized] && prev[normalized].byDate) || {}),
|
||||
[targetDate]: news
|
||||
},
|
||||
byDateFreshness: {
|
||||
...((prev[normalized] && prev[normalized].byDateFreshness) || {}),
|
||||
[targetDate]: freshness
|
||||
}
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct news-for-date fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_news_for_date',
|
||||
ticker: normalized,
|
||||
date,
|
||||
limit: 20
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_news_for_date',
|
||||
ticker: normalized,
|
||||
date,
|
||||
limit: 20
|
||||
});
|
||||
}, [clientRef, _hasDirectNewsService, _fetchNewsForDateDirect, setNewsByTicker]);
|
||||
|
||||
const requestStockNewsTimeline = useCallback((symbol) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_news_timeline',
|
||||
ticker: normalized,
|
||||
lookback_days: 90
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
const requestStockNewsCategories = useCallback((symbol) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const endDate = currentDate
|
||||
? String(currentDate).slice(0, 10)
|
||||
: new Date().toISOString().slice(0, 10);
|
||||
const end = new Date(`${endDate}T00:00:00`);
|
||||
const start = new Date(end);
|
||||
start.setDate(start.getDate() - 90);
|
||||
const startDate = start.toISOString().slice(0, 10);
|
||||
|
||||
if (_hasDirectNewsService()) {
|
||||
void _fetchNewsCategoriesDirect(normalized, startDate, endDate, 200)
|
||||
.then((payload) => {
|
||||
const freshness = payload?.freshness || null;
|
||||
setNewsByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
...(prev[normalized] || {}),
|
||||
categories: payload?.categories || {},
|
||||
categoriesStartDate: startDate,
|
||||
categoriesEndDate: endDate,
|
||||
categoriesFreshness: freshness
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct news-categories fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_news_categories',
|
||||
ticker: normalized,
|
||||
lookback_days: 90
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_news_categories',
|
||||
ticker: normalized,
|
||||
lookback_days: 90
|
||||
});
|
||||
}, [currentDate, clientRef, _hasDirectNewsService, _fetchNewsCategoriesDirect, setNewsByTicker]);
|
||||
|
||||
const requestStockInsiderTrades = useCallback((symbol, startDate = null, endDate = null) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_hasDirectTradingService()) {
|
||||
void _fetchInsiderTradesDirect(normalized, startDate, endDate, 50)
|
||||
.then((payload) => {
|
||||
const rows = Array.isArray(payload?.insider_trades) ? payload.insider_trades : [];
|
||||
setInsiderTradesByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
ticker: normalized,
|
||||
startDate: startDate || null,
|
||||
endDate: endDate || null,
|
||||
trades: rows
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct insider-trades fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_insider_trades',
|
||||
ticker: normalized,
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
limit: 50
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_insider_trades',
|
||||
ticker: normalized,
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
limit: 50
|
||||
});
|
||||
}, [clientRef, _hasDirectTradingService, _fetchInsiderTradesDirect, setInsiderTradesByTicker]);
|
||||
|
||||
const requestStockTechnicalIndicators = useCallback((symbol) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_technical_indicators',
|
||||
ticker: normalized
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
const requestStockRangeExplain = useCallback((symbol, startDate, endDate, articleIds = []) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !startDate || !endDate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_hasDirectNewsService()) {
|
||||
void _fetchRangeExplainDirect(normalized, startDate, endDate, articleIds)
|
||||
.then((payload) => {
|
||||
const result = payload?.result && typeof payload.result === 'object' ? payload.result : null;
|
||||
const freshness = payload?.freshness || null;
|
||||
if (!result?.start_date || !result?.end_date) {
|
||||
return;
|
||||
}
|
||||
const cacheKey = `${result.start_date}:${result.end_date}`;
|
||||
setNewsByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
...(prev[normalized] || {}),
|
||||
rangeExplainCache: {
|
||||
...((prev[normalized] && prev[normalized].rangeExplainCache) || {}),
|
||||
[cacheKey]: {
|
||||
...result,
|
||||
freshness
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct range explain fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_range_explain',
|
||||
ticker: normalized,
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
article_ids: Array.isArray(articleIds) ? articleIds : []
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_range_explain',
|
||||
ticker: normalized,
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
article_ids: Array.isArray(articleIds) ? articleIds : []
|
||||
});
|
||||
}, [clientRef, _hasDirectNewsService, _fetchRangeExplainDirect, setNewsByTicker]);
|
||||
|
||||
const requestStockStory = useCallback((symbol, asOfDate) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
const date = typeof asOfDate === 'string' ? asOfDate.trim() : '';
|
||||
if (!normalized || !date) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_hasDirectNewsService()) {
|
||||
void _fetchStockStoryDirect(normalized, date)
|
||||
.then((payload) => {
|
||||
setNewsByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
...(prev[normalized] || {}),
|
||||
storyCache: {
|
||||
...((prev[normalized] && prev[normalized].storyCache) || {}),
|
||||
[date]: {
|
||||
story: payload?.story || '',
|
||||
source: payload?.source || null,
|
||||
asOfDate: date,
|
||||
freshness: payload?.freshness || null
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct story fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_story',
|
||||
ticker: normalized,
|
||||
as_of_date: date
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_story',
|
||||
ticker: normalized,
|
||||
as_of_date: date
|
||||
});
|
||||
}, [clientRef, _hasDirectNewsService, _fetchStockStoryDirect, setNewsByTicker]);
|
||||
|
||||
const requestStockSimilarDays = useCallback((symbol, targetDate, lookbackDays = 365) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
const date = typeof targetDate === 'string' ? targetDate.trim() : '';
|
||||
if (!normalized || !date) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (_hasDirectNewsService()) {
|
||||
void _fetchSimilarDaysDirect(normalized, date, lookbackDays)
|
||||
.then((payload) => {
|
||||
setNewsByTicker((prev) => ({
|
||||
...prev,
|
||||
[normalized]: {
|
||||
...(prev[normalized] || {}),
|
||||
similarDaysCache: {
|
||||
...((prev[normalized] && prev[normalized].similarDaysCache) || {}),
|
||||
[date]: {
|
||||
target_features: payload?.target_features || {},
|
||||
items: Array.isArray(payload?.items) ? payload?.items : [],
|
||||
error: payload?.error || null,
|
||||
freshness: payload?.freshness || null
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Direct similar-days fetch failed, falling back to websocket:', error);
|
||||
if (clientRef.current) {
|
||||
clientRef.current.send({
|
||||
type: 'get_stock_similar_days',
|
||||
ticker: normalized,
|
||||
target_date: date,
|
||||
lookback_days: lookbackDays
|
||||
});
|
||||
}
|
||||
});
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return clientRef.current.send({
|
||||
type: 'get_stock_similar_days',
|
||||
ticker: normalized,
|
||||
target_date: date,
|
||||
lookback_days: lookbackDays
|
||||
});
|
||||
}, [clientRef, _hasDirectNewsService, _fetchSimilarDaysDirect, setNewsByTicker]);
|
||||
|
||||
const requestStockEnrich = useCallback((symbol, startDate, endDate, { force = false, onlyLocalToLlm = false } = {}) => {
|
||||
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||
if (!normalized || !clientRef.current) {
|
||||
return false;
|
||||
}
|
||||
return clientRef.current.send({
|
||||
type: 'enrich_stock_news',
|
||||
ticker: normalized,
|
||||
start_date: startDate,
|
||||
end_date: endDate,
|
||||
force: Boolean(force),
|
||||
only_local_to_llm: Boolean(onlyLocalToLlm)
|
||||
});
|
||||
}, [clientRef]);
|
||||
|
||||
return {
|
||||
buildTickersFromSymbols,
|
||||
normalizePriceHistory,
|
||||
requestStockHistory,
|
||||
requestStockExplainEvents,
|
||||
requestStockNews,
|
||||
requestStockNewsForDate,
|
||||
requestStockNewsTimeline,
|
||||
requestStockNewsCategories,
|
||||
requestStockInsiderTrades,
|
||||
requestStockTechnicalIndicators,
|
||||
requestStockRangeExplain,
|
||||
requestStockStory,
|
||||
requestStockSimilarDays,
|
||||
requestStockEnrich
|
||||
};
|
||||
}
|
||||
144
frontend/src/hooks/useWatchlistCallbacks.js
Normal file
144
frontend/src/hooks/useWatchlistCallbacks.js
Normal file
@@ -0,0 +1,144 @@
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { INITIAL_TICKERS } from '../config/constants';
|
||||
|
||||
/**
|
||||
* Extracts watchlist-related callbacks from App.jsx into a single hook.
|
||||
*/
|
||||
export function useWatchlistCallbacks({
|
||||
clientRef,
|
||||
runtimeWatchlistSymbols,
|
||||
watchlistDraftSymbols,
|
||||
watchlistInputValue,
|
||||
watchlistFeedback,
|
||||
setters
|
||||
}) {
|
||||
const {
|
||||
setWatchlistDraftSymbols,
|
||||
setWatchlistInputValue,
|
||||
setWatchlistFeedback
|
||||
} = setters;
|
||||
|
||||
const parseWatchlistInput = useCallback((value) => {
|
||||
if (typeof value !== 'string') {
|
||||
return [];
|
||||
}
|
||||
|
||||
return Array.from(
|
||||
new Set(
|
||||
value
|
||||
.split(/[\s,]+/)
|
||||
.map((symbol) => symbol.trim().toUpperCase())
|
||||
.filter(Boolean)
|
||||
)
|
||||
);
|
||||
}, []);
|
||||
|
||||
const commitWatchlistInput = useCallback((value) => {
|
||||
const parsed = parseWatchlistInput(value);
|
||||
if (parsed.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
setWatchlistDraftSymbols((prev) => Array.from(new Set([...prev, ...parsed])));
|
||||
setWatchlistInputValue('');
|
||||
if (watchlistFeedback) {
|
||||
setWatchlistFeedback(null);
|
||||
}
|
||||
return parsed;
|
||||
}, [parseWatchlistInput, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback, setters]);
|
||||
|
||||
const handleWatchlistRemove = useCallback((symbolToRemove) => {
|
||||
setWatchlistDraftSymbols((prev) => prev.filter((symbol) => symbol !== symbolToRemove));
|
||||
if (watchlistFeedback) {
|
||||
setWatchlistFeedback(null);
|
||||
}
|
||||
}, [watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
|
||||
|
||||
const handleWatchlistInputChange = useCallback((value) => {
|
||||
setWatchlistInputValue(value);
|
||||
if (watchlistFeedback) {
|
||||
setWatchlistFeedback(null);
|
||||
}
|
||||
}, [watchlistFeedback, setWatchlistInputValue, setWatchlistFeedback]);
|
||||
|
||||
const handleWatchlistInputKeyDown = useCallback((e) => {
|
||||
if (e.key === 'Enter' || e.key === ',') {
|
||||
e.preventDefault();
|
||||
commitWatchlistInput(watchlistInputValue);
|
||||
}
|
||||
}, [commitWatchlistInput, watchlistInputValue]);
|
||||
|
||||
const handleWatchlistSuggestionClick = useCallback((symbol) => {
|
||||
if (watchlistDraftSymbols.includes(symbol)) {
|
||||
return;
|
||||
}
|
||||
setWatchlistDraftSymbols((prev) => [...prev, symbol]);
|
||||
if (watchlistFeedback) {
|
||||
setWatchlistFeedback(null);
|
||||
}
|
||||
}, [watchlistDraftSymbols, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
|
||||
|
||||
const handleWatchlistRestoreCurrent = useCallback(() => {
|
||||
setWatchlistDraftSymbols(runtimeWatchlistSymbols);
|
||||
setWatchlistInputValue('');
|
||||
setWatchlistFeedback(null);
|
||||
}, [runtimeWatchlistSymbols, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback]);
|
||||
|
||||
const handleWatchlistSave = useCallback(() => {
|
||||
const pendingTickers = parseWatchlistInput(watchlistInputValue);
|
||||
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
|
||||
if (nextTickers.length === 0) {
|
||||
setWatchlistFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
|
||||
return;
|
||||
}
|
||||
|
||||
if (!clientRef.current) {
|
||||
setWatchlistFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||
return;
|
||||
}
|
||||
|
||||
setters.setIsWatchlistSaving(true);
|
||||
setWatchlistFeedback(null);
|
||||
setWatchlistDraftSymbols(nextTickers);
|
||||
setWatchlistInputValue('');
|
||||
const success = clientRef.current.send({
|
||||
type: 'update_watchlist',
|
||||
tickers: nextTickers
|
||||
});
|
||||
|
||||
if (!success) {
|
||||
setters.setIsWatchlistSaving(false);
|
||||
setWatchlistFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||
}
|
||||
}, [parseWatchlistInput, watchlistDraftSymbols, watchlistInputValue, clientRef, setters.setIsWatchlistSaving, setWatchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue]);
|
||||
|
||||
const watchlistSuggestions = useMemo(
|
||||
() => INITIAL_TICKERS.map((ticker) => ticker.symbol).filter((symbol, index, list) => list.indexOf(symbol) === index),
|
||||
[]
|
||||
);
|
||||
|
||||
const isWatchlistDraftDirty = useMemo(() => {
|
||||
if (watchlistInputValue.trim()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (watchlistDraftSymbols.length !== runtimeWatchlistSymbols.length) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return watchlistDraftSymbols.some((symbol, index) => symbol !== runtimeWatchlistSymbols[index]);
|
||||
}, [runtimeWatchlistSymbols, watchlistDraftSymbols, watchlistInputValue]);
|
||||
|
||||
return {
|
||||
parseWatchlistInput,
|
||||
commitWatchlistInput,
|
||||
handleWatchlistRemove,
|
||||
handleWatchlistInputChange,
|
||||
handleWatchlistInputKeyDown,
|
||||
handleWatchlistSuggestionClick,
|
||||
handleWatchlistRestoreCurrent,
|
||||
handleWatchlistSave,
|
||||
watchlistSuggestions,
|
||||
isWatchlistDraftDirty
|
||||
};
|
||||
}
|
||||
1057
frontend/src/hooks/useWebSocketHandler.js
Normal file
1057
frontend/src/hooks/useWebSocketHandler.js
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user