P0 修复: - runtimeStore: 添加缺失的 lastDayHistory 字段 - Gateway/RuntimeService: 状态同步改为内存优先,消除 glob 竞态 - App.jsx: 从 3075 行重构到 ~500 行,提取 8 个独立文件 P1 修复: - CORS: 4 个服务改为从环境变量读取允许 origins - MarketStore: 改为模块级单例模式 - Domain 层: 删除 trading thin wrapper,保留 news 真实逻辑 - 测试: 补齐 77 个 gateway/runtime 测试 新增文件: - backend/tests/test_gateway.py (43 tests) - frontend/src/hooks/useWebSocketHandler.js - frontend/src/hooks/useStockRequestCallbacks.js - frontend/src/hooks/useAgentCallbacks.js - frontend/src/hooks/useRuntimeCallbacks.js - frontend/src/hooks/useWatchlistCallbacks.js - frontend/src/components/TickerBar.jsx - frontend/src/components/HeaderRight.jsx - frontend/src/components/ChartTabs.jsx Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
550 lines
18 KiB
Python
550 lines
18 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Tests for the Gateway main class - core behavior and fallback paths."""
|
|
|
|
import json
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from backend.services.gateway import Gateway
|
|
import backend.services.gateway as gateway_module
|
|
|
|
|
|
class DummyWebSocket:
|
|
def __init__(self):
|
|
self.messages = []
|
|
self.closed = False
|
|
self._queue = []
|
|
|
|
def queue(self, data: str):
|
|
"""Queue a raw message string to be yielded by the async iterator."""
|
|
self._queue.append(data)
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if not self._queue:
|
|
raise StopAsyncIteration
|
|
return self._queue.pop(0)
|
|
|
|
async def send(self, payload: str):
|
|
self.messages.append(json.loads(payload))
|
|
|
|
async def close(self):
|
|
self.closed = True
|
|
|
|
|
|
class DummyStateSync:
|
|
def __init__(self, current_date="2026-03-16"):
|
|
self.state = {"current_date": current_date}
|
|
self.system_messages = []
|
|
self.saved = False
|
|
self.initial_state_payload = {}
|
|
|
|
def set_broadcast_fn(self, _fn):
|
|
return None
|
|
|
|
def update_state(self, key, value):
|
|
self.state[key] = value
|
|
|
|
def save_state(self):
|
|
self.saved = True
|
|
|
|
async def on_system_message(self, message):
|
|
self.system_messages.append(message)
|
|
|
|
def get_initial_state_payload(self, include_dashboard=True):
|
|
return {
|
|
"status": "running",
|
|
"current_date": self.state.get("current_date", ""),
|
|
"portfolio": {},
|
|
"holdings": [],
|
|
"trades": [],
|
|
}
|
|
|
|
|
|
class DummyMarketService:
|
|
def __init__(self):
|
|
self.broadcast_func = None
|
|
self.market_status = {"is_open": True, "session": "regular"}
|
|
|
|
def set_price_recorder(self, _fn):
|
|
return None
|
|
|
|
async def start(self, broadcast_func=None):
|
|
self.broadcast_func = broadcast_func
|
|
|
|
def get_market_status(self):
|
|
return self.market_status
|
|
|
|
|
|
class DummyStorage:
|
|
def __init__(self, initial_cash=100000.0, live_session=False):
|
|
self.initial_cash = initial_cash
|
|
self.is_live_session_active = live_session
|
|
self._market_store = SimpleNamespace()
|
|
|
|
@property
|
|
def market_store(self):
|
|
return self._market_store
|
|
|
|
def load_file(self, name):
|
|
if name == "summary":
|
|
return {"totalAssetValue": self.initial_cash}
|
|
if name in ("holdings", "trades"):
|
|
return []
|
|
return None
|
|
|
|
def get_live_returns(self):
|
|
return {"session_pnl": 0.0, "session_return": 0.0}
|
|
|
|
|
|
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
|
|
storage = storage or DummyStorage()
|
|
state_sync = state_sync or DummyStateSync()
|
|
market_service = market_service or DummyMarketService()
|
|
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
|
|
return Gateway(
|
|
market_service=market_service,
|
|
storage_service=storage,
|
|
pipeline=pipeline,
|
|
state_sync=state_sync,
|
|
config=config or {"mode": "live"},
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Gateway initialization and core properties
|
|
# =============================================================================
|
|
|
|
def test_gateway_init_sets_live_mode():
|
|
gateway = make_gateway(config={"mode": "live"})
|
|
assert gateway.mode == "live"
|
|
assert gateway.is_backtest is False
|
|
|
|
|
|
def test_gateway_init_sets_backtest_mode_from_config():
|
|
gateway = make_gateway(config={"mode": "backtest"})
|
|
assert gateway.mode == "backtest"
|
|
assert gateway.is_backtest is True
|
|
|
|
|
|
def test_gateway_init_sets_backtest_mode_from_flag():
|
|
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
|
|
assert gateway.is_backtest is True
|
|
|
|
|
|
def test_gateway_init_defaults_to_live_mode():
|
|
gateway = make_gateway(config={})
|
|
assert gateway.mode == "live"
|
|
assert gateway.is_backtest is False
|
|
|
|
|
|
def test_gateway_state_property_returns_state_sync_state():
|
|
state_sync = DummyStateSync()
|
|
state_sync.state["foo"] = "bar"
|
|
gateway = make_gateway(state_sync=state_sync)
|
|
assert gateway.state["foo"] == "bar"
|
|
|
|
|
|
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
|
|
rows = [{"id": "1"}, {"id": "2"}]
|
|
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
|
|
result = Gateway._news_rows_need_enrichment(rows)
|
|
mock.assert_called_once_with(rows)
|
|
assert result is True
|
|
|
|
|
|
# =============================================================================
|
|
# Service URL helpers and fallback paths
|
|
# =============================================================================
|
|
|
|
def test_news_service_url_returns_config_value(monkeypatch):
|
|
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
|
|
assert gateway._news_service_url() == "http://custom-news:9000"
|
|
|
|
|
|
def test_news_service_url_falls_back_to_env(monkeypatch):
|
|
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
|
|
gateway = make_gateway(config={})
|
|
assert gateway._news_service_url() == "http://env-news:9001"
|
|
|
|
|
|
def test_news_service_url_returns_none_when_unset(monkeypatch):
|
|
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
|
gateway = make_gateway(config={})
|
|
assert gateway._news_service_url() is None
|
|
|
|
|
|
def test_news_service_url_strips_whitespace(monkeypatch):
|
|
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
|
|
assert gateway._news_service_url() == "http://whitespace-news:9000"
|
|
|
|
|
|
def test_trading_service_url_returns_config_value(monkeypatch):
|
|
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
|
|
assert gateway._trading_service_url() == "http://custom-trading:9000"
|
|
|
|
|
|
def test_trading_service_url_falls_back_to_env(monkeypatch):
|
|
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
|
|
gateway = make_gateway(config={})
|
|
assert gateway._trading_service_url() == "http://env-trading:9001"
|
|
|
|
|
|
def test_trading_service_url_returns_none_when_unset(monkeypatch):
|
|
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
|
gateway = make_gateway(config={})
|
|
assert gateway._trading_service_url() is None
|
|
|
|
|
|
def test_trading_service_url_strips_whitespace(monkeypatch):
|
|
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
|
|
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
|
|
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
|
gateway = make_gateway(config={})
|
|
|
|
async def dummy_callback(client):
|
|
return "should not be called"
|
|
|
|
result = await gateway._call_news_service("test_action", dummy_callback)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_news_service_calls_callback_and_returns():
|
|
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
|
|
|
async def callback(client):
|
|
return {"result": "ok"}
|
|
|
|
result = await gateway._call_news_service("test_action", callback)
|
|
assert result == {"result": "ok"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_news_service_returns_none_on_exception():
|
|
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
|
|
|
async def failing_callback(client):
|
|
raise RuntimeError("connection failed")
|
|
|
|
result = await gateway._call_news_service("test_action", failing_callback)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
|
|
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
|
gateway = make_gateway(config={})
|
|
|
|
result = await gateway._call_trading_service("test_action", lambda c: None)
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_trading_service_calls_callback_and_returns():
|
|
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
|
|
|
async def callback(client):
|
|
return {"result": "ok"}
|
|
result = await gateway._call_trading_service("test_action", callback)
|
|
assert result == {"result": "ok"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_trading_service_returns_none_on_exception():
|
|
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
|
|
|
async def failing_callback(client):
|
|
raise RuntimeError("connection failed")
|
|
|
|
result = await gateway._call_trading_service("test_action", failing_callback)
|
|
assert result is None
|
|
|
|
|
|
# =============================================================================
|
|
# WebSocket message handlers
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_messages_ping_returns_pong():
|
|
"""Ping message type results in a pong response."""
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
ws.queue(json.dumps({"type": "ping"}))
|
|
|
|
await gateway._handle_client_messages(ws)
|
|
|
|
assert ws.messages[-1]["type"] == "pong"
|
|
assert "timestamp" in ws.messages[-1]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_messages_get_state_sends_initial_state():
|
|
"""get_state message type triggers _send_initial_state."""
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
ws.queue(json.dumps({"type": "get_state"}))
|
|
|
|
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
|
|
await gateway._handle_client_messages(ws)
|
|
mock_send.assert_called_once_with(ws)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_messages_unknown_type_is_silently_ignored():
|
|
"""Unknown message types are silently ignored without error."""
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
ws.queue(json.dumps({"type": "unknown_type"}))
|
|
|
|
# Should not raise
|
|
await gateway._handle_client_messages(ws)
|
|
assert len(ws.messages) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
|
|
"""Invalid JSON messages are caught by the handler's except block."""
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
ws.queue("not valid json")
|
|
|
|
# Should not raise
|
|
await gateway._handle_client_messages(ws)
|
|
assert len(ws.messages) == 0
|
|
|
|
|
|
# =============================================================================
|
|
# Backtest handling
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
|
|
gateway = make_gateway(config={"mode": "live"})
|
|
# Should not raise - backtest is ignored in live mode
|
|
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
|
|
# Gateway should not have started a backtest task
|
|
assert gateway._backtest_task is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_start_backtest_ignored_when_task_already_running():
|
|
gateway = make_gateway(config={"mode": "backtest"})
|
|
|
|
# Pre-set a backtest task
|
|
dummy_task = MagicMock()
|
|
dummy_task.done.return_value = False
|
|
gateway._backtest_task = dummy_task
|
|
|
|
# Should not start a new task
|
|
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
|
|
|
|
assert gateway._backtest_task is dummy_task # unchanged
|
|
|
|
|
|
# =============================================================================
|
|
# Manual trigger (live/mock mode)
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_manual_trigger_rejected_in_backtest_mode():
|
|
gateway = make_gateway(config={"mode": "backtest"})
|
|
ws = DummyWebSocket()
|
|
|
|
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
|
|
|
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
|
|
gateway = make_gateway(config={"mode": "live"})
|
|
ws = DummyWebSocket()
|
|
|
|
# Simulate a running cycle task
|
|
dummy_task = MagicMock()
|
|
dummy_task.done.return_value = False
|
|
gateway._manual_cycle_task = dummy_task
|
|
|
|
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
|
|
|
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
|
|
|
|
|
|
# =============================================================================
|
|
# Normalization helpers
|
|
# =============================================================================
|
|
|
|
def test_normalize_watchlist_filters_empty_and_dedupes():
|
|
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
|
|
assert result == ["AAPL", "MSFT"]
|
|
|
|
|
|
def test_normalize_watchlist_handles_string_input():
|
|
result = Gateway._normalize_watchlist("aapl, msft, aapl")
|
|
assert result == ["AAPL", "MSFT"]
|
|
|
|
|
|
def test_normalize_agent_workspace_filename_allows_editable_files():
|
|
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
|
|
result = Gateway._normalize_agent_workspace_filename(filename)
|
|
assert result == filename
|
|
|
|
|
|
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
|
|
result = Gateway._normalize_agent_workspace_filename("README.md")
|
|
assert result is None
|
|
|
|
|
|
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
|
|
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
|
|
assert result is None
|
|
|
|
|
|
# =============================================================================
|
|
# Broadcast
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_skips_when_no_clients():
|
|
gateway = make_gateway()
|
|
gateway.connected_clients = set()
|
|
|
|
# Should not raise
|
|
await gateway.broadcast({"type": "test"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_sends_to_all_connected_clients():
|
|
gateway = make_gateway()
|
|
ws1 = DummyWebSocket()
|
|
ws2 = DummyWebSocket()
|
|
gateway.connected_clients = {ws1, ws2}
|
|
|
|
await gateway.broadcast({"type": "market_update", "data": "test"})
|
|
|
|
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
|
assert ws1.messages[0]["data"] == "test"
|
|
assert ws2.messages[0]["data"] == "test"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_removes_closed_connections():
|
|
"""Verify closed connections are removed from connected_clients set.
|
|
|
|
The broadcast method's _send_to_client helper removes a client
|
|
when it raises websockets.ConnectionClosed.
|
|
"""
|
|
gateway = make_gateway()
|
|
closed_ws = DummyWebSocket()
|
|
open_ws = DummyWebSocket()
|
|
gateway.connected_clients = {closed_ws, open_ws}
|
|
|
|
# Make closed_ws.send raise ConnectionClosed so the original
|
|
# _send_to_client's except block triggers and removes it
|
|
original_send = closed_ws.send
|
|
async def raising_send(payload):
|
|
raise gateway_module.websockets.ConnectionClosed(None, None)
|
|
closed_ws.send = raising_send
|
|
|
|
try:
|
|
await gateway.broadcast({"type": "test"})
|
|
except gateway_module.websockets.ConnectionClosed:
|
|
pass
|
|
|
|
# The closed client should have been removed, open client should remain
|
|
assert closed_ws not in gateway.connected_clients
|
|
assert open_ws in gateway.connected_clients
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_sends_to_all_connected_clients():
|
|
"""Verify broadcast sends to all connected clients and collects results."""
|
|
gateway = make_gateway()
|
|
ws1 = DummyWebSocket()
|
|
ws2 = DummyWebSocket()
|
|
gateway.connected_clients = {ws1, ws2}
|
|
|
|
await gateway.broadcast({"type": "market_update", "data": "test"})
|
|
|
|
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
|
assert ws1.messages[0]["data"] == "test"
|
|
assert ws2.messages[0]["data"] == "test"
|
|
|
|
|
|
# =============================================================================
|
|
# Stop
|
|
# =============================================================================
|
|
|
|
def test_stop_gateway_calls_cycle_support():
|
|
gateway = make_gateway()
|
|
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
|
|
gateway.stop()
|
|
mock.assert_called_once_with(gateway)
|
|
|
|
|
|
# =============================================================================
|
|
# set_backtest_dates
|
|
# =============================================================================
|
|
|
|
def test_set_backtest_dates_delegates_to_cycle_support():
|
|
gateway = make_gateway()
|
|
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
|
|
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
|
|
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
|
|
|
|
|
|
# =============================================================================
|
|
# Provider usage change callback
|
|
# =============================================================================
|
|
|
|
def test_on_provider_usage_changed_updates_state_sync():
|
|
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
|
|
gateway = make_gateway()
|
|
gateway._loop = None # no loop set
|
|
|
|
snapshot = {"provider": "finnhub", "calls": 10}
|
|
gateway._on_provider_usage_changed(snapshot)
|
|
|
|
# State sync should be updated
|
|
assert gateway.state_sync.state.get("data_sources") == snapshot
|
|
|
|
|
|
# =============================================================================
|
|
# handle_client lifecycle
|
|
# =============================================================================
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_adds_and_removes_client_from_connected_set():
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
|
|
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
|
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
|
await gateway.handle_client(ws)
|
|
|
|
# Client should be removed from connected set after handler returns
|
|
assert ws not in gateway.connected_clients
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handle_client_adds_client_before_handler():
|
|
gateway = make_gateway()
|
|
ws = DummyWebSocket()
|
|
|
|
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
|
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
|
await gateway.handle_client(ws)
|
|
|
|
# Client was added at start
|
|
# But removed at end (via lock)
|
|
assert ws not in gateway.connected_clients
|