# -*- coding: utf-8 -*- """Tests for the Gateway main class - core behavior and fallback paths.""" import json from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from backend.services.gateway import Gateway import backend.services.gateway as gateway_module class DummyWebSocket: def __init__(self): self.messages = [] self.closed = False self._queue = [] def queue(self, data: str): """Queue a raw message string to be yielded by the async iterator.""" self._queue.append(data) def __aiter__(self): return self async def __anext__(self): if not self._queue: raise StopAsyncIteration return self._queue.pop(0) async def send(self, payload: str): self.messages.append(json.loads(payload)) async def close(self): self.closed = True class DummyStateSync: def __init__(self, current_date="2026-03-16"): self.state = {"current_date": current_date} self.system_messages = [] self.saved = False self.initial_state_payload = {} def set_broadcast_fn(self, _fn): return None def update_state(self, key, value): self.state[key] = value def save_state(self): self.saved = True async def on_system_message(self, message): self.system_messages.append(message) def get_initial_state_payload(self, include_dashboard=True): return { "status": "running", "current_date": self.state.get("current_date", ""), "portfolio": {}, "holdings": [], "trades": [], } class DummyMarketService: def __init__(self): self.broadcast_func = None self.market_status = {"is_open": True, "session": "regular"} def set_price_recorder(self, _fn): return None async def start(self, broadcast_func=None): self.broadcast_func = broadcast_func def get_market_status(self): return self.market_status class DummyStorage: def __init__(self, initial_cash=100000.0, live_session=False): self.initial_cash = initial_cash self.is_live_session_active = live_session self._market_store = SimpleNamespace() @property def market_store(self): return self._market_store def load_file(self, name): if name == "summary": return {"totalAssetValue": self.initial_cash} if name in ("holdings", "trades"): return [] return None def get_live_returns(self): return {"session_pnl": 0.0, "session_return": 0.0} def make_gateway(market_service=None, storage=None, state_sync=None, config=None): storage = storage or DummyStorage() state_sync = state_sync or DummyStateSync() market_service = market_service or DummyMarketService() pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0})) return Gateway( market_service=market_service, storage_service=storage, pipeline=pipeline, state_sync=state_sync, config=config or {"mode": "live"}, ) # ============================================================================= # Gateway initialization and core properties # ============================================================================= def test_gateway_init_sets_live_mode(): gateway = make_gateway(config={"mode": "live"}) assert gateway.mode == "live" assert gateway.is_backtest is False def test_gateway_init_sets_backtest_mode_from_config(): gateway = make_gateway(config={"mode": "backtest"}) assert gateway.mode == "backtest" assert gateway.is_backtest is True def test_gateway_init_sets_backtest_mode_from_flag(): gateway = make_gateway(config={"backtest_mode": True, "mode": "live"}) assert gateway.is_backtest is True def test_gateway_init_defaults_to_live_mode(): gateway = make_gateway(config={}) assert gateway.mode == "live" assert gateway.is_backtest is False def test_gateway_state_property_returns_state_sync_state(): state_sync = DummyStateSync() state_sync.state["foo"] = "bar" gateway = make_gateway(state_sync=state_sync) assert gateway.state["foo"] == "bar" def test_gateway_news_rows_need_enrichment_delegates_to_news_domain(): rows = [{"id": "1"}, {"id": "2"}] with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock: result = Gateway._news_rows_need_enrichment(rows) mock.assert_called_once_with(rows) assert result is True # ============================================================================= # Service URL helpers and fallback paths # ============================================================================= def test_news_service_url_returns_config_value(monkeypatch): gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"}) assert gateway._news_service_url() == "http://custom-news:9000" def test_news_service_url_falls_back_to_env(monkeypatch): monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001") gateway = make_gateway(config={}) assert gateway._news_service_url() == "http://env-news:9001" def test_news_service_url_returns_none_when_unset(monkeypatch): monkeypatch.delenv("NEWS_SERVICE_URL", raising=False) gateway = make_gateway(config={}) assert gateway._news_service_url() is None def test_news_service_url_strips_whitespace(monkeypatch): gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "}) assert gateway._news_service_url() == "http://whitespace-news:9000" def test_trading_service_url_returns_config_value(monkeypatch): gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"}) assert gateway._trading_service_url() == "http://custom-trading:9000" def test_trading_service_url_falls_back_to_env(monkeypatch): monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001") gateway = make_gateway(config={}) assert gateway._trading_service_url() == "http://env-trading:9001" def test_trading_service_url_returns_none_when_unset(monkeypatch): monkeypatch.delenv("TRADING_SERVICE_URL", raising=False) gateway = make_gateway(config={}) assert gateway._trading_service_url() is None def test_trading_service_url_strips_whitespace(monkeypatch): gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "}) assert gateway._trading_service_url() == "http://whitespace-trading:9000" @pytest.mark.asyncio async def test_call_news_service_returns_none_when_url_not_set(monkeypatch): monkeypatch.delenv("NEWS_SERVICE_URL", raising=False) gateway = make_gateway(config={}) async def dummy_callback(client): return "should not be called" result = await gateway._call_news_service("test_action", dummy_callback) assert result is None @pytest.mark.asyncio async def test_call_news_service_calls_callback_and_returns(): gateway = make_gateway(config={"news_service_url": "http://news:9000"}) async def callback(client): return {"result": "ok"} result = await gateway._call_news_service("test_action", callback) assert result == {"result": "ok"} @pytest.mark.asyncio async def test_call_news_service_returns_none_on_exception(): gateway = make_gateway(config={"news_service_url": "http://news:9000"}) async def failing_callback(client): raise RuntimeError("connection failed") result = await gateway._call_news_service("test_action", failing_callback) assert result is None @pytest.mark.asyncio async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch): monkeypatch.delenv("TRADING_SERVICE_URL", raising=False) gateway = make_gateway(config={}) result = await gateway._call_trading_service("test_action", lambda c: None) assert result is None @pytest.mark.asyncio async def test_call_trading_service_calls_callback_and_returns(): gateway = make_gateway(config={"trading_service_url": "http://trading:9000"}) async def callback(client): return {"result": "ok"} result = await gateway._call_trading_service("test_action", callback) assert result == {"result": "ok"} @pytest.mark.asyncio async def test_call_trading_service_returns_none_on_exception(): gateway = make_gateway(config={"trading_service_url": "http://trading:9000"}) async def failing_callback(client): raise RuntimeError("connection failed") result = await gateway._call_trading_service("test_action", failing_callback) assert result is None # ============================================================================= # WebSocket message handlers # ============================================================================= @pytest.mark.asyncio async def test_handle_client_messages_ping_returns_pong(): """Ping message type results in a pong response.""" gateway = make_gateway() ws = DummyWebSocket() ws.queue(json.dumps({"type": "ping"})) await gateway._handle_client_messages(ws) assert ws.messages[-1]["type"] == "pong" assert "timestamp" in ws.messages[-1] @pytest.mark.asyncio async def test_handle_client_messages_get_state_sends_initial_state(): """get_state message type triggers _send_initial_state.""" gateway = make_gateway() ws = DummyWebSocket() ws.queue(json.dumps({"type": "get_state"})) with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send: await gateway._handle_client_messages(ws) mock_send.assert_called_once_with(ws) @pytest.mark.asyncio async def test_handle_client_messages_unknown_type_is_silently_ignored(): """Unknown message types are silently ignored without error.""" gateway = make_gateway() ws = DummyWebSocket() ws.queue(json.dumps({"type": "unknown_type"})) # Should not raise await gateway._handle_client_messages(ws) assert len(ws.messages) == 0 @pytest.mark.asyncio async def test_handle_client_messages_json_decode_error_is_silently_ignored(): """Invalid JSON messages are caught by the handler's except block.""" gateway = make_gateway() ws = DummyWebSocket() ws.queue("not valid json") # Should not raise await gateway._handle_client_messages(ws) assert len(ws.messages) == 0 # ============================================================================= # Backtest handling # ============================================================================= @pytest.mark.asyncio async def test_handle_start_backtest_ignored_when_not_backtest_mode(): gateway = make_gateway(config={"mode": "live"}) # Should not raise - backtest is ignored in live mode await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]}) # Gateway should not have started a backtest task assert gateway._backtest_task is None @pytest.mark.asyncio async def test_handle_start_backtest_ignored_when_task_already_running(): gateway = make_gateway(config={"mode": "backtest"}) # Pre-set a backtest task dummy_task = MagicMock() dummy_task.done.return_value = False gateway._backtest_task = dummy_task # Should not start a new task await gateway._handle_start_backtest({"dates": ["2026-03-01"]}) assert gateway._backtest_task is dummy_task # unchanged # ============================================================================= # Manual trigger (live/mock mode) # ============================================================================= @pytest.mark.asyncio async def test_handle_manual_trigger_rejected_in_backtest_mode(): gateway = make_gateway(config={"mode": "backtest"}) ws = DummyWebSocket() await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"}) assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages) @pytest.mark.asyncio async def test_handle_manual_trigger_rejected_when_cycle_already_running(): gateway = make_gateway(config={"mode": "live"}) ws = DummyWebSocket() # Simulate a running cycle task dummy_task = MagicMock() dummy_task.done.return_value = False gateway._manual_cycle_task = dummy_task await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"}) assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages) # ============================================================================= # Normalization helpers # ============================================================================= def test_normalize_watchlist_filters_empty_and_dedupes(): result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "]) assert result == ["AAPL", "MSFT"] def test_normalize_watchlist_handles_string_input(): result = Gateway._normalize_watchlist("aapl, msft, aapl") assert result == ["AAPL", "MSFT"] def test_normalize_agent_workspace_filename_allows_editable_files(): for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]: result = Gateway._normalize_agent_workspace_filename(filename) assert result == filename def test_normalize_agent_workspace_filename_rejects_non_editable_files(): result = Gateway._normalize_agent_workspace_filename("README.md") assert result is None def test_normalize_agent_workspace_filename_rejects_arbitrary_paths(): result = Gateway._normalize_agent_workspace_filename("../etc/passwd") assert result is None # ============================================================================= # Broadcast # ============================================================================= @pytest.mark.asyncio async def test_broadcast_skips_when_no_clients(): gateway = make_gateway() gateway.connected_clients = set() # Should not raise await gateway.broadcast({"type": "test"}) @pytest.mark.asyncio async def test_broadcast_sends_to_all_connected_clients(): gateway = make_gateway() ws1 = DummyWebSocket() ws2 = DummyWebSocket() gateway.connected_clients = {ws1, ws2} await gateway.broadcast({"type": "market_update", "data": "test"}) assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages) assert ws1.messages[0]["data"] == "test" assert ws2.messages[0]["data"] == "test" @pytest.mark.asyncio async def test_broadcast_removes_closed_connections(): """Verify closed connections are removed from connected_clients set. The broadcast method's _send_to_client helper removes a client when it raises websockets.ConnectionClosed. """ gateway = make_gateway() closed_ws = DummyWebSocket() open_ws = DummyWebSocket() gateway.connected_clients = {closed_ws, open_ws} # Make closed_ws.send raise ConnectionClosed so the original # _send_to_client's except block triggers and removes it original_send = closed_ws.send async def raising_send(payload): raise gateway_module.websockets.ConnectionClosed(None, None) closed_ws.send = raising_send try: await gateway.broadcast({"type": "test"}) except gateway_module.websockets.ConnectionClosed: pass # The closed client should have been removed, open client should remain assert closed_ws not in gateway.connected_clients assert open_ws in gateway.connected_clients @pytest.mark.asyncio async def test_broadcast_sends_to_all_connected_clients(): """Verify broadcast sends to all connected clients and collects results.""" gateway = make_gateway() ws1 = DummyWebSocket() ws2 = DummyWebSocket() gateway.connected_clients = {ws1, ws2} await gateway.broadcast({"type": "market_update", "data": "test"}) assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages) assert ws1.messages[0]["data"] == "test" assert ws2.messages[0]["data"] == "test" # ============================================================================= # Stop # ============================================================================= def test_stop_gateway_calls_cycle_support(): gateway = make_gateway() with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock: gateway.stop() mock.assert_called_once_with(gateway) # ============================================================================= # set_backtest_dates # ============================================================================= def test_set_backtest_dates_delegates_to_cycle_support(): gateway = make_gateway() with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock: gateway.set_backtest_dates(["2026-03-01", "2026-03-02"]) mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"]) # ============================================================================= # Provider usage change callback # ============================================================================= def test_on_provider_usage_changed_updates_state_sync(): """_on_provider_usage_changed updates state_sync with the provider snapshot.""" gateway = make_gateway() gateway._loop = None # no loop set snapshot = {"provider": "finnhub", "calls": 10} gateway._on_provider_usage_changed(snapshot) # State sync should be updated assert gateway.state_sync.state.get("data_sources") == snapshot # ============================================================================= # handle_client lifecycle # ============================================================================= @pytest.mark.asyncio async def test_handle_client_adds_and_removes_client_from_connected_set(): gateway = make_gateway() ws = DummyWebSocket() with patch.object(gateway, "_send_initial_state", AsyncMock()): with patch.object(gateway, "_handle_client_messages", AsyncMock()): await gateway.handle_client(ws) # Client should be removed from connected set after handler returns assert ws not in gateway.connected_clients @pytest.mark.asyncio async def test_handle_client_adds_client_before_handler(): gateway = make_gateway() ws = DummyWebSocket() with patch.object(gateway, "_send_initial_state", AsyncMock()): with patch.object(gateway, "_handle_client_messages", AsyncMock()): await gateway.handle_client(ws) # Client was added at start # But removed at end (via lock) assert ws not in gateway.connected_clients