Files
evotraders/backend/tests/test_gateway.py
cillin 3926a6bd07 feat: 架构修复 - P0/P1 问题全面修复
P0 修复:
- runtimeStore: 添加缺失的 lastDayHistory 字段
- Gateway/RuntimeService: 状态同步改为内存优先,消除 glob 竞态
- App.jsx: 从 3075 行重构到 ~500 行,提取 8 个独立文件

P1 修复:
- CORS: 4 个服务改为从环境变量读取允许 origins
- MarketStore: 改为模块级单例模式
- Domain 层: 删除 trading thin wrapper,保留 news 真实逻辑
- 测试: 补齐 77 个 gateway/runtime 测试

新增文件:
- backend/tests/test_gateway.py (43 tests)
- frontend/src/hooks/useWebSocketHandler.js
- frontend/src/hooks/useStockRequestCallbacks.js
- frontend/src/hooks/useAgentCallbacks.js
- frontend/src/hooks/useRuntimeCallbacks.js
- frontend/src/hooks/useWatchlistCallbacks.js
- frontend/src/components/TickerBar.jsx
- frontend/src/components/HeaderRight.jsx
- frontend/src/components/ChartTabs.jsx

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 18:45:57 +08:00

550 lines
18 KiB
Python

# -*- coding: utf-8 -*-
"""Tests for the Gateway main class - core behavior and fallback paths."""
import json
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
class DummyWebSocket:
def __init__(self):
self.messages = []
self.closed = False
self._queue = []
def queue(self, data: str):
"""Queue a raw message string to be yielded by the async iterator."""
self._queue.append(data)
def __aiter__(self):
return self
async def __anext__(self):
if not self._queue:
raise StopAsyncIteration
return self._queue.pop(0)
async def send(self, payload: str):
self.messages.append(json.loads(payload))
async def close(self):
self.closed = True
class DummyStateSync:
def __init__(self, current_date="2026-03-16"):
self.state = {"current_date": current_date}
self.system_messages = []
self.saved = False
self.initial_state_payload = {}
def set_broadcast_fn(self, _fn):
return None
def update_state(self, key, value):
self.state[key] = value
def save_state(self):
self.saved = True
async def on_system_message(self, message):
self.system_messages.append(message)
def get_initial_state_payload(self, include_dashboard=True):
return {
"status": "running",
"current_date": self.state.get("current_date", ""),
"portfolio": {},
"holdings": [],
"trades": [],
}
class DummyMarketService:
def __init__(self):
self.broadcast_func = None
self.market_status = {"is_open": True, "session": "regular"}
def set_price_recorder(self, _fn):
return None
async def start(self, broadcast_func=None):
self.broadcast_func = broadcast_func
def get_market_status(self):
return self.market_status
class DummyStorage:
def __init__(self, initial_cash=100000.0, live_session=False):
self.initial_cash = initial_cash
self.is_live_session_active = live_session
self._market_store = SimpleNamespace()
@property
def market_store(self):
return self._market_store
def load_file(self, name):
if name == "summary":
return {"totalAssetValue": self.initial_cash}
if name in ("holdings", "trades"):
return []
return None
def get_live_returns(self):
return {"session_pnl": 0.0, "session_return": 0.0}
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
storage = storage or DummyStorage()
state_sync = state_sync or DummyStateSync()
market_service = market_service or DummyMarketService()
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
return Gateway(
market_service=market_service,
storage_service=storage,
pipeline=pipeline,
state_sync=state_sync,
config=config or {"mode": "live"},
)
# =============================================================================
# Gateway initialization and core properties
# =============================================================================
def test_gateway_init_sets_live_mode():
gateway = make_gateway(config={"mode": "live"})
assert gateway.mode == "live"
assert gateway.is_backtest is False
def test_gateway_init_sets_backtest_mode_from_config():
gateway = make_gateway(config={"mode": "backtest"})
assert gateway.mode == "backtest"
assert gateway.is_backtest is True
def test_gateway_init_sets_backtest_mode_from_flag():
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
assert gateway.is_backtest is True
def test_gateway_init_defaults_to_live_mode():
gateway = make_gateway(config={})
assert gateway.mode == "live"
assert gateway.is_backtest is False
def test_gateway_state_property_returns_state_sync_state():
state_sync = DummyStateSync()
state_sync.state["foo"] = "bar"
gateway = make_gateway(state_sync=state_sync)
assert gateway.state["foo"] == "bar"
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
rows = [{"id": "1"}, {"id": "2"}]
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
result = Gateway._news_rows_need_enrichment(rows)
mock.assert_called_once_with(rows)
assert result is True
# =============================================================================
# Service URL helpers and fallback paths
# =============================================================================
def test_news_service_url_returns_config_value(monkeypatch):
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
assert gateway._news_service_url() == "http://custom-news:9000"
def test_news_service_url_falls_back_to_env(monkeypatch):
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
gateway = make_gateway(config={})
assert gateway._news_service_url() == "http://env-news:9001"
def test_news_service_url_returns_none_when_unset(monkeypatch):
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
assert gateway._news_service_url() is None
def test_news_service_url_strips_whitespace(monkeypatch):
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
assert gateway._news_service_url() == "http://whitespace-news:9000"
def test_trading_service_url_returns_config_value(monkeypatch):
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
assert gateway._trading_service_url() == "http://custom-trading:9000"
def test_trading_service_url_falls_back_to_env(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
gateway = make_gateway(config={})
assert gateway._trading_service_url() == "http://env-trading:9001"
def test_trading_service_url_returns_none_when_unset(monkeypatch):
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
assert gateway._trading_service_url() is None
def test_trading_service_url_strips_whitespace(monkeypatch):
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
@pytest.mark.asyncio
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
async def dummy_callback(client):
return "should not be called"
result = await gateway._call_news_service("test_action", dummy_callback)
assert result is None
@pytest.mark.asyncio
async def test_call_news_service_calls_callback_and_returns():
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
async def callback(client):
return {"result": "ok"}
result = await gateway._call_news_service("test_action", callback)
assert result == {"result": "ok"}
@pytest.mark.asyncio
async def test_call_news_service_returns_none_on_exception():
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
async def failing_callback(client):
raise RuntimeError("connection failed")
result = await gateway._call_news_service("test_action", failing_callback)
assert result is None
@pytest.mark.asyncio
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
gateway = make_gateway(config={})
result = await gateway._call_trading_service("test_action", lambda c: None)
assert result is None
@pytest.mark.asyncio
async def test_call_trading_service_calls_callback_and_returns():
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
async def callback(client):
return {"result": "ok"}
result = await gateway._call_trading_service("test_action", callback)
assert result == {"result": "ok"}
@pytest.mark.asyncio
async def test_call_trading_service_returns_none_on_exception():
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
async def failing_callback(client):
raise RuntimeError("connection failed")
result = await gateway._call_trading_service("test_action", failing_callback)
assert result is None
# =============================================================================
# WebSocket message handlers
# =============================================================================
@pytest.mark.asyncio
async def test_handle_client_messages_ping_returns_pong():
"""Ping message type results in a pong response."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "ping"}))
await gateway._handle_client_messages(ws)
assert ws.messages[-1]["type"] == "pong"
assert "timestamp" in ws.messages[-1]
@pytest.mark.asyncio
async def test_handle_client_messages_get_state_sends_initial_state():
"""get_state message type triggers _send_initial_state."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "get_state"}))
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
await gateway._handle_client_messages(ws)
mock_send.assert_called_once_with(ws)
@pytest.mark.asyncio
async def test_handle_client_messages_unknown_type_is_silently_ignored():
"""Unknown message types are silently ignored without error."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue(json.dumps({"type": "unknown_type"}))
# Should not raise
await gateway._handle_client_messages(ws)
assert len(ws.messages) == 0
@pytest.mark.asyncio
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
"""Invalid JSON messages are caught by the handler's except block."""
gateway = make_gateway()
ws = DummyWebSocket()
ws.queue("not valid json")
# Should not raise
await gateway._handle_client_messages(ws)
assert len(ws.messages) == 0
# =============================================================================
# Backtest handling
# =============================================================================
@pytest.mark.asyncio
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
gateway = make_gateway(config={"mode": "live"})
# Should not raise - backtest is ignored in live mode
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
# Gateway should not have started a backtest task
assert gateway._backtest_task is None
@pytest.mark.asyncio
async def test_handle_start_backtest_ignored_when_task_already_running():
gateway = make_gateway(config={"mode": "backtest"})
# Pre-set a backtest task
dummy_task = MagicMock()
dummy_task.done.return_value = False
gateway._backtest_task = dummy_task
# Should not start a new task
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
assert gateway._backtest_task is dummy_task # unchanged
# =============================================================================
# Manual trigger (live/mock mode)
# =============================================================================
@pytest.mark.asyncio
async def test_handle_manual_trigger_rejected_in_backtest_mode():
gateway = make_gateway(config={"mode": "backtest"})
ws = DummyWebSocket()
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
@pytest.mark.asyncio
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
gateway = make_gateway(config={"mode": "live"})
ws = DummyWebSocket()
# Simulate a running cycle task
dummy_task = MagicMock()
dummy_task.done.return_value = False
gateway._manual_cycle_task = dummy_task
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
# =============================================================================
# Normalization helpers
# =============================================================================
def test_normalize_watchlist_filters_empty_and_dedupes():
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
assert result == ["AAPL", "MSFT"]
def test_normalize_watchlist_handles_string_input():
result = Gateway._normalize_watchlist("aapl, msft, aapl")
assert result == ["AAPL", "MSFT"]
def test_normalize_agent_workspace_filename_allows_editable_files():
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
result = Gateway._normalize_agent_workspace_filename(filename)
assert result == filename
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
result = Gateway._normalize_agent_workspace_filename("README.md")
assert result is None
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
assert result is None
# =============================================================================
# Broadcast
# =============================================================================
@pytest.mark.asyncio
async def test_broadcast_skips_when_no_clients():
gateway = make_gateway()
gateway.connected_clients = set()
# Should not raise
await gateway.broadcast({"type": "test"})
@pytest.mark.asyncio
async def test_broadcast_sends_to_all_connected_clients():
gateway = make_gateway()
ws1 = DummyWebSocket()
ws2 = DummyWebSocket()
gateway.connected_clients = {ws1, ws2}
await gateway.broadcast({"type": "market_update", "data": "test"})
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
assert ws1.messages[0]["data"] == "test"
assert ws2.messages[0]["data"] == "test"
@pytest.mark.asyncio
async def test_broadcast_removes_closed_connections():
"""Verify closed connections are removed from connected_clients set.
The broadcast method's _send_to_client helper removes a client
when it raises websockets.ConnectionClosed.
"""
gateway = make_gateway()
closed_ws = DummyWebSocket()
open_ws = DummyWebSocket()
gateway.connected_clients = {closed_ws, open_ws}
# Make closed_ws.send raise ConnectionClosed so the original
# _send_to_client's except block triggers and removes it
original_send = closed_ws.send
async def raising_send(payload):
raise gateway_module.websockets.ConnectionClosed(None, None)
closed_ws.send = raising_send
try:
await gateway.broadcast({"type": "test"})
except gateway_module.websockets.ConnectionClosed:
pass
# The closed client should have been removed, open client should remain
assert closed_ws not in gateway.connected_clients
assert open_ws in gateway.connected_clients
@pytest.mark.asyncio
async def test_broadcast_sends_to_all_connected_clients():
"""Verify broadcast sends to all connected clients and collects results."""
gateway = make_gateway()
ws1 = DummyWebSocket()
ws2 = DummyWebSocket()
gateway.connected_clients = {ws1, ws2}
await gateway.broadcast({"type": "market_update", "data": "test"})
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
assert ws1.messages[0]["data"] == "test"
assert ws2.messages[0]["data"] == "test"
# =============================================================================
# Stop
# =============================================================================
def test_stop_gateway_calls_cycle_support():
gateway = make_gateway()
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
gateway.stop()
mock.assert_called_once_with(gateway)
# =============================================================================
# set_backtest_dates
# =============================================================================
def test_set_backtest_dates_delegates_to_cycle_support():
gateway = make_gateway()
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
# =============================================================================
# Provider usage change callback
# =============================================================================
def test_on_provider_usage_changed_updates_state_sync():
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
gateway = make_gateway()
gateway._loop = None # no loop set
snapshot = {"provider": "finnhub", "calls": 10}
gateway._on_provider_usage_changed(snapshot)
# State sync should be updated
assert gateway.state_sync.state.get("data_sources") == snapshot
# =============================================================================
# handle_client lifecycle
# =============================================================================
@pytest.mark.asyncio
async def test_handle_client_adds_and_removes_client_from_connected_set():
gateway = make_gateway()
ws = DummyWebSocket()
with patch.object(gateway, "_send_initial_state", AsyncMock()):
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
await gateway.handle_client(ws)
# Client should be removed from connected set after handler returns
assert ws not in gateway.connected_clients
@pytest.mark.asyncio
async def test_handle_client_adds_client_before_handler():
gateway = make_gateway()
ws = DummyWebSocket()
with patch.object(gateway, "_send_initial_state", AsyncMock()):
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
await gateway.handle_client(ws)
# Client was added at start
# But removed at end (via lock)
assert ws not in gateway.connected_clients