Compare commits
2 Commits
456748b01e
...
3926a6bd07
| Author | SHA1 | Date | |
|---|---|---|---|
| 3926a6bd07 | |||
| 80256a4079 |
@@ -302,36 +302,28 @@ def _start_gateway_process(
|
|||||||
|
|
||||||
@router.get("/context", response_model=RunContextResponse)
|
@router.get("/context", response_model=RunContextResponse)
|
||||||
async def get_run_context() -> RunContextResponse:
|
async def get_run_context() -> RunContextResponse:
|
||||||
"""Return the most recent run context."""
|
"""Return the current run context from in-memory state (avoids glob race condition)."""
|
||||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
manager = _runtime_state.runtime_manager
|
||||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
if manager is None or manager.context is None:
|
||||||
|
|
||||||
if not snapshots:
|
|
||||||
raise HTTPException(status_code=404, detail="No run context available")
|
raise HTTPException(status_code=404, detail="No run context available")
|
||||||
|
|
||||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
context = manager.context
|
||||||
context = latest.get("context")
|
|
||||||
if context is None:
|
|
||||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
|
||||||
|
|
||||||
return RunContextResponse(
|
return RunContextResponse(
|
||||||
config_name=context["config_name"],
|
config_name=context.config_name,
|
||||||
run_dir=context["run_dir"],
|
run_dir=str(context.run_dir),
|
||||||
bootstrap_values=context["bootstrap_values"],
|
bootstrap_values=context.bootstrap_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||||
"""Return agent states from the most recent run."""
|
"""Return agent states from the in-memory runtime manager (avoids glob race condition)."""
|
||||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
manager = _runtime_state.runtime_manager
|
||||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
if manager is None:
|
||||||
|
|
||||||
if not snapshots:
|
|
||||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||||
|
|
||||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
snapshot = manager.build_snapshot()
|
||||||
agents = latest.get("agents", [])
|
agents = snapshot.get("agents", [])
|
||||||
|
|
||||||
return RuntimeAgentsResponse(
|
return RuntimeAgentsResponse(
|
||||||
agents=[RuntimeAgentState(**a) for a in agents]
|
agents=[RuntimeAgentState(**a) for a in agents]
|
||||||
@@ -340,15 +332,13 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
|
|||||||
|
|
||||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||||
"""Return events from the most recent run."""
|
"""Return events from the in-memory runtime manager (avoids glob race condition)."""
|
||||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
manager = _runtime_state.runtime_manager
|
||||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
if manager is None:
|
||||||
|
|
||||||
if not snapshots:
|
|
||||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||||
|
|
||||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
snapshot = manager.build_snapshot()
|
||||||
events = latest.get("events", [])
|
events = snapshot.get("events", [])
|
||||||
|
|
||||||
return RuntimeEventsResponse(
|
return RuntimeEventsResponse(
|
||||||
events=[RuntimeEvent(**e) for e in events]
|
events=[RuntimeEvent(**e) for e in events]
|
||||||
@@ -362,15 +352,10 @@ async def get_gateway_status() -> GatewayStatusResponse:
|
|||||||
run_id = None
|
run_id = None
|
||||||
|
|
||||||
if is_running:
|
if is_running:
|
||||||
# Try to find run_id from runtime state
|
# Get run_id from in-memory runtime manager (avoids glob race condition)
|
||||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
manager = _runtime_state.runtime_manager
|
||||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
if manager is not None and manager.context is not None:
|
||||||
if snapshots:
|
run_id = manager.context.config_name
|
||||||
try:
|
|
||||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
|
||||||
run_id = latest.get("context", {}).get("config_name")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to parse latest snapshot: {e}")
|
|
||||||
|
|
||||||
return GatewayStatusResponse(
|
return GatewayStatusResponse(
|
||||||
is_running=is_running,
|
is_running=is_running,
|
||||||
@@ -404,8 +389,28 @@ def _build_gateway_ws_url(request: Request, port: int) -> str:
|
|||||||
return f"{ws_scheme}://{host}:{port}"
|
return f"{ws_scheme}://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||||
"""Load the latest persisted runtime snapshot."""
|
"""Return the active runtime context from the in-memory manager (avoids glob race condition).
|
||||||
|
|
||||||
|
Falls back to file-based lookup only when the in-memory manager is not available
|
||||||
|
(e.g., after a service restart). File-based lookup is deprecated and exists
|
||||||
|
only for backward compatibility.
|
||||||
|
"""
|
||||||
|
if not _is_gateway_running():
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||||
|
|
||||||
|
# Primary: use in-memory manager (always correct for current process)
|
||||||
|
manager = _runtime_state.runtime_manager
|
||||||
|
if manager is not None and manager.context is not None:
|
||||||
|
ctx = manager.context
|
||||||
|
return {
|
||||||
|
"config_name": ctx.config_name,
|
||||||
|
"run_dir": str(ctx.run_dir),
|
||||||
|
"bootstrap_values": ctx.bootstrap_values,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Deprecated fallback: scan filesystem (only for backward compatibility
|
||||||
|
# after service restart without a restart of the runtime itself)
|
||||||
snapshots = sorted(
|
snapshots = sorted(
|
||||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||||
key=lambda p: p.stat().st_mtime,
|
key=lambda p: p.stat().st_mtime,
|
||||||
@@ -413,14 +418,7 @@ def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
|||||||
)
|
)
|
||||||
if not snapshots:
|
if not snapshots:
|
||||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
|
||||||
"""Return the active runtime context from the latest snapshot."""
|
|
||||||
if not _is_gateway_running():
|
|
||||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
|
||||||
latest = _load_latest_runtime_snapshot()
|
|
||||||
context = latest.get("context") or {}
|
context = latest.get("context") or {}
|
||||||
if not context.get("config_name"):
|
if not context.get("config_name"):
|
||||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||||
@@ -663,15 +661,8 @@ async def get_current_runtime():
|
|||||||
if not _is_gateway_running():
|
if not _is_gateway_running():
|
||||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||||
|
|
||||||
# Find latest runtime state
|
# Get context from in-memory manager (avoids glob race condition)
|
||||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
context = _get_current_runtime_context()
|
||||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
|
||||||
|
|
||||||
if not snapshots:
|
|
||||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
|
||||||
|
|
||||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
|
||||||
context = latest.get("context", {})
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"run_id": context.get("config_name"),
|
"run_id": context.get("config_name"),
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from backend.api import agents_router, guard_router, workspaces_router
|
from backend.api import agents_router, guard_router, workspaces_router
|
||||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||||
|
from backend.config.env_config import get_cors_origins
|
||||||
|
|
||||||
# Global instances (initialized on startup)
|
# Global instances (initialized on startup)
|
||||||
agent_factory: AgentFactory | None = None
|
agent_factory: AgentFactory | None = None
|
||||||
@@ -49,7 +50,7 @@ def create_app(project_root: Path | None = None) -> FastAPI:
|
|||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=get_cors_origins(),
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from backend.data.market_store import MarketStore
|
from backend.data.market_store import MarketStore
|
||||||
from backend.domains import news as news_domain
|
from backend.domains import news as news_domain
|
||||||
|
from backend.config.env_config import get_cors_origins
|
||||||
|
|
||||||
|
|
||||||
def get_market_store() -> MarketStore:
|
def get_market_store() -> MarketStore:
|
||||||
@@ -27,7 +28,7 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=get_cors_origins(),
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
|
|
||||||
from backend.api import runtime_router
|
from backend.api import runtime_router
|
||||||
from backend.api.runtime import get_runtime_state
|
from backend.api.runtime import get_runtime_state
|
||||||
|
from backend.config.env_config import get_cors_origins
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
@@ -20,7 +21,7 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=get_cors_origins(),
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|||||||
@@ -8,7 +8,16 @@ from typing import Any
|
|||||||
from fastapi import FastAPI, Query
|
from fastapi import FastAPI, Query
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from backend.domains import trading as trading_domain
|
from backend.config.env_config import get_cors_origins
|
||||||
|
from backend.services.market import MarketService
|
||||||
|
from backend.tools.data_tools import (
|
||||||
|
get_company_news,
|
||||||
|
get_financial_metrics,
|
||||||
|
get_insider_trades,
|
||||||
|
get_market_cap,
|
||||||
|
get_prices,
|
||||||
|
search_line_items,
|
||||||
|
)
|
||||||
from shared.schema import (
|
from shared.schema import (
|
||||||
CompanyNewsResponse,
|
CompanyNewsResponse,
|
||||||
FinancialMetricsResponse,
|
FinancialMetricsResponse,
|
||||||
@@ -28,7 +37,7 @@ def create_app() -> FastAPI:
|
|||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=get_cors_origins(),
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
@@ -45,12 +54,8 @@ def create_app() -> FastAPI:
|
|||||||
start_date: str = Query(...),
|
start_date: str = Query(...),
|
||||||
end_date: str = Query(...),
|
end_date: str = Query(...),
|
||||||
) -> PriceResponse:
|
) -> PriceResponse:
|
||||||
payload = trading_domain.get_prices_payload(
|
prices = get_prices(ticker=ticker, start_date=start_date, end_date=end_date)
|
||||||
ticker=ticker,
|
return PriceResponse(ticker=ticker, prices=prices)
|
||||||
start_date=start_date,
|
|
||||||
end_date=end_date,
|
|
||||||
)
|
|
||||||
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
|
||||||
|
|
||||||
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||||
async def api_get_financials(
|
async def api_get_financials(
|
||||||
@@ -59,13 +64,13 @@ def create_app() -> FastAPI:
|
|||||||
period: str = Query("ttm"),
|
period: str = Query("ttm"),
|
||||||
limit: int = Query(10, ge=1, le=100),
|
limit: int = Query(10, ge=1, le=100),
|
||||||
) -> FinancialMetricsResponse:
|
) -> FinancialMetricsResponse:
|
||||||
payload = trading_domain.get_financials_payload(
|
metrics = get_financial_metrics(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
period=period,
|
period=period,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
return FinancialMetricsResponse(financial_metrics=metrics)
|
||||||
|
|
||||||
@app.get("/api/news", response_model=CompanyNewsResponse)
|
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||||
async def api_get_news(
|
async def api_get_news(
|
||||||
@@ -74,13 +79,13 @@ def create_app() -> FastAPI:
|
|||||||
start_date: str | None = Query(None),
|
start_date: str | None = Query(None),
|
||||||
limit: int = Query(1000, ge=1, le=5000),
|
limit: int = Query(1000, ge=1, le=5000),
|
||||||
) -> CompanyNewsResponse:
|
) -> CompanyNewsResponse:
|
||||||
payload = trading_domain.get_news_payload(
|
news = get_company_news(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return CompanyNewsResponse(news=payload["news"])
|
return CompanyNewsResponse(news=news)
|
||||||
|
|
||||||
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||||
async def api_get_insider_trades(
|
async def api_get_insider_trades(
|
||||||
@@ -89,18 +94,19 @@ def create_app() -> FastAPI:
|
|||||||
start_date: str | None = Query(None),
|
start_date: str | None = Query(None),
|
||||||
limit: int = Query(1000, ge=1, le=5000),
|
limit: int = Query(1000, ge=1, le=5000),
|
||||||
) -> InsiderTradeResponse:
|
) -> InsiderTradeResponse:
|
||||||
payload = trading_domain.get_insider_trades_payload(
|
trades = get_insider_trades(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
return InsiderTradeResponse(insider_trades=trades)
|
||||||
|
|
||||||
@app.get("/api/market/status")
|
@app.get("/api/market/status")
|
||||||
async def api_get_market_status() -> dict[str, Any]:
|
async def api_get_market_status() -> dict[str, Any]:
|
||||||
"""Return current market status using the existing market service logic."""
|
"""Return current market status using the existing market service logic."""
|
||||||
return trading_domain.get_market_status_payload()
|
service = MarketService(tickers=[])
|
||||||
|
return service.get_market_status()
|
||||||
|
|
||||||
@app.get("/api/market-cap")
|
@app.get("/api/market-cap")
|
||||||
async def api_get_market_cap(
|
async def api_get_market_cap(
|
||||||
@@ -108,10 +114,12 @@ def create_app() -> FastAPI:
|
|||||||
end_date: str = Query(...),
|
end_date: str = Query(...),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return market cap for one ticker/date."""
|
"""Return market cap for one ticker/date."""
|
||||||
return trading_domain.get_market_cap_payload(
|
market_cap = get_market_cap(ticker=ticker, end_date=end_date)
|
||||||
ticker=ticker,
|
return {
|
||||||
end_date=end_date,
|
"ticker": ticker,
|
||||||
)
|
"end_date": end_date,
|
||||||
|
"market_cap": market_cap,
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/api/line-items", response_model=LineItemResponse)
|
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||||
async def api_get_line_items(
|
async def api_get_line_items(
|
||||||
@@ -121,14 +129,14 @@ def create_app() -> FastAPI:
|
|||||||
period: str = Query("ttm"),
|
period: str = Query("ttm"),
|
||||||
limit: int = Query(10, ge=1, le=100),
|
limit: int = Query(10, ge=1, le=100),
|
||||||
) -> LineItemResponse:
|
) -> LineItemResponse:
|
||||||
payload = trading_domain.get_line_items_payload(
|
items = search_line_items(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
line_items=line_items,
|
line_items=line_items,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
period=period,
|
period=period,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
return LineItemResponse(search_results=payload["search_results"])
|
return LineItemResponse(search_results=items)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
"""Environment config helpers with light validation and normalization."""
|
"""Environment config helpers with light validation and normalization."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -16,6 +17,36 @@ PROVIDER_ALIASES = {
|
|||||||
"vertexai": "GEMINI",
|
"vertexai": "GEMINI",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Default dev CORS origins (localhost variants used by common dev servers)
|
||||||
|
_LOCALHOST_ORIGINS = [
|
||||||
|
"http://localhost:5173",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:8000",
|
||||||
|
"http://127.0.0.1:5173",
|
||||||
|
"http://127.0.0.1:3000",
|
||||||
|
"http://127.0.0.1:8000",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_cors_origins() -> list[str]:
|
||||||
|
"""Get CORS allowed origins from environment.
|
||||||
|
|
||||||
|
Reads CORS_ALLOWED_ORIGINS env var (comma-separated).
|
||||||
|
Falls back to localhost dev origins if not set.
|
||||||
|
Warns if "*" is configured (only acceptable for local dev).
|
||||||
|
"""
|
||||||
|
origins = get_env_list("CORS_ALLOWED_ORIGINS", default=[])
|
||||||
|
if origins:
|
||||||
|
if "*" in origins:
|
||||||
|
warnings.warn(
|
||||||
|
"CORS_ALLOWED_ORIGINS contains '*' — this allows any origin. "
|
||||||
|
"Only use in local development, never in production.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
return origins
|
||||||
|
# Fallback: local dev only
|
||||||
|
return _LOCALHOST_ORIGINS
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class AgentModelConfig:
|
class AgentModelConfig:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from backend.data.market_ingest import ingest_symbols
|
from backend.data.market_ingest import ingest_symbols
|
||||||
from backend.domains import trading as trading_domain
|
from backend.tools.data_tools import get_market_cap
|
||||||
from backend.utils.msg_adapter import FrontendAdapter
|
from backend.utils.msg_adapter import FrontendAdapter
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -265,8 +265,7 @@ async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[s
|
|||||||
if response is not None:
|
if response is not None:
|
||||||
market_cap = response.get("market_cap")
|
market_cap = response.get("market_cap")
|
||||||
if market_cap is None:
|
if market_cap is None:
|
||||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
market_cap = get_market_cap(ticker=ticker, end_date=date)
|
||||||
market_cap = payload.get("market_cap")
|
|
||||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||||
|
|||||||
@@ -11,10 +11,9 @@ from typing import Any
|
|||||||
|
|
||||||
from backend.data.provider_utils import normalize_symbol
|
from backend.data.provider_utils import normalize_symbol
|
||||||
from backend.domains import news as news_domain
|
from backend.domains import news as news_domain
|
||||||
from backend.domains import trading as trading_domain
|
|
||||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||||
from backend.tools.data_tools import prices_to_df
|
from backend.tools.data_tools import get_insider_trades, get_prices, prices_to_df
|
||||||
from shared.client import NewsServiceClient, TradingServiceClient
|
from shared.client import NewsServiceClient, TradingServiceClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -59,13 +58,12 @@ async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str,
|
|||||||
if not prices:
|
if not prices:
|
||||||
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||||
if not prices:
|
if not prices:
|
||||||
payload = await asyncio.to_thread(
|
prices = await asyncio.to_thread(
|
||||||
trading_domain.get_prices_payload,
|
get_prices,
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
start_date=start_date,
|
start_date=start_date,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
)
|
)
|
||||||
prices = payload.get("prices") or []
|
|
||||||
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||||
if prices:
|
if prices:
|
||||||
@@ -400,14 +398,13 @@ async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: di
|
|||||||
trades = response.insider_trades
|
trades = response.insider_trades
|
||||||
|
|
||||||
if not trades:
|
if not trades:
|
||||||
payload = await asyncio.to_thread(
|
trades = await asyncio.to_thread(
|
||||||
trading_domain.get_insider_trades_payload,
|
get_insider_trades,
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
end_date=end_date,
|
end_date=end_date,
|
||||||
start_date=start_date if start_date else None,
|
start_date=start_date if start_date else None,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
trades = payload.get("insider_trades") or []
|
|
||||||
|
|
||||||
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||||
formatted_trades = [{
|
formatted_trades = [{
|
||||||
@@ -540,12 +537,11 @@ async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, da
|
|||||||
prices = response.prices
|
prices = response.prices
|
||||||
|
|
||||||
if prices is None:
|
if prices is None:
|
||||||
payload = trading_domain.get_prices_payload(
|
prices = get_prices(
|
||||||
ticker=ticker,
|
ticker=ticker,
|
||||||
start_date=start_date.strftime("%Y-%m-%d"),
|
start_date=start_date.strftime("%Y-%m-%d"),
|
||||||
end_date=end_date.strftime("%Y-%m-%d"),
|
end_date=end_date.strftime("%Y-%m-%d"),
|
||||||
)
|
)
|
||||||
prices = payload.get("prices") or []
|
|
||||||
|
|
||||||
if not prices or len(prices) < 20:
|
if not prices or len(prices) < 20:
|
||||||
await websocket.send(json.dumps({
|
await websocket.send(json.dumps({
|
||||||
|
|||||||
549
backend/tests/test_gateway.py
Normal file
549
backend/tests/test_gateway.py
Normal file
@@ -0,0 +1,549 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Tests for the Gateway main class - core behavior and fallback paths."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.services.gateway import Gateway
|
||||||
|
import backend.services.gateway as gateway_module
|
||||||
|
|
||||||
|
|
||||||
|
class DummyWebSocket:
|
||||||
|
def __init__(self):
|
||||||
|
self.messages = []
|
||||||
|
self.closed = False
|
||||||
|
self._queue = []
|
||||||
|
|
||||||
|
def queue(self, data: str):
|
||||||
|
"""Queue a raw message string to be yielded by the async iterator."""
|
||||||
|
self._queue.append(data)
|
||||||
|
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if not self._queue:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return self._queue.pop(0)
|
||||||
|
|
||||||
|
async def send(self, payload: str):
|
||||||
|
self.messages.append(json.loads(payload))
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
|
||||||
|
class DummyStateSync:
|
||||||
|
def __init__(self, current_date="2026-03-16"):
|
||||||
|
self.state = {"current_date": current_date}
|
||||||
|
self.system_messages = []
|
||||||
|
self.saved = False
|
||||||
|
self.initial_state_payload = {}
|
||||||
|
|
||||||
|
def set_broadcast_fn(self, _fn):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_state(self, key, value):
|
||||||
|
self.state[key] = value
|
||||||
|
|
||||||
|
def save_state(self):
|
||||||
|
self.saved = True
|
||||||
|
|
||||||
|
async def on_system_message(self, message):
|
||||||
|
self.system_messages.append(message)
|
||||||
|
|
||||||
|
def get_initial_state_payload(self, include_dashboard=True):
|
||||||
|
return {
|
||||||
|
"status": "running",
|
||||||
|
"current_date": self.state.get("current_date", ""),
|
||||||
|
"portfolio": {},
|
||||||
|
"holdings": [],
|
||||||
|
"trades": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DummyMarketService:
|
||||||
|
def __init__(self):
|
||||||
|
self.broadcast_func = None
|
||||||
|
self.market_status = {"is_open": True, "session": "regular"}
|
||||||
|
|
||||||
|
def set_price_recorder(self, _fn):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def start(self, broadcast_func=None):
|
||||||
|
self.broadcast_func = broadcast_func
|
||||||
|
|
||||||
|
def get_market_status(self):
|
||||||
|
return self.market_status
|
||||||
|
|
||||||
|
|
||||||
|
class DummyStorage:
|
||||||
|
def __init__(self, initial_cash=100000.0, live_session=False):
|
||||||
|
self.initial_cash = initial_cash
|
||||||
|
self.is_live_session_active = live_session
|
||||||
|
self._market_store = SimpleNamespace()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def market_store(self):
|
||||||
|
return self._market_store
|
||||||
|
|
||||||
|
def load_file(self, name):
|
||||||
|
if name == "summary":
|
||||||
|
return {"totalAssetValue": self.initial_cash}
|
||||||
|
if name in ("holdings", "trades"):
|
||||||
|
return []
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_live_returns(self):
|
||||||
|
return {"session_pnl": 0.0, "session_return": 0.0}
|
||||||
|
|
||||||
|
|
||||||
|
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
|
||||||
|
storage = storage or DummyStorage()
|
||||||
|
state_sync = state_sync or DummyStateSync()
|
||||||
|
market_service = market_service or DummyMarketService()
|
||||||
|
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
|
||||||
|
return Gateway(
|
||||||
|
market_service=market_service,
|
||||||
|
storage_service=storage,
|
||||||
|
pipeline=pipeline,
|
||||||
|
state_sync=state_sync,
|
||||||
|
config=config or {"mode": "live"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Gateway initialization and core properties
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_gateway_init_sets_live_mode():
|
||||||
|
gateway = make_gateway(config={"mode": "live"})
|
||||||
|
assert gateway.mode == "live"
|
||||||
|
assert gateway.is_backtest is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_init_sets_backtest_mode_from_config():
|
||||||
|
gateway = make_gateway(config={"mode": "backtest"})
|
||||||
|
assert gateway.mode == "backtest"
|
||||||
|
assert gateway.is_backtest is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_init_sets_backtest_mode_from_flag():
|
||||||
|
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
|
||||||
|
assert gateway.is_backtest is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_init_defaults_to_live_mode():
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
assert gateway.mode == "live"
|
||||||
|
assert gateway.is_backtest is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_state_property_returns_state_sync_state():
|
||||||
|
state_sync = DummyStateSync()
|
||||||
|
state_sync.state["foo"] = "bar"
|
||||||
|
gateway = make_gateway(state_sync=state_sync)
|
||||||
|
assert gateway.state["foo"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
|
||||||
|
rows = [{"id": "1"}, {"id": "2"}]
|
||||||
|
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
|
||||||
|
result = Gateway._news_rows_need_enrichment(rows)
|
||||||
|
mock.assert_called_once_with(rows)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Service URL helpers and fallback paths
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_news_service_url_returns_config_value(monkeypatch):
|
||||||
|
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
|
||||||
|
assert gateway._news_service_url() == "http://custom-news:9000"
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_service_url_falls_back_to_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
assert gateway._news_service_url() == "http://env-news:9001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_service_url_returns_none_when_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
assert gateway._news_service_url() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_news_service_url_strips_whitespace(monkeypatch):
|
||||||
|
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
|
||||||
|
assert gateway._news_service_url() == "http://whitespace-news:9000"
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_service_url_returns_config_value(monkeypatch):
|
||||||
|
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
|
||||||
|
assert gateway._trading_service_url() == "http://custom-trading:9000"
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_service_url_falls_back_to_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
assert gateway._trading_service_url() == "http://env-trading:9001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_service_url_returns_none_when_unset(monkeypatch):
|
||||||
|
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
assert gateway._trading_service_url() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_trading_service_url_strips_whitespace(monkeypatch):
|
||||||
|
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
|
||||||
|
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
|
||||||
|
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
|
||||||
|
async def dummy_callback(client):
|
||||||
|
return "should not be called"
|
||||||
|
|
||||||
|
result = await gateway._call_news_service("test_action", dummy_callback)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_news_service_calls_callback_and_returns():
|
||||||
|
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||||
|
|
||||||
|
async def callback(client):
|
||||||
|
return {"result": "ok"}
|
||||||
|
|
||||||
|
result = await gateway._call_news_service("test_action", callback)
|
||||||
|
assert result == {"result": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_news_service_returns_none_on_exception():
|
||||||
|
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||||
|
|
||||||
|
async def failing_callback(client):
|
||||||
|
raise RuntimeError("connection failed")
|
||||||
|
|
||||||
|
result = await gateway._call_news_service("test_action", failing_callback)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
|
||||||
|
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||||
|
gateway = make_gateway(config={})
|
||||||
|
|
||||||
|
result = await gateway._call_trading_service("test_action", lambda c: None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_trading_service_calls_callback_and_returns():
|
||||||
|
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||||
|
|
||||||
|
async def callback(client):
|
||||||
|
return {"result": "ok"}
|
||||||
|
result = await gateway._call_trading_service("test_action", callback)
|
||||||
|
assert result == {"result": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_trading_service_returns_none_on_exception():
|
||||||
|
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||||
|
|
||||||
|
async def failing_callback(client):
|
||||||
|
raise RuntimeError("connection failed")
|
||||||
|
|
||||||
|
result = await gateway._call_trading_service("test_action", failing_callback)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# WebSocket message handlers
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_messages_ping_returns_pong():
|
||||||
|
"""Ping message type results in a pong response."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
ws.queue(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
await gateway._handle_client_messages(ws)
|
||||||
|
|
||||||
|
assert ws.messages[-1]["type"] == "pong"
|
||||||
|
assert "timestamp" in ws.messages[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_messages_get_state_sends_initial_state():
|
||||||
|
"""get_state message type triggers _send_initial_state."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
ws.queue(json.dumps({"type": "get_state"}))
|
||||||
|
|
||||||
|
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
|
||||||
|
await gateway._handle_client_messages(ws)
|
||||||
|
mock_send.assert_called_once_with(ws)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_messages_unknown_type_is_silently_ignored():
|
||||||
|
"""Unknown message types are silently ignored without error."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
ws.queue(json.dumps({"type": "unknown_type"}))
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await gateway._handle_client_messages(ws)
|
||||||
|
assert len(ws.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
|
||||||
|
"""Invalid JSON messages are caught by the handler's except block."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
ws.queue("not valid json")
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await gateway._handle_client_messages(ws)
|
||||||
|
assert len(ws.messages) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Backtest handling
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
|
||||||
|
gateway = make_gateway(config={"mode": "live"})
|
||||||
|
# Should not raise - backtest is ignored in live mode
|
||||||
|
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
|
||||||
|
# Gateway should not have started a backtest task
|
||||||
|
assert gateway._backtest_task is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_start_backtest_ignored_when_task_already_running():
|
||||||
|
gateway = make_gateway(config={"mode": "backtest"})
|
||||||
|
|
||||||
|
# Pre-set a backtest task
|
||||||
|
dummy_task = MagicMock()
|
||||||
|
dummy_task.done.return_value = False
|
||||||
|
gateway._backtest_task = dummy_task
|
||||||
|
|
||||||
|
# Should not start a new task
|
||||||
|
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
|
||||||
|
|
||||||
|
assert gateway._backtest_task is dummy_task # unchanged
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Manual trigger (live/mock mode)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_manual_trigger_rejected_in_backtest_mode():
|
||||||
|
gateway = make_gateway(config={"mode": "backtest"})
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
|
||||||
|
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||||
|
|
||||||
|
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
|
||||||
|
gateway = make_gateway(config={"mode": "live"})
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
|
||||||
|
# Simulate a running cycle task
|
||||||
|
dummy_task = MagicMock()
|
||||||
|
dummy_task.done.return_value = False
|
||||||
|
gateway._manual_cycle_task = dummy_task
|
||||||
|
|
||||||
|
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||||
|
|
||||||
|
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Normalization helpers
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_normalize_watchlist_filters_empty_and_dedupes():
|
||||||
|
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
|
||||||
|
assert result == ["AAPL", "MSFT"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_watchlist_handles_string_input():
|
||||||
|
result = Gateway._normalize_watchlist("aapl, msft, aapl")
|
||||||
|
assert result == ["AAPL", "MSFT"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_agent_workspace_filename_allows_editable_files():
|
||||||
|
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
|
||||||
|
result = Gateway._normalize_agent_workspace_filename(filename)
|
||||||
|
assert result == filename
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
|
||||||
|
result = Gateway._normalize_agent_workspace_filename("README.md")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
|
||||||
|
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Broadcast
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_skips_when_no_clients():
|
||||||
|
gateway = make_gateway()
|
||||||
|
gateway.connected_clients = set()
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await gateway.broadcast({"type": "test"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_sends_to_all_connected_clients():
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws1 = DummyWebSocket()
|
||||||
|
ws2 = DummyWebSocket()
|
||||||
|
gateway.connected_clients = {ws1, ws2}
|
||||||
|
|
||||||
|
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||||
|
|
||||||
|
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||||
|
assert ws1.messages[0]["data"] == "test"
|
||||||
|
assert ws2.messages[0]["data"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_removes_closed_connections():
|
||||||
|
"""Verify closed connections are removed from connected_clients set.
|
||||||
|
|
||||||
|
The broadcast method's _send_to_client helper removes a client
|
||||||
|
when it raises websockets.ConnectionClosed.
|
||||||
|
"""
|
||||||
|
gateway = make_gateway()
|
||||||
|
closed_ws = DummyWebSocket()
|
||||||
|
open_ws = DummyWebSocket()
|
||||||
|
gateway.connected_clients = {closed_ws, open_ws}
|
||||||
|
|
||||||
|
# Make closed_ws.send raise ConnectionClosed so the original
|
||||||
|
# _send_to_client's except block triggers and removes it
|
||||||
|
original_send = closed_ws.send
|
||||||
|
async def raising_send(payload):
|
||||||
|
raise gateway_module.websockets.ConnectionClosed(None, None)
|
||||||
|
closed_ws.send = raising_send
|
||||||
|
|
||||||
|
try:
|
||||||
|
await gateway.broadcast({"type": "test"})
|
||||||
|
except gateway_module.websockets.ConnectionClosed:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# The closed client should have been removed, open client should remain
|
||||||
|
assert closed_ws not in gateway.connected_clients
|
||||||
|
assert open_ws in gateway.connected_clients
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_broadcast_sends_to_all_connected_clients():
|
||||||
|
"""Verify broadcast sends to all connected clients and collects results."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws1 = DummyWebSocket()
|
||||||
|
ws2 = DummyWebSocket()
|
||||||
|
gateway.connected_clients = {ws1, ws2}
|
||||||
|
|
||||||
|
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||||
|
|
||||||
|
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||||
|
assert ws1.messages[0]["data"] == "test"
|
||||||
|
assert ws2.messages[0]["data"] == "test"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Stop
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_stop_gateway_calls_cycle_support():
|
||||||
|
gateway = make_gateway()
|
||||||
|
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
|
||||||
|
gateway.stop()
|
||||||
|
mock.assert_called_once_with(gateway)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# set_backtest_dates
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_set_backtest_dates_delegates_to_cycle_support():
|
||||||
|
gateway = make_gateway()
|
||||||
|
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
|
||||||
|
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
|
||||||
|
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Provider usage change callback
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_on_provider_usage_changed_updates_state_sync():
|
||||||
|
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
|
||||||
|
gateway = make_gateway()
|
||||||
|
gateway._loop = None # no loop set
|
||||||
|
|
||||||
|
snapshot = {"provider": "finnhub", "calls": 10}
|
||||||
|
gateway._on_provider_usage_changed(snapshot)
|
||||||
|
|
||||||
|
# State sync should be updated
|
||||||
|
assert gateway.state_sync.state.get("data_sources") == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# handle_client lifecycle
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_adds_and_removes_client_from_connected_set():
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
|
||||||
|
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||||
|
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||||
|
await gateway.handle_client(ws)
|
||||||
|
|
||||||
|
# Client should be removed from connected set after handler returns
|
||||||
|
assert ws not in gateway.connected_clients
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_client_adds_client_before_handler():
|
||||||
|
gateway = make_gateway()
|
||||||
|
ws = DummyWebSocket()
|
||||||
|
|
||||||
|
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||||
|
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||||
|
await gateway.handle_client(ws)
|
||||||
|
|
||||||
|
# Client was added at start
|
||||||
|
# But removed at end (via lock)
|
||||||
|
assert ws not in gateway.connected_clients
|
||||||
@@ -1,14 +1,31 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Tests for the extracted runtime service app surface."""
|
"""Tests for the extracted runtime service app surface."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from backend.api import runtime as runtime_module
|
from backend.api import runtime as runtime_module
|
||||||
from backend.apps.runtime_service import create_app
|
from backend.apps.runtime_service import create_app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_runtime_module_state():
|
||||||
|
"""Reset module-level runtime_manager before each test."""
|
||||||
|
runtime_module.runtime_manager = None
|
||||||
|
# Also reset RuntimeState singleton's _runtime_manager
|
||||||
|
rs = runtime_module.get_runtime_state()
|
||||||
|
rs._runtime_manager = None
|
||||||
|
yield
|
||||||
|
runtime_module.runtime_manager = None
|
||||||
|
rs = runtime_module.get_runtime_state()
|
||||||
|
rs._runtime_manager = None
|
||||||
|
|
||||||
|
|
||||||
def test_runtime_service_routes_are_exposed():
|
def test_runtime_service_routes_are_exposed():
|
||||||
app = create_app()
|
app = create_app()
|
||||||
paths = {route.path for route in app.routes}
|
paths = {route.path for route in app.routes}
|
||||||
@@ -153,7 +170,9 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
|||||||
)
|
)
|
||||||
|
|
||||||
class _DummyContext:
|
class _DummyContext:
|
||||||
def __init__(self):
|
def __init__(self, run_dir):
|
||||||
|
self.config_name = "demo"
|
||||||
|
self.run_dir = run_dir
|
||||||
self.bootstrap_values = {
|
self.bootstrap_values = {
|
||||||
"tickers": ["AAPL"],
|
"tickers": ["AAPL"],
|
||||||
"schedule_mode": "daily",
|
"schedule_mode": "daily",
|
||||||
@@ -165,8 +184,17 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
|||||||
class _DummyManager:
|
class _DummyManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.config_name = "demo"
|
self.config_name = "demo"
|
||||||
self.bootstrap = dict(_DummyContext().bootstrap_values)
|
self.bootstrap = dict(_DummyContext(run_dir).bootstrap_values)
|
||||||
self.context = _DummyContext()
|
self.context = _DummyContext(run_dir)
|
||||||
|
|
||||||
|
def build_snapshot(self):
|
||||||
|
return {
|
||||||
|
"context": {
|
||||||
|
"config_name": self.context.config_name,
|
||||||
|
"run_dir": str(self.context.run_dir),
|
||||||
|
"bootstrap_values": self.context.bootstrap_values,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
def _persist_snapshot(self):
|
def _persist_snapshot(self):
|
||||||
return None
|
return None
|
||||||
@@ -192,3 +220,385 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
|||||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||||
assert payload["resolved"]["interval_minutes"] == 15
|
assert payload["resolved"]["interval_minutes"] == 15
|
||||||
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RuntimeState singleton unit tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_runtime_state_is_singleton():
|
||||||
|
"""RuntimeState.__new__ returns the same instance across calls."""
|
||||||
|
state1 = runtime_module.RuntimeState()
|
||||||
|
state2 = runtime_module.RuntimeState()
|
||||||
|
assert state1 is state2
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_state_get_runtime_state_returns_same_instance():
|
||||||
|
"""get_runtime_state() returns the module singleton."""
|
||||||
|
instance = runtime_module.get_runtime_state()
|
||||||
|
assert instance is runtime_module._runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_state_default_values():
|
||||||
|
"""RuntimeState initializes with sensible defaults on first instantiation."""
|
||||||
|
# Reset singleton to get fresh __init__ values
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
assert state._runtime_manager is None
|
||||||
|
assert state._gateway_process is None
|
||||||
|
assert state._gateway_port == 8765
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_state_gateway_port_property():
|
||||||
|
"""gateway_port property getter and setter work correctly."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
state.gateway_port = 9999
|
||||||
|
assert state.gateway_port == 9999
|
||||||
|
state.gateway_port = 1234
|
||||||
|
assert state.gateway_port == 1234
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_state_gateway_process_property():
|
||||||
|
"""gateway_process property getter and setter work correctly."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
assert state.gateway_process is None
|
||||||
|
|
||||||
|
fake_process = object()
|
||||||
|
state.gateway_process = fake_process
|
||||||
|
assert state.gateway_process is fake_process
|
||||||
|
|
||||||
|
state.gateway_process = None
|
||||||
|
assert state.gateway_process is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_state_runtime_manager_property():
|
||||||
|
"""runtime_manager property getter and setter work correctly."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
assert state.runtime_manager is None
|
||||||
|
|
||||||
|
fake_manager = object()
|
||||||
|
state.runtime_manager = fake_manager
|
||||||
|
assert state.runtime_manager is fake_manager
|
||||||
|
|
||||||
|
state.runtime_manager = None
|
||||||
|
assert state.runtime_manager is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
|
||||||
|
def test_runtime_state_lock_property_is_async():
|
||||||
|
"""lock is an async property that returns a coroutine producing an asyncio.Lock."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
lock_coro = state.lock
|
||||||
|
assert asyncio.iscoroutine(lock_coro)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_state_async_set_get_gateway_port():
|
||||||
|
"""Async setters and getters for gateway_port with lock protection."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
await state.set_gateway_port(8888)
|
||||||
|
assert await state.get_gateway_port() == 8888
|
||||||
|
await state.set_gateway_port(7777)
|
||||||
|
assert await state.get_gateway_port() == 7777
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_state_async_set_get_gateway_process():
|
||||||
|
"""Async setters and getters for gateway_process with lock protection."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
await state.set_gateway_process(None)
|
||||||
|
assert await state.get_gateway_process() is None
|
||||||
|
|
||||||
|
fake_process = object()
|
||||||
|
await state.set_gateway_process(fake_process)
|
||||||
|
assert await state.get_gateway_process() is fake_process
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_state_async_set_get_runtime_manager():
|
||||||
|
"""Async setters and getters for runtime_manager with lock protection."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
await state.set_runtime_manager(None)
|
||||||
|
assert await state.get_runtime_manager() is None
|
||||||
|
|
||||||
|
fake_manager = object()
|
||||||
|
await state.set_runtime_manager(fake_manager)
|
||||||
|
assert await state.get_runtime_manager() is fake_manager
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# _is_gateway_running helper tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_is_gateway_running_returns_false_when_process_is_none():
|
||||||
|
"""_is_gateway_running returns False when gateway_process is None."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
new_state = runtime_module.RuntimeState()
|
||||||
|
new_state._gateway_process = None
|
||||||
|
runtime_module._runtime_state = new_state
|
||||||
|
assert runtime_module._is_gateway_running() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_gateway_running_returns_false_when_process_exited():
|
||||||
|
"""_is_gateway_running returns False when process has terminated."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
runtime_module._runtime_state = state
|
||||||
|
|
||||||
|
mock_process = MagicMock()
|
||||||
|
mock_process.poll.return_value = 1 # non-None = process has exited
|
||||||
|
|
||||||
|
state._gateway_process = mock_process
|
||||||
|
assert runtime_module._is_gateway_running() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_gateway_running_returns_true_when_process_running():
|
||||||
|
"""_is_gateway_running returns True when process is alive."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
runtime_module._runtime_state = state
|
||||||
|
|
||||||
|
mock_process = MagicMock()
|
||||||
|
mock_process.poll.return_value = None # None = still running
|
||||||
|
|
||||||
|
state._gateway_process = mock_process
|
||||||
|
assert runtime_module._is_gateway_running() is True
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# _stop_gateway helper tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_stop_gateway_returns_false_when_no_process():
|
||||||
|
"""_stop_gateway returns False if no gateway process exists."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
new_state = runtime_module.RuntimeState()
|
||||||
|
new_state._gateway_process = None
|
||||||
|
runtime_module._runtime_state = new_state
|
||||||
|
|
||||||
|
result = runtime_module._stop_gateway()
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_gateway_sets_process_to_none_after_stop():
|
||||||
|
"""_stop_gateway sets _gateway_process to None after stopping."""
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
runtime_module._runtime_state = state
|
||||||
|
|
||||||
|
mock_process = MagicMock()
|
||||||
|
mock_process.poll.return_value = None
|
||||||
|
mock_process.wait.return_value = 0
|
||||||
|
|
||||||
|
state._gateway_process = mock_process
|
||||||
|
|
||||||
|
result = runtime_module._stop_gateway()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert state._gateway_process is None
|
||||||
|
mock_process.terminate.assert_called_once()
|
||||||
|
mock_process.wait.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_gateway_kills_when_terminate_times_out():
|
||||||
|
"""_stop_gateway kills the process if terminate times out."""
|
||||||
|
import subprocess
|
||||||
|
runtime_module.RuntimeState._instance = None
|
||||||
|
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||||
|
state = runtime_module.RuntimeState()
|
||||||
|
runtime_module._runtime_state = state
|
||||||
|
|
||||||
|
mock_process = MagicMock()
|
||||||
|
mock_process.poll.return_value = None
|
||||||
|
mock_process.wait.side_effect = subprocess.TimeoutExpired("cmd", 5)
|
||||||
|
mock_process.kill.return_value = None
|
||||||
|
|
||||||
|
state._gateway_process = mock_process
|
||||||
|
|
||||||
|
result = runtime_module._stop_gateway()
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
assert state._gateway_process is None
|
||||||
|
mock_process.kill.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# _build_gateway_ws_url helper tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_build_gateway_ws_url_defaults_to_ws():
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||||
|
mock_request.url.scheme = "http"
|
||||||
|
mock_request.url.hostname = "localhost"
|
||||||
|
|
||||||
|
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||||
|
assert url == "ws://localhost:8765"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_gateway_ws_url_uses_wss_for_https():
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||||
|
mock_request.url.scheme = "https"
|
||||||
|
mock_request.url.hostname = "example.com"
|
||||||
|
|
||||||
|
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||||
|
assert url == "wss://example.com:8765"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_gateway_ws_url_respects_forwarded_proto():
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
def header_get(key, default=""):
|
||||||
|
if key == "x-forwarded-proto":
|
||||||
|
return "https"
|
||||||
|
return default
|
||||||
|
mock_request.headers.get.side_effect = header_get
|
||||||
|
mock_request.url.scheme = "http"
|
||||||
|
mock_request.url.hostname = "internal.example"
|
||||||
|
|
||||||
|
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||||
|
assert url == "wss://internal.example:8765"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_gateway_ws_url_respects_forwarded_host():
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
mock_request = MagicMock(spec=Request)
|
||||||
|
mock_request.headers.get.side_effect = lambda k, d="": {
|
||||||
|
"x-forwarded-host": "external.example.com"
|
||||||
|
}.get(k, d)
|
||||||
|
mock_request.url.scheme = "http"
|
||||||
|
mock_request.url.hostname = "internal.example"
|
||||||
|
|
||||||
|
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||||
|
assert url == "ws://external.example.com:8765"
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# _normalize_runtime_config_updates tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_validates_schedule_mode():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode="invalid")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert "schedule_mode" in str(exc_info.value.detail).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_validates_schedule_mode_values():
|
||||||
|
for invalid in ["weekly", "monthly", "once"]:
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=invalid)
|
||||||
|
with pytest.raises(HTTPException):
|
||||||
|
runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_accepts_daily_and_intraday():
|
||||||
|
for valid in ["daily", "intraday", "DAILY", "IntraDay"]:
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=valid)
|
||||||
|
result = runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert "schedule_mode" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_validates_trigger_time_format():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="25:99")
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert "trigger_time" in str(exc_info.value.detail).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_accepts_now_trigger_time():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="now")
|
||||||
|
result = runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert result["trigger_time"] == "now"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_defaults_empty_trigger_time():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time=" ")
|
||||||
|
result = runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert result["trigger_time"] == "09:30"
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_rejects_no_updates():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest()
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert "no runtime config updates" in str(exc_info.value.detail).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_normalize_runtime_config_updates_coerces_types():
|
||||||
|
req = runtime_module.UpdateRuntimeConfigRequest(
|
||||||
|
schedule_mode="intraday",
|
||||||
|
interval_minutes="30", # string from JSON
|
||||||
|
initial_cash="50000.0", # string from JSON
|
||||||
|
margin_requirement="0.25",
|
||||||
|
)
|
||||||
|
result = runtime_module._normalize_runtime_config_updates(req)
|
||||||
|
assert result["schedule_mode"] == "intraday"
|
||||||
|
assert result["interval_minutes"] == 30
|
||||||
|
assert result["initial_cash"] == 50000.0
|
||||||
|
assert result["margin_requirement"] == 0.25
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# register_runtime_manager / unregister_runtime_manager tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_register_runtime_manager_sets_module_and_singleton():
|
||||||
|
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||||
|
fake_manager = object()
|
||||||
|
|
||||||
|
runtime_module.register_runtime_manager(fake_manager)
|
||||||
|
|
||||||
|
assert runtime_module.runtime_manager is fake_manager
|
||||||
|
assert runtime_module._runtime_state.runtime_manager is fake_manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_unregister_runtime_manager_clears_module_and_singleton():
|
||||||
|
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||||
|
runtime_module._runtime_state.runtime_manager = object()
|
||||||
|
runtime_module.runtime_manager = object()
|
||||||
|
|
||||||
|
runtime_module.unregister_runtime_manager()
|
||||||
|
|
||||||
|
assert runtime_module.runtime_manager is None
|
||||||
|
assert runtime_module._runtime_state.runtime_manager is None
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# _generate_run_id tests
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def test_generate_run_id_returns_timestamp_format():
|
||||||
|
run_id = runtime_module._generate_run_id()
|
||||||
|
# Format: YYYYMMDD_HHMMSS - length is 15
|
||||||
|
assert len(run_id) == 15
|
||||||
|
assert run_id[8] == "_" # separator between date and time
|
||||||
|
assert run_id[:8].isdigit() # YYYYMMDD
|
||||||
|
assert run_id[9:].isdigit() # HHMMSS
|
||||||
|
|||||||
@@ -1,47 +1,21 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Unit tests for the trading domain helpers."""
|
"""Unit tests for data_tools functions (replaces the deleted trading_domain)."""
|
||||||
|
|
||||||
from backend.domains import trading as trading_domain
|
from backend.tools.data_tools import (
|
||||||
|
get_company_news,
|
||||||
|
get_financial_metrics,
|
||||||
|
get_insider_trades,
|
||||||
|
get_market_cap,
|
||||||
|
get_prices,
|
||||||
|
search_line_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_trading_domain_payload_wrappers(monkeypatch):
|
def test_data_tools_functions_exist():
|
||||||
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
|
"""Verify that all data_tools functions are importable and callable."""
|
||||||
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
|
assert callable(get_prices)
|
||||||
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
assert callable(get_financial_metrics)
|
||||||
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
assert callable(get_company_news)
|
||||||
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
|
assert callable(get_insider_trades)
|
||||||
|
assert callable(get_market_cap)
|
||||||
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
|
assert callable(search_line_items)
|
||||||
"ticker": "AAPL",
|
|
||||||
"prices": [{"close": 1}],
|
|
||||||
}
|
|
||||||
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
|
|
||||||
"financial_metrics": [{"ticker": "AAPL"}],
|
|
||||||
}
|
|
||||||
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
|
|
||||||
"news": [{"ticker": "AAPL"}],
|
|
||||||
}
|
|
||||||
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
|
|
||||||
"insider_trades": [{"ticker": "AAPL"}],
|
|
||||||
}
|
|
||||||
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
|
|
||||||
"ticker": "AAPL",
|
|
||||||
"end_date": "2026-03-16",
|
|
||||||
"market_cap": 2.5e12,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_market_status_payload_uses_market_service(monkeypatch):
|
|
||||||
class _FakeMarketService:
|
|
||||||
def __init__(self, tickers):
|
|
||||||
self.tickers = tickers
|
|
||||||
|
|
||||||
def get_market_status(self):
|
|
||||||
return {"status": "open", "status_text": "Open"}
|
|
||||||
|
|
||||||
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
|
|
||||||
|
|
||||||
assert trading_domain.get_market_status_payload() == {
|
|
||||||
"status": "open",
|
|
||||||
"status_text": "Open",
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -24,20 +24,17 @@ def test_trading_service_routes_are_exposed():
|
|||||||
|
|
||||||
def test_trading_service_prices_endpoint(monkeypatch):
|
def test_trading_service_prices_endpoint(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_prices_payload",
|
"backend.apps.trading_service.get_prices",
|
||||||
lambda ticker, start_date, end_date: {
|
lambda ticker, start_date, end_date: [
|
||||||
"ticker": ticker,
|
Price(
|
||||||
"prices": [
|
open=1.0,
|
||||||
Price(
|
close=2.0,
|
||||||
open=1.0,
|
high=2.5,
|
||||||
close=2.0,
|
low=0.5,
|
||||||
high=2.5,
|
volume=100,
|
||||||
low=0.5,
|
time="2026-03-20",
|
||||||
volume=100,
|
)
|
||||||
time="2026-03-20",
|
],
|
||||||
)
|
|
||||||
],
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
@@ -57,56 +54,54 @@ def test_trading_service_prices_endpoint(monkeypatch):
|
|||||||
|
|
||||||
def test_trading_service_financials_endpoint(monkeypatch):
|
def test_trading_service_financials_endpoint(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_financials_payload",
|
"backend.apps.trading_service.get_financial_metrics",
|
||||||
lambda ticker, end_date, period, limit: {
|
lambda ticker, end_date, period, limit: [
|
||||||
"financial_metrics": [
|
FinancialMetrics(
|
||||||
FinancialMetrics(
|
ticker=ticker,
|
||||||
ticker=ticker,
|
report_period=end_date,
|
||||||
report_period=end_date,
|
period=period,
|
||||||
period=period,
|
currency="USD",
|
||||||
currency="USD",
|
market_cap=123.0,
|
||||||
market_cap=123.0,
|
enterprise_value=None,
|
||||||
enterprise_value=None,
|
price_to_earnings_ratio=None,
|
||||||
price_to_earnings_ratio=None,
|
price_to_book_ratio=None,
|
||||||
price_to_book_ratio=None,
|
price_to_sales_ratio=None,
|
||||||
price_to_sales_ratio=None,
|
enterprise_value_to_ebitda_ratio=None,
|
||||||
enterprise_value_to_ebitda_ratio=None,
|
enterprise_value_to_revenue_ratio=None,
|
||||||
enterprise_value_to_revenue_ratio=None,
|
free_cash_flow_yield=None,
|
||||||
free_cash_flow_yield=None,
|
peg_ratio=None,
|
||||||
peg_ratio=None,
|
gross_margin=None,
|
||||||
gross_margin=None,
|
operating_margin=None,
|
||||||
operating_margin=None,
|
net_margin=None,
|
||||||
net_margin=None,
|
return_on_equity=None,
|
||||||
return_on_equity=None,
|
return_on_assets=None,
|
||||||
return_on_assets=None,
|
return_on_invested_capital=None,
|
||||||
return_on_invested_capital=None,
|
asset_turnover=None,
|
||||||
asset_turnover=None,
|
inventory_turnover=None,
|
||||||
inventory_turnover=None,
|
receivables_turnover=None,
|
||||||
receivables_turnover=None,
|
days_sales_outstanding=None,
|
||||||
days_sales_outstanding=None,
|
operating_cycle=None,
|
||||||
operating_cycle=None,
|
working_capital_turnover=None,
|
||||||
working_capital_turnover=None,
|
current_ratio=None,
|
||||||
current_ratio=None,
|
quick_ratio=None,
|
||||||
quick_ratio=None,
|
cash_ratio=None,
|
||||||
cash_ratio=None,
|
operating_cash_flow_ratio=None,
|
||||||
operating_cash_flow_ratio=None,
|
debt_to_equity=None,
|
||||||
debt_to_equity=None,
|
debt_to_assets=None,
|
||||||
debt_to_assets=None,
|
interest_coverage=None,
|
||||||
interest_coverage=None,
|
revenue_growth=None,
|
||||||
revenue_growth=None,
|
earnings_growth=None,
|
||||||
earnings_growth=None,
|
book_value_growth=None,
|
||||||
book_value_growth=None,
|
earnings_per_share_growth=None,
|
||||||
earnings_per_share_growth=None,
|
free_cash_flow_growth=None,
|
||||||
free_cash_flow_growth=None,
|
operating_income_growth=None,
|
||||||
operating_income_growth=None,
|
ebitda_growth=None,
|
||||||
ebitda_growth=None,
|
payout_ratio=None,
|
||||||
payout_ratio=None,
|
earnings_per_share=None,
|
||||||
earnings_per_share=None,
|
book_value_per_share=None,
|
||||||
book_value_per_share=None,
|
free_cash_flow_per_share=None,
|
||||||
free_cash_flow_per_share=None,
|
)
|
||||||
)
|
],
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
@@ -121,26 +116,22 @@ def test_trading_service_financials_endpoint(monkeypatch):
|
|||||||
|
|
||||||
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_news_payload",
|
"backend.apps.trading_service.get_company_news",
|
||||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||||
"news": [
|
CompanyNews(
|
||||||
CompanyNews(
|
ticker=ticker,
|
||||||
ticker=ticker,
|
title="News title",
|
||||||
title="News title",
|
source="polygon",
|
||||||
source="polygon",
|
url="https://example.com/news",
|
||||||
url="https://example.com/news",
|
date=end_date,
|
||||||
date=end_date,
|
)
|
||||||
)
|
],
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_insider_trades_payload",
|
"backend.apps.trading_service.get_insider_trades",
|
||||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||||
"insider_trades": [
|
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
],
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
@@ -165,8 +156,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
|||||||
return {"status": "open", "status_text": "Open"}
|
return {"status": "open", "status_text": "Open"}
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_market_status_payload",
|
"backend.apps.trading_service.MarketService",
|
||||||
lambda: _FakeMarketService().get_market_status(),
|
lambda tickers: _FakeMarketService(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
@@ -178,12 +169,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
|||||||
|
|
||||||
def test_trading_service_market_cap_endpoint(monkeypatch):
|
def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_market_cap_payload",
|
"backend.apps.trading_service.get_market_cap",
|
||||||
lambda ticker, end_date: {
|
lambda ticker, end_date: 3.5e12,
|
||||||
"ticker": ticker,
|
|
||||||
"end_date": end_date,
|
|
||||||
"market_cap": 3.5e12,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
@@ -202,18 +189,16 @@ def test_trading_service_market_cap_endpoint(monkeypatch):
|
|||||||
|
|
||||||
def test_trading_service_line_items_endpoint(monkeypatch):
|
def test_trading_service_line_items_endpoint(monkeypatch):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"backend.domains.trading.get_line_items_payload",
|
"backend.apps.trading_service.search_line_items",
|
||||||
lambda ticker, line_items, end_date, period, limit: {
|
lambda ticker, line_items, end_date, period, limit: [
|
||||||
"search_results": [
|
LineItem(
|
||||||
LineItem(
|
ticker=ticker,
|
||||||
ticker=ticker,
|
report_period=end_date,
|
||||||
report_period=end_date,
|
period=period,
|
||||||
period=period,
|
currency="USD",
|
||||||
currency="USD",
|
free_cash_flow=123.0,
|
||||||
free_cash_flow=123.0,
|
)
|
||||||
)
|
],
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with TestClient(create_app()) as client:
|
with TestClient(create_app()) as client:
|
||||||
|
|||||||
2931
frontend/src/App.jsx
2931
frontend/src/App.jsx
File diff suppressed because it is too large
Load Diff
18
frontend/src/components/ChartTabs.jsx
Normal file
18
frontend/src/components/ChartTabs.jsx
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import React from 'react';
|
||||||
|
|
||||||
|
export default function ChartTabs({
|
||||||
|
chartTab,
|
||||||
|
setChartTab,
|
||||||
|
isLiveEnabled
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="chart-tabs-floating">
|
||||||
|
<button
|
||||||
|
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
|
||||||
|
onClick={() => setChartTab('all')}
|
||||||
|
>
|
||||||
|
日线
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
293
frontend/src/components/HeaderRight.jsx
Normal file
293
frontend/src/components/HeaderRight.jsx
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import RuntimeSettingsPanel from './RuntimeSettingsPanel.jsx';
|
||||||
|
|
||||||
|
export default function HeaderRight({
|
||||||
|
// Connection state
|
||||||
|
isConnected,
|
||||||
|
// Virtual time
|
||||||
|
virtualTime,
|
||||||
|
now,
|
||||||
|
// Market & server
|
||||||
|
marketStatus,
|
||||||
|
marketStatusLabel,
|
||||||
|
serverMode,
|
||||||
|
// Labels
|
||||||
|
runtimeSummaryLabel,
|
||||||
|
livePriceSourceLabel,
|
||||||
|
historicalPriceSourceLabel,
|
||||||
|
// Settings state
|
||||||
|
isRuntimeSettingsOpen,
|
||||||
|
isRuntimeConfigSaving,
|
||||||
|
isWatchlistSaving,
|
||||||
|
runtimeConfigFeedback,
|
||||||
|
watchlistFeedback,
|
||||||
|
// Settings panel props
|
||||||
|
scheduleModeDraft,
|
||||||
|
intervalMinutesDraft,
|
||||||
|
triggerTimeDraft,
|
||||||
|
maxCommCyclesDraft,
|
||||||
|
initialCashDraft,
|
||||||
|
marginRequirementDraft,
|
||||||
|
enableMemoryDraft,
|
||||||
|
modeDraft,
|
||||||
|
pollIntervalDraft,
|
||||||
|
startDateDraft,
|
||||||
|
endDateDraft,
|
||||||
|
enableMockDraft,
|
||||||
|
watchlistDraftSymbols,
|
||||||
|
watchlistInputValue,
|
||||||
|
watchlistSuggestions,
|
||||||
|
// Callbacks
|
||||||
|
onRuntimeSettingsToggle,
|
||||||
|
onCloseSettings,
|
||||||
|
onScheduleModeChange,
|
||||||
|
onIntervalMinutesChange,
|
||||||
|
onTriggerTimeChange,
|
||||||
|
onMaxCommCyclesChange,
|
||||||
|
onInitialCashChange,
|
||||||
|
onMarginRequirementChange,
|
||||||
|
onEnableMemoryChange,
|
||||||
|
onModeChange,
|
||||||
|
onPollIntervalChange,
|
||||||
|
onStartDateChange,
|
||||||
|
onEndDateChange,
|
||||||
|
onEnableMockChange,
|
||||||
|
onWatchlistInputChange,
|
||||||
|
onWatchlistInputKeyDown,
|
||||||
|
onWatchlistAdd,
|
||||||
|
onWatchlistRemove,
|
||||||
|
onWatchlistRestoreCurrent,
|
||||||
|
onWatchlistRestoreDefault,
|
||||||
|
onWatchlistSuggestionClick,
|
||||||
|
onLaunchConfigSave,
|
||||||
|
onRestoreDefaults,
|
||||||
|
onManualTrigger,
|
||||||
|
clientRef
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="header-right" style={{ display: 'flex', alignItems: 'center', gap: 24, marginLeft: 'auto', flexWrap: 'wrap', minWidth: 0 }}>
|
||||||
|
{/* Mock Mode Indicator */}
|
||||||
|
{virtualTime && (
|
||||||
|
<div style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 6,
|
||||||
|
padding: '4px 10px',
|
||||||
|
borderRadius: 4,
|
||||||
|
background: '#FF9800',
|
||||||
|
border: '1px solid #FFB74D'
|
||||||
|
}}>
|
||||||
|
<span style={{
|
||||||
|
fontSize: '11px',
|
||||||
|
fontWeight: 600,
|
||||||
|
color: '#FFFFFF',
|
||||||
|
fontFamily: '"Courier New", monospace',
|
||||||
|
letterSpacing: '0.5px'
|
||||||
|
}}>
|
||||||
|
模拟模式
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Clock Display (only in Mock mode) */}
|
||||||
|
{virtualTime && (
|
||||||
|
<div style={{
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 8
|
||||||
|
}}>
|
||||||
|
<div style={{
|
||||||
|
display: 'flex',
|
||||||
|
flexDirection: 'column',
|
||||||
|
alignItems: 'flex-end',
|
||||||
|
gap: 2,
|
||||||
|
padding: '4px 12px',
|
||||||
|
borderRadius: 4,
|
||||||
|
background: '#1A237E',
|
||||||
|
border: '1px solid #3F51B5'
|
||||||
|
}}>
|
||||||
|
<span style={{
|
||||||
|
fontSize: '11px',
|
||||||
|
color: '#999',
|
||||||
|
fontFamily: '"Courier New", monospace',
|
||||||
|
textTransform: 'uppercase',
|
||||||
|
letterSpacing: '0.5px'
|
||||||
|
}}>
|
||||||
|
虚拟时间
|
||||||
|
</span>
|
||||||
|
<span style={{
|
||||||
|
fontSize: '14px',
|
||||||
|
fontWeight: 700,
|
||||||
|
color: '#FFFFFF',
|
||||||
|
fontFamily: '"Courier New", monospace',
|
||||||
|
letterSpacing: '1px'
|
||||||
|
}}>
|
||||||
|
{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}
|
||||||
|
</span>
|
||||||
|
<span style={{
|
||||||
|
fontSize: '10px',
|
||||||
|
color: '#999',
|
||||||
|
fontFamily: '"Courier New", monospace'
|
||||||
|
}}>
|
||||||
|
{now.toLocaleDateString('en-US', { month: 'short', day: 'numeric', year: 'numeric' })}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Fast Forward Button (only in Mock mode) */}
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
if (clientRef.current) {
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'fast_forward_time',
|
||||||
|
minutes: 30
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
console.error('Failed to send fast forward request');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
style={{
|
||||||
|
padding: '6px 12px',
|
||||||
|
borderRadius: 4,
|
||||||
|
background: '#3F51B5',
|
||||||
|
border: '1px solid #5C6BC0',
|
||||||
|
color: '#FFFFFF',
|
||||||
|
fontSize: '12px',
|
||||||
|
fontFamily: '"Courier New", monospace',
|
||||||
|
fontWeight: 600,
|
||||||
|
cursor: 'pointer',
|
||||||
|
transition: 'all 0.2s',
|
||||||
|
display: 'flex',
|
||||||
|
alignItems: 'center',
|
||||||
|
gap: 4,
|
||||||
|
textTransform: 'uppercase',
|
||||||
|
letterSpacing: '0.5px'
|
||||||
|
}}
|
||||||
|
onMouseEnter={(e) => {
|
||||||
|
e.target.style.background = '#5C6BC0';
|
||||||
|
e.target.style.borderColor = '#7986CB';
|
||||||
|
}}
|
||||||
|
onMouseLeave={(e) => {
|
||||||
|
e.target.style.background = '#3F51B5';
|
||||||
|
e.target.style.borderColor = '#5C6BC0';
|
||||||
|
}}
|
||||||
|
title="快进30分钟 (Mock模式)"
|
||||||
|
>
|
||||||
|
+30min
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Unified Status Indicator */}
|
||||||
|
<div className="header-status-inline">
|
||||||
|
<span className={`status-dot ${isConnected ? 'live' : 'offline'}`} />
|
||||||
|
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
|
||||||
|
{isConnected ? '在线' : '离线'}
|
||||||
|
</span>
|
||||||
|
{marketStatus && (
|
||||||
|
<>
|
||||||
|
<span className="status-sep">·</span>
|
||||||
|
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
|
||||||
|
{marketStatusLabel}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{livePriceSourceLabel && (
|
||||||
|
<>
|
||||||
|
<span className="status-sep">·</span>
|
||||||
|
<span className="market-text backtest">
|
||||||
|
{livePriceSourceLabel}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{historicalPriceSourceLabel && (
|
||||||
|
<>
|
||||||
|
<span className="status-sep">·</span>
|
||||||
|
<span className="market-text backtest">
|
||||||
|
{historicalPriceSourceLabel}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{runtimeSummaryLabel && (
|
||||||
|
<>
|
||||||
|
<span className="status-sep">·</span>
|
||||||
|
<span className="market-text backtest" title="当前运行配置">
|
||||||
|
{runtimeSummaryLabel}
|
||||||
|
</span>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
<span className="status-sep">·</span>
|
||||||
|
<span className="time-text">{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{serverMode !== 'backtest' && (
|
||||||
|
<button
|
||||||
|
onClick={onManualTrigger}
|
||||||
|
disabled={!isConnected}
|
||||||
|
style={{
|
||||||
|
padding: '6px 12px',
|
||||||
|
borderRadius: 4,
|
||||||
|
background: isConnected ? '#111111' : '#8a8a8a',
|
||||||
|
border: '1px solid #111111',
|
||||||
|
color: '#FFFFFF',
|
||||||
|
fontSize: '11px',
|
||||||
|
fontFamily: '"Courier New", monospace',
|
||||||
|
fontWeight: 700,
|
||||||
|
cursor: isConnected ? 'pointer' : 'not-allowed',
|
||||||
|
letterSpacing: '0.4px',
|
||||||
|
textTransform: 'uppercase'
|
||||||
|
}}
|
||||||
|
title="手动触发一轮分析与交易决策"
|
||||||
|
>
|
||||||
|
手动运行
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<RuntimeSettingsPanel
|
||||||
|
showTrigger={false}
|
||||||
|
isOpen={isRuntimeSettingsOpen}
|
||||||
|
isConnected={isConnected}
|
||||||
|
isSaving={isRuntimeConfigSaving || isWatchlistSaving}
|
||||||
|
feedback={runtimeConfigFeedback || watchlistFeedback}
|
||||||
|
scheduleMode={scheduleModeDraft}
|
||||||
|
intervalMinutes={intervalMinutesDraft}
|
||||||
|
triggerTime={triggerTimeDraft}
|
||||||
|
maxCommCycles={maxCommCyclesDraft}
|
||||||
|
initialCash={initialCashDraft}
|
||||||
|
marginRequirement={marginRequirementDraft}
|
||||||
|
enableMemory={enableMemoryDraft}
|
||||||
|
mode={modeDraft}
|
||||||
|
pollInterval={pollIntervalDraft}
|
||||||
|
startDate={startDateDraft}
|
||||||
|
endDate={endDateDraft}
|
||||||
|
enableMock={enableMockDraft}
|
||||||
|
watchlistSymbols={watchlistDraftSymbols}
|
||||||
|
watchlistInputValue={watchlistInputValue}
|
||||||
|
watchlistSuggestions={watchlistSuggestions}
|
||||||
|
onToggle={onRuntimeSettingsToggle}
|
||||||
|
onClose={onCloseSettings}
|
||||||
|
onScheduleModeChange={onScheduleModeChange}
|
||||||
|
onIntervalMinutesChange={onIntervalMinutesChange}
|
||||||
|
onTriggerTimeChange={onTriggerTimeChange}
|
||||||
|
onMaxCommCyclesChange={onMaxCommCyclesChange}
|
||||||
|
onInitialCashChange={onInitialCashChange}
|
||||||
|
onMarginRequirementChange={onMarginRequirementChange}
|
||||||
|
onEnableMemoryChange={onEnableMemoryChange}
|
||||||
|
onModeChange={onModeChange}
|
||||||
|
onPollIntervalChange={onPollIntervalChange}
|
||||||
|
onStartDateChange={onStartDateChange}
|
||||||
|
onEndDateChange={onEndDateChange}
|
||||||
|
onEnableMockChange={onEnableMockChange}
|
||||||
|
onWatchlistInputChange={onWatchlistInputChange}
|
||||||
|
onWatchlistInputKeyDown={onWatchlistInputKeyDown}
|
||||||
|
onWatchlistAdd={onWatchlistAdd}
|
||||||
|
onWatchlistRemove={onWatchlistRemove}
|
||||||
|
onWatchlistRestoreCurrent={onWatchlistRestoreCurrent}
|
||||||
|
onWatchlistRestoreDefault={onWatchlistRestoreDefault}
|
||||||
|
onWatchlistSuggestionClick={onWatchlistSuggestionClick}
|
||||||
|
onSave={onLaunchConfigSave}
|
||||||
|
onRestoreDefaults={onRestoreDefaults}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
52
frontend/src/components/TickerBar.jsx
Normal file
52
frontend/src/components/TickerBar.jsx
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import React from 'react';
|
||||||
|
import StockLogo from './StockLogo';
|
||||||
|
import { formatNumber, formatTickerPrice } from '../utils/formatters';
|
||||||
|
|
||||||
|
export default function TickerBar({
|
||||||
|
displayTickers,
|
||||||
|
rollingTickers,
|
||||||
|
portfolioData,
|
||||||
|
onTickerSelect
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="ticker-bar">
|
||||||
|
<div className="ticker-track">
|
||||||
|
{[0, 1].map((groupIdx) => (
|
||||||
|
<div key={groupIdx} className="ticker-group">
|
||||||
|
{displayTickers.map(ticker => (
|
||||||
|
<div
|
||||||
|
key={`${ticker.symbol}-${groupIdx}`}
|
||||||
|
className="ticker-item"
|
||||||
|
onClick={() => onTickerSelect && onTickerSelect(ticker.symbol)}
|
||||||
|
style={{ cursor: onTickerSelect ? 'pointer' : 'default' }}
|
||||||
|
>
|
||||||
|
<StockLogo ticker={ticker.symbol} size={16} />
|
||||||
|
<span className="ticker-symbol">{ticker.symbol}</span>
|
||||||
|
<span className="ticker-price">
|
||||||
|
<span className={`ticker-price-value ${rollingTickers[ticker.symbol] ? 'rolling' : ''}`}>
|
||||||
|
{ticker.price !== null && ticker.price !== undefined
|
||||||
|
? `$${formatTickerPrice(ticker.price)}`
|
||||||
|
: '-'}
|
||||||
|
</span>
|
||||||
|
</span>
|
||||||
|
<span className={`ticker-change ${
|
||||||
|
ticker.change === null || ticker.change === undefined
|
||||||
|
? ''
|
||||||
|
: ticker.change >= 0 ? 'positive' : 'negative'
|
||||||
|
}`}>
|
||||||
|
{ticker.change !== null && ticker.change !== undefined
|
||||||
|
? `${ticker.change >= 0 ? '+' : ''}${ticker.change.toFixed(2)}%`
|
||||||
|
: '-'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<div className="portfolio-value">
|
||||||
|
<span className="portfolio-label">投资组合</span>
|
||||||
|
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
308
frontend/src/hooks/useAgentCallbacks.js
Normal file
308
frontend/src/hooks/useAgentCallbacks.js
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import { uploadAgentSkillZip } from '../services/runtimeApi';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts agent/skill-related callbacks from App.jsx into a single hook.
|
||||||
|
*/
|
||||||
|
export function useAgentCallbacks({
|
||||||
|
clientRef,
|
||||||
|
selectedSkillAgentId,
|
||||||
|
selectedWorkspaceFile,
|
||||||
|
workspaceDraftContent,
|
||||||
|
agentProfilesByAgent,
|
||||||
|
agentSkillsByAgent,
|
||||||
|
workspaceFilesByAgent,
|
||||||
|
AGENTS,
|
||||||
|
setters
|
||||||
|
}) {
|
||||||
|
const {
|
||||||
|
setIsAgentSkillsLoading,
|
||||||
|
setAgentSkillsFeedback,
|
||||||
|
setSkillDetailLoadingKey,
|
||||||
|
setAgentSkillsSavingKey,
|
||||||
|
setIsWorkspaceFileLoading,
|
||||||
|
setWorkspaceFileSavingKey,
|
||||||
|
setWorkspaceFileFeedback,
|
||||||
|
setLocalSkillDraftsByKey,
|
||||||
|
setAgentSkillsByAgent,
|
||||||
|
setAgentProfilesByAgent,
|
||||||
|
setSkillDetailsByName,
|
||||||
|
setWorkspaceFilesByAgent,
|
||||||
|
setSelectedSkillAgentId,
|
||||||
|
setSelectedWorkspaceFile,
|
||||||
|
setWorkspaceDraftContent
|
||||||
|
} = setters;
|
||||||
|
|
||||||
|
const requestAgentSkills = useCallback((agentId) => {
|
||||||
|
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
setIsAgentSkillsLoading(true);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_agent_skills',
|
||||||
|
agent_id: normalized
|
||||||
|
});
|
||||||
|
}, [clientRef, setIsAgentSkillsLoading, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const requestAgentProfile = useCallback((agentId) => {
|
||||||
|
const normalized = typeof agentId === 'string' ? agentId.trim() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_agent_profile',
|
||||||
|
agent_id: normalized
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
const requestSkillDetail = useCallback((skillName) => {
|
||||||
|
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const detailKey = `${selectedSkillAgentId}:${normalized}`;
|
||||||
|
setSkillDetailLoadingKey(detailKey);
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_skill_detail',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: normalized
|
||||||
|
});
|
||||||
|
}, [clientRef, selectedSkillAgentId, setSkillDetailLoadingKey]);
|
||||||
|
|
||||||
|
const requestWorkspaceFile = useCallback((agentId, filename) => {
|
||||||
|
const normalizedAgentId = typeof agentId === 'string' ? agentId.trim() : '';
|
||||||
|
const normalizedFilename = typeof filename === 'string' ? filename.trim() : '';
|
||||||
|
if (!normalizedAgentId || !normalizedFilename || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
setIsWorkspaceFileLoading(true);
|
||||||
|
setWorkspaceFileFeedback(null);
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_agent_workspace_file',
|
||||||
|
agent_id: normalizedAgentId,
|
||||||
|
filename: normalizedFilename
|
||||||
|
});
|
||||||
|
}, [clientRef, setIsWorkspaceFileLoading, setWorkspaceFileFeedback]);
|
||||||
|
|
||||||
|
const handleCreateLocalSkill = useCallback((skillName) => {
|
||||||
|
const normalized = typeof skillName === 'string' ? skillName.trim() : '';
|
||||||
|
if (!normalized) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '技能名称不能为空' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${normalized}:create`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'create_agent_local_skill',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: normalized
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const handleLocalSkillDraftChange = useCallback((skillName, content) => {
|
||||||
|
const detailKey = `${selectedSkillAgentId}:${skillName}`;
|
||||||
|
setLocalSkillDraftsByKey((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[detailKey]: content
|
||||||
|
}));
|
||||||
|
}, [selectedSkillAgentId, setLocalSkillDraftsByKey]);
|
||||||
|
|
||||||
|
const handleLocalSkillSave = useCallback((skillName) => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const detailKey = `${selectedSkillAgentId}:${skillName}`;
|
||||||
|
const content = setters.localSkillDraftsByKey[detailKey];
|
||||||
|
if (typeof content !== 'string') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:content`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'update_agent_local_skill',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: skillName,
|
||||||
|
content
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, setters.localSkillDraftsByKey, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const handleLocalSkillDelete = useCallback((skillName) => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:delete`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'delete_agent_local_skill',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: skillName
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const handleRemoveSharedSkill = useCallback((skillName) => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}:remove`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'remove_agent_skill',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: skillName
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const handleAgentSkillToggle = useCallback((skillName, enabled) => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:${skillName}`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'update_agent_skill',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
skill_name: skillName,
|
||||||
|
enabled
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
const handleSkillAgentChange = useCallback((agentId) => {
|
||||||
|
setSelectedSkillAgentId(agentId);
|
||||||
|
requestAgentProfile(agentId);
|
||||||
|
requestAgentSkills(agentId);
|
||||||
|
requestWorkspaceFile(agentId, selectedWorkspaceFile);
|
||||||
|
}, [requestAgentProfile, requestAgentSkills, requestWorkspaceFile, selectedWorkspaceFile, setSelectedSkillAgentId]);
|
||||||
|
|
||||||
|
const handleWorkspaceFileChange = useCallback((filename) => {
|
||||||
|
setSelectedWorkspaceFile(filename);
|
||||||
|
requestWorkspaceFile(selectedSkillAgentId, filename);
|
||||||
|
}, [requestWorkspaceFile, selectedSkillAgentId, setSelectedWorkspaceFile]);
|
||||||
|
|
||||||
|
const handleWorkspaceFileSave = useCallback(() => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setWorkspaceFileFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const key = `${selectedSkillAgentId}:${selectedWorkspaceFile}`;
|
||||||
|
setWorkspaceFileSavingKey(key);
|
||||||
|
setWorkspaceFileFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'update_agent_workspace_file',
|
||||||
|
agent_id: selectedSkillAgentId,
|
||||||
|
filename: selectedWorkspaceFile,
|
||||||
|
content: workspaceDraftContent
|
||||||
|
});
|
||||||
|
if (!success) {
|
||||||
|
setWorkspaceFileSavingKey(null);
|
||||||
|
setWorkspaceFileFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [clientRef, selectedSkillAgentId, selectedWorkspaceFile, workspaceDraftContent, setWorkspaceFileSavingKey, setWorkspaceFileFeedback]);
|
||||||
|
|
||||||
|
const handleUploadExternalSkill = useCallback(async (file) => {
|
||||||
|
if (!(file instanceof File)) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '请选择 zip 文件后再上传' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!selectedSkillAgentId) {
|
||||||
|
setAgentSkillsFeedback({ type: 'error', text: '未选择目标 Agent' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setAgentSkillsSavingKey(`${selectedSkillAgentId}:__upload__`);
|
||||||
|
setAgentSkillsFeedback(null);
|
||||||
|
try {
|
||||||
|
const result = await uploadAgentSkillZip({
|
||||||
|
agentId: selectedSkillAgentId,
|
||||||
|
file,
|
||||||
|
activate: true
|
||||||
|
});
|
||||||
|
setAgentSkillsFeedback({
|
||||||
|
type: 'success',
|
||||||
|
text: `已上传并安装技能 ${result.skill_name || ''}`.trim()
|
||||||
|
});
|
||||||
|
requestAgentSkills(selectedSkillAgentId);
|
||||||
|
} catch (error) {
|
||||||
|
setAgentSkillsFeedback({
|
||||||
|
type: 'error',
|
||||||
|
text: `上传失败: ${error.message || '未知错误'}`
|
||||||
|
});
|
||||||
|
} finally {
|
||||||
|
setAgentSkillsSavingKey(null);
|
||||||
|
}
|
||||||
|
}, [selectedSkillAgentId, requestAgentSkills, setAgentSkillsSavingKey, setAgentSkillsFeedback]);
|
||||||
|
|
||||||
|
// Sync workspace draft content when selected content changes
|
||||||
|
useEffect(() => {
|
||||||
|
const selectedWorkspaceContent = workspaceFilesByAgent[selectedSkillAgentId]?.[selectedWorkspaceFile] || '';
|
||||||
|
setWorkspaceDraftContent(selectedWorkspaceContent);
|
||||||
|
}, [selectedWorkspaceFile, selectedSkillAgentId, workspaceFilesByAgent, setWorkspaceDraftContent]);
|
||||||
|
|
||||||
|
// Load agent profiles and skills when view changes
|
||||||
|
const currentView = setters.currentView;
|
||||||
|
const isConnected = setters.isConnected;
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (currentView !== 'traders' || !isConnected) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
AGENTS.forEach((agent) => {
|
||||||
|
if (!agentProfilesByAgent[agent.id]) {
|
||||||
|
requestAgentProfile(agent.id);
|
||||||
|
}
|
||||||
|
if (!agentSkillsByAgent[agent.id]) {
|
||||||
|
requestAgentSkills(agent.id);
|
||||||
|
}
|
||||||
|
if (!workspaceFilesByAgent[agent.id]?.['MEMORY.md']) {
|
||||||
|
requestWorkspaceFile(agent.id, 'MEMORY.md');
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}, [agentProfilesByAgent, agentSkillsByAgent, currentView, isConnected, requestAgentProfile, requestAgentSkills, requestWorkspaceFile, workspaceFilesByAgent, AGENTS]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
requestAgentSkills,
|
||||||
|
requestAgentProfile,
|
||||||
|
requestSkillDetail,
|
||||||
|
requestWorkspaceFile,
|
||||||
|
handleCreateLocalSkill,
|
||||||
|
handleLocalSkillDraftChange,
|
||||||
|
handleLocalSkillSave,
|
||||||
|
handleLocalSkillDelete,
|
||||||
|
handleRemoveSharedSkill,
|
||||||
|
handleAgentSkillToggle,
|
||||||
|
handleSkillAgentChange,
|
||||||
|
handleWorkspaceFileChange,
|
||||||
|
handleWorkspaceFileSave,
|
||||||
|
handleUploadExternalSkill
|
||||||
|
};
|
||||||
|
}
|
||||||
257
frontend/src/hooks/useRuntimeCallbacks.js
Normal file
257
frontend/src/hooks/useRuntimeCallbacks.js
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
import { useCallback } from 'react';
|
||||||
|
import { startRuntime } from '../services/runtimeApi';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts runtime config callbacks from App.jsx into a single hook.
|
||||||
|
*/
|
||||||
|
export function useRuntimeCallbacks({
|
||||||
|
clientRef,
|
||||||
|
addSystemMessage,
|
||||||
|
parseWatchlistInput,
|
||||||
|
setters
|
||||||
|
}) {
|
||||||
|
const {
|
||||||
|
setScheduleModeDraft,
|
||||||
|
setIntervalMinutesDraft,
|
||||||
|
setTriggerTimeDraft,
|
||||||
|
setMaxCommCyclesDraft,
|
||||||
|
setInitialCashDraft,
|
||||||
|
setMarginRequirementDraft,
|
||||||
|
setEnableMemoryDraft,
|
||||||
|
setModeDraft,
|
||||||
|
setPollIntervalDraft,
|
||||||
|
setStartDateDraft,
|
||||||
|
setEndDateDraft,
|
||||||
|
setEnableMockDraft,
|
||||||
|
setRuntimeConfigFeedback,
|
||||||
|
setIsRuntimeConfigSaving,
|
||||||
|
setIsWatchlistSaving,
|
||||||
|
setIsRuntimeSettingsOpen,
|
||||||
|
watchlistDraftSymbols,
|
||||||
|
watchlistInputValue,
|
||||||
|
scheduleModeDraft,
|
||||||
|
intervalMinutesDraft,
|
||||||
|
maxCommCyclesDraft,
|
||||||
|
initialCashDraft,
|
||||||
|
marginRequirementDraft,
|
||||||
|
enableMemoryDraft,
|
||||||
|
modeDraft,
|
||||||
|
pollIntervalDraft,
|
||||||
|
startDateDraft,
|
||||||
|
endDateDraft,
|
||||||
|
enableMockDraft
|
||||||
|
} = setters;
|
||||||
|
|
||||||
|
const handleRuntimeConfigSave = useCallback(() => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const interval = Number(intervalMinutesDraft);
|
||||||
|
const maxCommCycles = Number(maxCommCyclesDraft);
|
||||||
|
if (!Number.isInteger(interval) || interval <= 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsRuntimeConfigSaving(true);
|
||||||
|
setRuntimeConfigFeedback(null);
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'update_runtime_config',
|
||||||
|
schedule_mode: scheduleModeDraft,
|
||||||
|
interval_minutes: interval,
|
||||||
|
trigger_time: triggerTimeDraft,
|
||||||
|
max_comm_cycles: maxCommCycles,
|
||||||
|
initial_cash: Number(initialCashDraft),
|
||||||
|
margin_requirement: Number(marginRequirementDraft),
|
||||||
|
enable_memory: Boolean(enableMemoryDraft)
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
setIsRuntimeConfigSaving(false);
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
clientRef,
|
||||||
|
intervalMinutesDraft,
|
||||||
|
maxCommCyclesDraft,
|
||||||
|
scheduleModeDraft,
|
||||||
|
triggerTimeDraft,
|
||||||
|
initialCashDraft,
|
||||||
|
marginRequirementDraft,
|
||||||
|
enableMemoryDraft,
|
||||||
|
setIsRuntimeConfigSaving,
|
||||||
|
setRuntimeConfigFeedback
|
||||||
|
]);
|
||||||
|
|
||||||
|
const handleLaunchConfigSave = useCallback(async () => {
|
||||||
|
const pendingTickers = parseWatchlistInput(watchlistInputValue);
|
||||||
|
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
|
||||||
|
if (nextTickers.length === 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const interval = Number(intervalMinutesDraft);
|
||||||
|
const maxCommCycles = Number(maxCommCyclesDraft);
|
||||||
|
const initialCash = Number(initialCashDraft);
|
||||||
|
const marginRequirement = Number(marginRequirementDraft);
|
||||||
|
if (!Number.isInteger(interval) || interval <= 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '间隔必须是正整数分钟' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!Number.isInteger(maxCommCycles) || maxCommCycles <= 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '讨论轮数必须是正整数' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!Number.isFinite(initialCash) || initialCash <= 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '初始资金必须是正数' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!Number.isFinite(marginRequirement) || marginRequirement < 0) {
|
||||||
|
setRuntimeConfigFeedback({ type: 'error', text: '保证金要求不能为负数' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setIsRuntimeConfigSaving(true);
|
||||||
|
setIsWatchlistSaving(true);
|
||||||
|
setRuntimeConfigFeedback(null);
|
||||||
|
setters.setWatchlistFeedback(null);
|
||||||
|
setters.setWatchlistDraftSymbols(nextTickers);
|
||||||
|
setters.setWatchlistInputValue('');
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result = await startRuntime({
|
||||||
|
tickers: nextTickers,
|
||||||
|
schedule_mode: scheduleModeDraft,
|
||||||
|
interval_minutes: interval,
|
||||||
|
trigger_time: triggerTimeDraft,
|
||||||
|
max_comm_cycles: maxCommCycles,
|
||||||
|
initial_cash: initialCash,
|
||||||
|
margin_requirement: marginRequirement,
|
||||||
|
enable_memory: Boolean(enableMemoryDraft),
|
||||||
|
mode: modeDraft || 'live',
|
||||||
|
poll_interval: Number(pollIntervalDraft) || 10,
|
||||||
|
start_date: startDateDraft || null,
|
||||||
|
end_date: endDateDraft || null,
|
||||||
|
enable_mock: Boolean(enableMockDraft)
|
||||||
|
});
|
||||||
|
|
||||||
|
setIsRuntimeConfigSaving(false);
|
||||||
|
setIsWatchlistSaving(false);
|
||||||
|
setIsRuntimeSettingsOpen(false);
|
||||||
|
setRuntimeConfigFeedback({
|
||||||
|
type: 'success',
|
||||||
|
text: `任务已启动: ${result.run_id}`
|
||||||
|
});
|
||||||
|
addSystemMessage(`新任务已启动: ${result.run_id}`);
|
||||||
|
} catch (error) {
|
||||||
|
setIsRuntimeConfigSaving(false);
|
||||||
|
setIsWatchlistSaving(false);
|
||||||
|
setRuntimeConfigFeedback({
|
||||||
|
type: 'error',
|
||||||
|
text: `启动失败: ${error.message}`
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [
|
||||||
|
parseWatchlistInput,
|
||||||
|
watchlistInputValue,
|
||||||
|
watchlistDraftSymbols,
|
||||||
|
intervalMinutesDraft,
|
||||||
|
maxCommCyclesDraft,
|
||||||
|
initialCashDraft,
|
||||||
|
marginRequirementDraft,
|
||||||
|
enableMemoryDraft,
|
||||||
|
scheduleModeDraft,
|
||||||
|
triggerTimeDraft,
|
||||||
|
modeDraft,
|
||||||
|
pollIntervalDraft,
|
||||||
|
startDateDraft,
|
||||||
|
endDateDraft,
|
||||||
|
enableMockDraft,
|
||||||
|
setters,
|
||||||
|
setIsRuntimeConfigSaving,
|
||||||
|
setIsWatchlistSaving,
|
||||||
|
setRuntimeConfigFeedback,
|
||||||
|
setIsRuntimeSettingsOpen,
|
||||||
|
addSystemMessage
|
||||||
|
]);
|
||||||
|
|
||||||
|
const handleRuntimeDefaultsRestore = useCallback(() => {
|
||||||
|
setScheduleModeDraft('daily');
|
||||||
|
setIntervalMinutesDraft('60');
|
||||||
|
setTriggerTimeDraft('09:30');
|
||||||
|
setMaxCommCyclesDraft('2');
|
||||||
|
setInitialCashDraft('100000');
|
||||||
|
setMarginRequirementDraft('0');
|
||||||
|
setEnableMemoryDraft(false);
|
||||||
|
setModeDraft('live');
|
||||||
|
setPollIntervalDraft('10');
|
||||||
|
setStartDateDraft('');
|
||||||
|
setEndDateDraft('');
|
||||||
|
setEnableMockDraft(false);
|
||||||
|
setRuntimeConfigFeedback(null);
|
||||||
|
}, [
|
||||||
|
setScheduleModeDraft,
|
||||||
|
setIntervalMinutesDraft,
|
||||||
|
setTriggerTimeDraft,
|
||||||
|
setMaxCommCyclesDraft,
|
||||||
|
setInitialCashDraft,
|
||||||
|
setMarginRequirementDraft,
|
||||||
|
setEnableMemoryDraft,
|
||||||
|
setModeDraft,
|
||||||
|
setPollIntervalDraft,
|
||||||
|
setStartDateDraft,
|
||||||
|
setEndDateDraft,
|
||||||
|
setEnableMockDraft,
|
||||||
|
setRuntimeConfigFeedback
|
||||||
|
]);
|
||||||
|
|
||||||
|
const handleRuntimeSettingsToggle = useCallback(() => {
|
||||||
|
setRuntimeConfigFeedback(null);
|
||||||
|
setters.setAgentSkillsFeedback(null);
|
||||||
|
setters.setWorkspaceFileFeedback(null);
|
||||||
|
setIsRuntimeSettingsOpen((prev) => {
|
||||||
|
const nextOpen = !prev;
|
||||||
|
if (nextOpen) {
|
||||||
|
// Initialize watchlist draft when opening settings
|
||||||
|
setters.setWatchlistDraftSymbols(settlers.runtimeWatchlistSymbols);
|
||||||
|
setters.setWatchlistInputValue('');
|
||||||
|
setters.setWatchlistFeedback(null);
|
||||||
|
}
|
||||||
|
return nextOpen;
|
||||||
|
});
|
||||||
|
setters.setIsWatchlistPanelOpen(false);
|
||||||
|
}, [setRuntimeConfigFeedback, setters, setIsRuntimeSettingsOpen]);
|
||||||
|
|
||||||
|
const handleManualTrigger = useCallback(() => {
|
||||||
|
if (!clientRef.current) {
|
||||||
|
addSystemMessage('连接未就绪,无法手动触发');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'trigger_strategy'
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
addSystemMessage('手动触发发送失败,请检查连接状态');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
addSystemMessage('已发送手动触发请求');
|
||||||
|
}, [clientRef, addSystemMessage]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
handleRuntimeConfigSave,
|
||||||
|
handleLaunchConfigSave,
|
||||||
|
handleRuntimeDefaultsRestore,
|
||||||
|
handleRuntimeSettingsToggle,
|
||||||
|
handleManualTrigger
|
||||||
|
};
|
||||||
|
}
|
||||||
584
frontend/src/hooks/useStockRequestCallbacks.js
Normal file
584
frontend/src/hooks/useStockRequestCallbacks.js
Normal file
@@ -0,0 +1,584 @@
|
|||||||
|
import { useCallback } from 'react';
|
||||||
|
import {
|
||||||
|
fetchNewsCategoriesDirect,
|
||||||
|
fetchNewsForDateDirect,
|
||||||
|
fetchRangeExplainDirect,
|
||||||
|
fetchSimilarDaysDirect,
|
||||||
|
fetchStockStoryDirect,
|
||||||
|
hasDirectNewsService
|
||||||
|
} from '../services/newsApi';
|
||||||
|
import {
|
||||||
|
fetchInsiderTradesDirect,
|
||||||
|
fetchStockHistoryDirect,
|
||||||
|
hasDirectTradingService
|
||||||
|
} from '../services/tradingApi';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts all requestStock* callbacks from App.jsx into a single hook.
|
||||||
|
*/
|
||||||
|
export function useStockRequestCallbacks({
|
||||||
|
clientRef,
|
||||||
|
currentDate,
|
||||||
|
requestedStockHistoryRef,
|
||||||
|
setters,
|
||||||
|
apiHelpers
|
||||||
|
}) {
|
||||||
|
const {
|
||||||
|
setOhlcHistoryByTicker,
|
||||||
|
setHistorySourceByTicker,
|
||||||
|
setExplainEventsByTicker,
|
||||||
|
setNewsByTicker,
|
||||||
|
setInsiderTradesByTicker,
|
||||||
|
setTechnicalIndicatorsByTicker,
|
||||||
|
setPriceHistoryByTicker
|
||||||
|
} = setters;
|
||||||
|
|
||||||
|
const {
|
||||||
|
hasDirectTradingService: _hasDirectTradingService,
|
||||||
|
fetchStockHistoryDirect: _fetchStockHistoryDirect,
|
||||||
|
hasDirectNewsService: _hasDirectNewsService,
|
||||||
|
fetchNewsForDateDirect: _fetchNewsForDateDirect,
|
||||||
|
fetchNewsCategoriesDirect: _fetchNewsCategoriesDirect,
|
||||||
|
fetchInsiderTradesDirect: _fetchInsiderTradesDirect,
|
||||||
|
fetchRangeExplainDirect: _fetchRangeExplainDirect,
|
||||||
|
fetchStockStoryDirect: _fetchStockStoryDirect,
|
||||||
|
fetchSimilarDaysDirect: _fetchSimilarDaysDirect
|
||||||
|
} = apiHelpers;
|
||||||
|
|
||||||
|
const buildTickersFromSymbols = useCallback((symbols, previousTickers = []) => {
|
||||||
|
if (!Array.isArray(symbols) || symbols.length === 0) {
|
||||||
|
return previousTickers;
|
||||||
|
}
|
||||||
|
|
||||||
|
return symbols
|
||||||
|
.filter((symbol) => typeof symbol === 'string' && symbol.trim())
|
||||||
|
.map((symbol) => {
|
||||||
|
const normalized = symbol.trim().toUpperCase();
|
||||||
|
const existing = previousTickers.find((ticker) => ticker.symbol === normalized);
|
||||||
|
return existing || {
|
||||||
|
symbol: normalized,
|
||||||
|
price: null,
|
||||||
|
change: null
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const normalizePriceHistory = useCallback((payload) => {
|
||||||
|
if (!payload || typeof payload !== 'object') {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
const normalized = {};
|
||||||
|
Object.entries(payload).forEach(([symbol, points]) => {
|
||||||
|
const ticker = String(symbol || '').trim().toUpperCase();
|
||||||
|
if (!ticker || !Array.isArray(points)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized[ticker] = points
|
||||||
|
.map((point) => {
|
||||||
|
if (Array.isArray(point) && point.length >= 2) {
|
||||||
|
const [label, value] = point;
|
||||||
|
const price = Number(value);
|
||||||
|
if (!label || !Number.isFinite(price)) return null;
|
||||||
|
return {
|
||||||
|
timestamp: String(label),
|
||||||
|
label: String(label),
|
||||||
|
price
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (point && typeof point === 'object') {
|
||||||
|
const rawTimestamp = point.timestamp ?? point.t ?? point.date ?? point.label;
|
||||||
|
const price = Number(point.price ?? point.v ?? point.value ?? point.close);
|
||||||
|
if (!rawTimestamp || !Number.isFinite(price)) return null;
|
||||||
|
return {
|
||||||
|
timestamp: String(rawTimestamp),
|
||||||
|
label: String(rawTimestamp),
|
||||||
|
price
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
})
|
||||||
|
.filter(Boolean)
|
||||||
|
.slice(-120);
|
||||||
|
});
|
||||||
|
|
||||||
|
return normalized;
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const requestStockHistory = useCallback((symbol, { force = false } = {}) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!force && requestedStockHistoryRef.current.has(normalized)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const endDate = currentDate
|
||||||
|
? String(currentDate).slice(0, 10)
|
||||||
|
: new Date().toISOString().slice(0, 10);
|
||||||
|
const end = new Date(`${endDate}T00:00:00`);
|
||||||
|
const start = new Date(end);
|
||||||
|
start.setDate(start.getDate() - 120);
|
||||||
|
const startDate = start.toISOString().slice(0, 10);
|
||||||
|
|
||||||
|
if (_hasDirectTradingService()) {
|
||||||
|
void _fetchStockHistoryDirect(normalized, startDate, endDate)
|
||||||
|
.then((payload) => {
|
||||||
|
const prices = Array.isArray(payload?.prices) ? payload.prices : [];
|
||||||
|
setOhlcHistoryByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: prices
|
||||||
|
}));
|
||||||
|
setPriceHistoryByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: prices
|
||||||
|
.map((point) => {
|
||||||
|
const price = Number(point?.close);
|
||||||
|
const timestamp = point?.time;
|
||||||
|
if (!timestamp || !Number.isFinite(price)) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
timestamp: String(timestamp),
|
||||||
|
label: String(timestamp),
|
||||||
|
price
|
||||||
|
};
|
||||||
|
})
|
||||||
|
.filter(Boolean)
|
||||||
|
}));
|
||||||
|
setHistorySourceByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: 'trading_service'
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct stock-history fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'get_stock_history',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 120
|
||||||
|
});
|
||||||
|
if (success) {
|
||||||
|
requestedStockHistoryRef.current.add(normalized);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
requestedStockHistoryRef.current.add(normalized);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'get_stock_history',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 120
|
||||||
|
});
|
||||||
|
|
||||||
|
if (success) {
|
||||||
|
requestedStockHistoryRef.current.add(normalized);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success;
|
||||||
|
}, [currentDate, _hasDirectTradingService, _fetchStockHistoryDirect, clientRef, requestedStockHistoryRef, setOhlcHistoryByTicker, setPriceHistoryByTicker, setHistorySourceByTicker]);
|
||||||
|
|
||||||
|
const requestStockExplainEvents = useCallback((symbol) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_explain_events',
|
||||||
|
ticker: normalized
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
const requestStockNews = useCallback((symbol) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_news',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 45,
|
||||||
|
limit: 12
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
const requestStockNewsForDate = useCallback((symbol, date) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !date) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_hasDirectNewsService()) {
|
||||||
|
void _fetchNewsForDateDirect(normalized, date, 20)
|
||||||
|
.then((payload) => {
|
||||||
|
const targetDate = typeof payload?.date === 'string' ? payload.date.trim() : date;
|
||||||
|
const news = Array.isArray(payload?.news) ? payload.news : [];
|
||||||
|
const freshness = payload?.freshness || null;
|
||||||
|
setNewsByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
...(prev[normalized] || {}),
|
||||||
|
byDate: {
|
||||||
|
...((prev[normalized] && prev[normalized].byDate) || {}),
|
||||||
|
[targetDate]: news
|
||||||
|
},
|
||||||
|
byDateFreshness: {
|
||||||
|
...((prev[normalized] && prev[normalized].byDateFreshness) || {}),
|
||||||
|
[targetDate]: freshness
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct news-for-date fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_news_for_date',
|
||||||
|
ticker: normalized,
|
||||||
|
date,
|
||||||
|
limit: 20
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_news_for_date',
|
||||||
|
ticker: normalized,
|
||||||
|
date,
|
||||||
|
limit: 20
|
||||||
|
});
|
||||||
|
}, [clientRef, _hasDirectNewsService, _fetchNewsForDateDirect, setNewsByTicker]);
|
||||||
|
|
||||||
|
const requestStockNewsTimeline = useCallback((symbol) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_news_timeline',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 90
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
const requestStockNewsCategories = useCallback((symbol) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const endDate = currentDate
|
||||||
|
? String(currentDate).slice(0, 10)
|
||||||
|
: new Date().toISOString().slice(0, 10);
|
||||||
|
const end = new Date(`${endDate}T00:00:00`);
|
||||||
|
const start = new Date(end);
|
||||||
|
start.setDate(start.getDate() - 90);
|
||||||
|
const startDate = start.toISOString().slice(0, 10);
|
||||||
|
|
||||||
|
if (_hasDirectNewsService()) {
|
||||||
|
void _fetchNewsCategoriesDirect(normalized, startDate, endDate, 200)
|
||||||
|
.then((payload) => {
|
||||||
|
const freshness = payload?.freshness || null;
|
||||||
|
setNewsByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
...(prev[normalized] || {}),
|
||||||
|
categories: payload?.categories || {},
|
||||||
|
categoriesStartDate: startDate,
|
||||||
|
categoriesEndDate: endDate,
|
||||||
|
categoriesFreshness: freshness
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct news-categories fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_news_categories',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 90
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_news_categories',
|
||||||
|
ticker: normalized,
|
||||||
|
lookback_days: 90
|
||||||
|
});
|
||||||
|
}, [currentDate, clientRef, _hasDirectNewsService, _fetchNewsCategoriesDirect, setNewsByTicker]);
|
||||||
|
|
||||||
|
const requestStockInsiderTrades = useCallback((symbol, startDate = null, endDate = null) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_hasDirectTradingService()) {
|
||||||
|
void _fetchInsiderTradesDirect(normalized, startDate, endDate, 50)
|
||||||
|
.then((payload) => {
|
||||||
|
const rows = Array.isArray(payload?.insider_trades) ? payload.insider_trades : [];
|
||||||
|
setInsiderTradesByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
ticker: normalized,
|
||||||
|
startDate: startDate || null,
|
||||||
|
endDate: endDate || null,
|
||||||
|
trades: rows
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct insider-trades fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_insider_trades',
|
||||||
|
ticker: normalized,
|
||||||
|
start_date: startDate,
|
||||||
|
end_date: endDate,
|
||||||
|
limit: 50
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_insider_trades',
|
||||||
|
ticker: normalized,
|
||||||
|
start_date: startDate,
|
||||||
|
end_date: endDate,
|
||||||
|
limit: 50
|
||||||
|
});
|
||||||
|
}, [clientRef, _hasDirectTradingService, _fetchInsiderTradesDirect, setInsiderTradesByTicker]);
|
||||||
|
|
||||||
|
const requestStockTechnicalIndicators = useCallback((symbol) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_technical_indicators',
|
||||||
|
ticker: normalized
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
const requestStockRangeExplain = useCallback((symbol, startDate, endDate, articleIds = []) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !startDate || !endDate) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_hasDirectNewsService()) {
|
||||||
|
void _fetchRangeExplainDirect(normalized, startDate, endDate, articleIds)
|
||||||
|
.then((payload) => {
|
||||||
|
const result = payload?.result && typeof payload.result === 'object' ? payload.result : null;
|
||||||
|
const freshness = payload?.freshness || null;
|
||||||
|
if (!result?.start_date || !result?.end_date) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const cacheKey = `${result.start_date}:${result.end_date}`;
|
||||||
|
setNewsByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
...(prev[normalized] || {}),
|
||||||
|
rangeExplainCache: {
|
||||||
|
...((prev[normalized] && prev[normalized].rangeExplainCache) || {}),
|
||||||
|
[cacheKey]: {
|
||||||
|
...result,
|
||||||
|
freshness
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct range explain fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_range_explain',
|
||||||
|
ticker: normalized,
|
||||||
|
start_date: startDate,
|
||||||
|
end_date: endDate,
|
||||||
|
article_ids: Array.isArray(articleIds) ? articleIds : []
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_range_explain',
|
||||||
|
ticker: normalized,
|
||||||
|
start_date: startDate,
|
||||||
|
end_date: endDate,
|
||||||
|
article_ids: Array.isArray(articleIds) ? articleIds : []
|
||||||
|
});
|
||||||
|
}, [clientRef, _hasDirectNewsService, _fetchRangeExplainDirect, setNewsByTicker]);
|
||||||
|
|
||||||
|
const requestStockStory = useCallback((symbol, asOfDate) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
const date = typeof asOfDate === 'string' ? asOfDate.trim() : '';
|
||||||
|
if (!normalized || !date) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_hasDirectNewsService()) {
|
||||||
|
void _fetchStockStoryDirect(normalized, date)
|
||||||
|
.then((payload) => {
|
||||||
|
setNewsByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
...(prev[normalized] || {}),
|
||||||
|
storyCache: {
|
||||||
|
...((prev[normalized] && prev[normalized].storyCache) || {}),
|
||||||
|
[date]: {
|
||||||
|
story: payload?.story || '',
|
||||||
|
source: payload?.source || null,
|
||||||
|
asOfDate: date,
|
||||||
|
freshness: payload?.freshness || null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct story fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_story',
|
||||||
|
ticker: normalized,
|
||||||
|
as_of_date: date
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_story',
|
||||||
|
ticker: normalized,
|
||||||
|
as_of_date: date
|
||||||
|
});
|
||||||
|
}, [clientRef, _hasDirectNewsService, _fetchStockStoryDirect, setNewsByTicker]);
|
||||||
|
|
||||||
|
const requestStockSimilarDays = useCallback((symbol, targetDate, lookbackDays = 365) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
const date = typeof targetDate === 'string' ? targetDate.trim() : '';
|
||||||
|
if (!normalized || !date) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (_hasDirectNewsService()) {
|
||||||
|
void _fetchSimilarDaysDirect(normalized, date, lookbackDays)
|
||||||
|
.then((payload) => {
|
||||||
|
setNewsByTicker((prev) => ({
|
||||||
|
...prev,
|
||||||
|
[normalized]: {
|
||||||
|
...(prev[normalized] || {}),
|
||||||
|
similarDaysCache: {
|
||||||
|
...((prev[normalized] && prev[normalized].similarDaysCache) || {}),
|
||||||
|
[date]: {
|
||||||
|
target_features: payload?.target_features || {},
|
||||||
|
items: Array.isArray(payload?.items) ? payload?.items : [],
|
||||||
|
error: payload?.error || null,
|
||||||
|
freshness: payload?.freshness || null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('Direct similar-days fetch failed, falling back to websocket:', error);
|
||||||
|
if (clientRef.current) {
|
||||||
|
clientRef.current.send({
|
||||||
|
type: 'get_stock_similar_days',
|
||||||
|
ticker: normalized,
|
||||||
|
target_date: date,
|
||||||
|
lookback_days: lookbackDays
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'get_stock_similar_days',
|
||||||
|
ticker: normalized,
|
||||||
|
target_date: date,
|
||||||
|
lookback_days: lookbackDays
|
||||||
|
});
|
||||||
|
}, [clientRef, _hasDirectNewsService, _fetchSimilarDaysDirect, setNewsByTicker]);
|
||||||
|
|
||||||
|
const requestStockEnrich = useCallback((symbol, startDate, endDate, { force = false, onlyLocalToLlm = false } = {}) => {
|
||||||
|
const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : '';
|
||||||
|
if (!normalized || !clientRef.current) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return clientRef.current.send({
|
||||||
|
type: 'enrich_stock_news',
|
||||||
|
ticker: normalized,
|
||||||
|
start_date: startDate,
|
||||||
|
end_date: endDate,
|
||||||
|
force: Boolean(force),
|
||||||
|
only_local_to_llm: Boolean(onlyLocalToLlm)
|
||||||
|
});
|
||||||
|
}, [clientRef]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
buildTickersFromSymbols,
|
||||||
|
normalizePriceHistory,
|
||||||
|
requestStockHistory,
|
||||||
|
requestStockExplainEvents,
|
||||||
|
requestStockNews,
|
||||||
|
requestStockNewsForDate,
|
||||||
|
requestStockNewsTimeline,
|
||||||
|
requestStockNewsCategories,
|
||||||
|
requestStockInsiderTrades,
|
||||||
|
requestStockTechnicalIndicators,
|
||||||
|
requestStockRangeExplain,
|
||||||
|
requestStockStory,
|
||||||
|
requestStockSimilarDays,
|
||||||
|
requestStockEnrich
|
||||||
|
};
|
||||||
|
}
|
||||||
144
frontend/src/hooks/useWatchlistCallbacks.js
Normal file
144
frontend/src/hooks/useWatchlistCallbacks.js
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import { useCallback, useMemo } from 'react';
|
||||||
|
import { INITIAL_TICKERS } from '../config/constants';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts watchlist-related callbacks from App.jsx into a single hook.
|
||||||
|
*/
|
||||||
|
export function useWatchlistCallbacks({
|
||||||
|
clientRef,
|
||||||
|
runtimeWatchlistSymbols,
|
||||||
|
watchlistDraftSymbols,
|
||||||
|
watchlistInputValue,
|
||||||
|
watchlistFeedback,
|
||||||
|
setters
|
||||||
|
}) {
|
||||||
|
const {
|
||||||
|
setWatchlistDraftSymbols,
|
||||||
|
setWatchlistInputValue,
|
||||||
|
setWatchlistFeedback
|
||||||
|
} = setters;
|
||||||
|
|
||||||
|
const parseWatchlistInput = useCallback((value) => {
|
||||||
|
if (typeof value !== 'string') {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
return Array.from(
|
||||||
|
new Set(
|
||||||
|
value
|
||||||
|
.split(/[\s,]+/)
|
||||||
|
.map((symbol) => symbol.trim().toUpperCase())
|
||||||
|
.filter(Boolean)
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const commitWatchlistInput = useCallback((value) => {
|
||||||
|
const parsed = parseWatchlistInput(value);
|
||||||
|
if (parsed.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
|
||||||
|
setWatchlistDraftSymbols((prev) => Array.from(new Set([...prev, ...parsed])));
|
||||||
|
setWatchlistInputValue('');
|
||||||
|
if (watchlistFeedback) {
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
}
|
||||||
|
return parsed;
|
||||||
|
}, [parseWatchlistInput, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback, setters]);
|
||||||
|
|
||||||
|
const handleWatchlistRemove = useCallback((symbolToRemove) => {
|
||||||
|
setWatchlistDraftSymbols((prev) => prev.filter((symbol) => symbol !== symbolToRemove));
|
||||||
|
if (watchlistFeedback) {
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
}
|
||||||
|
}, [watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
|
||||||
|
|
||||||
|
const handleWatchlistInputChange = useCallback((value) => {
|
||||||
|
setWatchlistInputValue(value);
|
||||||
|
if (watchlistFeedback) {
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
}
|
||||||
|
}, [watchlistFeedback, setWatchlistInputValue, setWatchlistFeedback]);
|
||||||
|
|
||||||
|
const handleWatchlistInputKeyDown = useCallback((e) => {
|
||||||
|
if (e.key === 'Enter' || e.key === ',') {
|
||||||
|
e.preventDefault();
|
||||||
|
commitWatchlistInput(watchlistInputValue);
|
||||||
|
}
|
||||||
|
}, [commitWatchlistInput, watchlistInputValue]);
|
||||||
|
|
||||||
|
const handleWatchlistSuggestionClick = useCallback((symbol) => {
|
||||||
|
if (watchlistDraftSymbols.includes(symbol)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setWatchlistDraftSymbols((prev) => [...prev, symbol]);
|
||||||
|
if (watchlistFeedback) {
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
}
|
||||||
|
}, [watchlistDraftSymbols, watchlistFeedback, setWatchlistDraftSymbols, setWatchlistFeedback]);
|
||||||
|
|
||||||
|
const handleWatchlistRestoreCurrent = useCallback(() => {
|
||||||
|
setWatchlistDraftSymbols(runtimeWatchlistSymbols);
|
||||||
|
setWatchlistInputValue('');
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
}, [runtimeWatchlistSymbols, setWatchlistDraftSymbols, setWatchlistInputValue, setWatchlistFeedback]);
|
||||||
|
|
||||||
|
const handleWatchlistSave = useCallback(() => {
|
||||||
|
const pendingTickers = parseWatchlistInput(watchlistInputValue);
|
||||||
|
const nextTickers = Array.from(new Set([...watchlistDraftSymbols, ...pendingTickers]));
|
||||||
|
if (nextTickers.length === 0) {
|
||||||
|
setWatchlistFeedback({ type: 'error', text: '至少输入 1 个有效股票代码' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!clientRef.current) {
|
||||||
|
setWatchlistFeedback({ type: 'error', text: '连接未就绪,稍后重试' });
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setters.setIsWatchlistSaving(true);
|
||||||
|
setWatchlistFeedback(null);
|
||||||
|
setWatchlistDraftSymbols(nextTickers);
|
||||||
|
setWatchlistInputValue('');
|
||||||
|
const success = clientRef.current.send({
|
||||||
|
type: 'update_watchlist',
|
||||||
|
tickers: nextTickers
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!success) {
|
||||||
|
setters.setIsWatchlistSaving(false);
|
||||||
|
setWatchlistFeedback({ type: 'error', text: '发送失败,请检查连接状态' });
|
||||||
|
}
|
||||||
|
}, [parseWatchlistInput, watchlistDraftSymbols, watchlistInputValue, clientRef, setters.setIsWatchlistSaving, setWatchlistFeedback, setWatchlistDraftSymbols, setWatchlistInputValue]);
|
||||||
|
|
||||||
|
const watchlistSuggestions = useMemo(
|
||||||
|
() => INITIAL_TICKERS.map((ticker) => ticker.symbol).filter((symbol, index, list) => list.indexOf(symbol) === index),
|
||||||
|
[]
|
||||||
|
);
|
||||||
|
|
||||||
|
const isWatchlistDraftDirty = useMemo(() => {
|
||||||
|
if (watchlistInputValue.trim()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (watchlistDraftSymbols.length !== runtimeWatchlistSymbols.length) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return watchlistDraftSymbols.some((symbol, index) => symbol !== runtimeWatchlistSymbols[index]);
|
||||||
|
}, [runtimeWatchlistSymbols, watchlistDraftSymbols, watchlistInputValue]);
|
||||||
|
|
||||||
|
return {
|
||||||
|
parseWatchlistInput,
|
||||||
|
commitWatchlistInput,
|
||||||
|
handleWatchlistRemove,
|
||||||
|
handleWatchlistInputChange,
|
||||||
|
handleWatchlistInputKeyDown,
|
||||||
|
handleWatchlistSuggestionClick,
|
||||||
|
handleWatchlistRestoreCurrent,
|
||||||
|
handleWatchlistSave,
|
||||||
|
watchlistSuggestions,
|
||||||
|
isWatchlistDraftDirty
|
||||||
|
};
|
||||||
|
}
|
||||||
1057
frontend/src/hooks/useWebSocketHandler.js
Normal file
1057
frontend/src/hooks/useWebSocketHandler.js
Normal file
File diff suppressed because it is too large
Load Diff
@@ -87,4 +87,8 @@ export const useRuntimeStore = create((set) => ({
|
|||||||
isRuntimeConfigSaving: false,
|
isRuntimeConfigSaving: false,
|
||||||
setRuntimeConfigFeedback: (runtimeConfigFeedback) => set({ runtimeConfigFeedback }),
|
setRuntimeConfigFeedback: (runtimeConfigFeedback) => set({ runtimeConfigFeedback }),
|
||||||
setIsRuntimeConfigSaving: (isRuntimeConfigSaving) => set({ isRuntimeConfigSaving }),
|
setIsRuntimeConfigSaving: (isRuntimeConfigSaving) => set({ isRuntimeConfigSaving }),
|
||||||
|
|
||||||
|
// Last day history (for replay)
|
||||||
|
lastDayHistory: [],
|
||||||
|
setLastDayHistory: (lastDayHistory) => set({ lastDayHistory }),
|
||||||
}));
|
}));
|
||||||
|
|||||||
Reference in New Issue
Block a user