# -*- coding: utf-8 -*- """Direct tests for Gateway support modules.""" from types import SimpleNamespace import pytest from backend.core.state_sync import StateSync from backend.services import gateway_cycle_support, gateway_runtime_support class _DummyScheduler: def __init__(self): self.calls = [] def reconfigure(self, **kwargs): self.calls.append(kwargs) class _DummyStateSync: def __init__(self): self.updated = [] self.saved = False self.system_messages = [] self.backtest_dates = [] self.state = {} def update_state(self, key, value): self.updated.append((key, value)) self.state[key] = value def save_state(self): self.saved = True async def on_system_message(self, message): self.system_messages.append(message) def set_backtest_dates(self, dates): self.backtest_dates = list(dates) class _DummyStorage: def __init__(self): self.initial_cash = 100000.0 self.is_live_session_active = False self.server_state_updates = [] self.max_feed_history = 200 self.runtime_db = SimpleNamespace( get_recent_feed_events=lambda limit=200: [], get_last_day_feed_events=lambda current_date=None, limit=200: [], ) self._persisted_server_state = {} def can_apply_initial_cash(self): return True def apply_initial_cash(self, value): self.initial_cash = value return True def update_server_state_from_dashboard(self, state): self.server_state_updates.append(state) def read_persisted_server_state(self): return dict(self._persisted_server_state) def load_file(self, name): if name == "summary": return {"totalAssetValue": self.initial_cash} return [] def build_dashboard_snapshot_from_state(self, state): return { "summary": {"totalAssetValue": self.initial_cash}, "holdings": [], "stats": {}, "trades": [], "leaderboard": [], } class _DummyPM: def __init__(self): self.portfolio = {"margin_requirement": 0.0} def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None): if margin_requirement is not None: self.portfolio["margin_requirement"] = margin_requirement return {"margin_requirement": True} def can_apply_initial_cash(self): return True class _DummyMarketService: def __init__(self): self.updated = None self.stopped = False def update_tickers(self, tickers): self.updated = list(tickers) return {"active": list(tickers), "added": list(tickers), "removed": []} def stop(self): self.stopped = True def make_gateway_stub(): pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM()) gateway = SimpleNamespace( market_service=_DummyMarketService(), pipeline=pipeline, scheduler=_DummyScheduler(), config={ "tickers": ["AAPL"], "schedule_mode": "daily", "interval_minutes": 60, "trigger_time": "09:30", "enable_memory": False, }, storage=_DummyStorage(), state_sync=_DummyStateSync(), _watchlist_ingest_task=None, _market_status_task=None, _backtest_task=None, _backtest_start_date=None, _backtest_end_date=None, _manual_cycle_task=None, ) return gateway def test_normalize_watchlist_filters_invalid_and_dedupes(): assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"] assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"] def test_normalize_agent_workspace_filename_obeys_allowlist(): allowlist = {"SOUL.md", "PROFILE.md"} assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md" assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None def test_apply_runtime_config_updates_gateway_state(): gateway = make_gateway_stub() result = gateway_runtime_support.apply_runtime_config( gateway, { "tickers": ["MSFT", "NVDA"], "schedule_mode": "intraday", "interval_minutes": 30, "trigger_time": "10:30", "initial_cash": 150000.0, "margin_requirement": 0.5, "max_comm_cycles": 4, "enable_memory": False, }, ) assert gateway.config["tickers"] == ["MSFT", "NVDA"] assert gateway.config["schedule_mode"] == "interval" assert gateway.storage.initial_cash == 150000.0 assert result["runtime_config_applied"]["max_comm_cycles"] == 4 assert gateway.scheduler.calls[-1] == { "mode": "interval", "trigger_time": "10:30", "interval_minutes": 30, } def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch): gateway = make_gateway_stub() captured = {} class DummyTask: def done(self): return False def cancel(self): captured["cancelled"] = True def fake_create_task(coro): captured["name"] = coro.cr_code.co_name coro.close() return DummyTask() monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task) gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"]) assert captured["name"] == "refresh_market_store_for_watchlist" @pytest.mark.asyncio async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch): gateway = make_gateway_stub() monkeypatch.setattr( gateway_cycle_support, "ingest_symbols", lambda symbols, mode="incremental": [ {"symbol": symbol, "prices": 3, "news": 4} for symbol in symbols ], ) await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"]) assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT" assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1] def test_initial_state_payload_prefers_dashboard_snapshot_for_top_level_views(): storage = _DummyStorage() sync = StateSync(storage=storage) sync._state = { "holdings": [], "trades": [], "stats": {}, "leaderboard": [], "portfolio": {"total_value": 100000.0}, } payload = sync.get_initial_state_payload(include_dashboard=True) assert payload["holdings"] == [] assert payload["trades"] == [] assert payload["stats"] == {} assert payload["leaderboard"] == [] assert payload["dashboard"]["summary"]["totalAssetValue"] == 100000.0 def test_initial_state_payload_uses_dashboard_snapshot_for_sparse_runtime_state(): class SnapshotStorage(_DummyStorage): def build_dashboard_snapshot_from_state(self, state): return { "summary": {"totalAssetValue": 123456.0}, "holdings": [{"ticker": "AAPL"}], "stats": {"totalTrades": 3}, "trades": [{"ticker": "AAPL"}], "leaderboard": [{"agentId": "technical_analyst"}], } sync = StateSync(storage=SnapshotStorage()) sync._state = { "holdings": [], "trades": [], "stats": {}, "leaderboard": [], } payload = sync.get_initial_state_payload(include_dashboard=True) assert payload["holdings"][0]["ticker"] == "AAPL" assert payload["trades"][0]["ticker"] == "AAPL" assert payload["stats"]["totalTrades"] == 3 assert payload["leaderboard"][0]["agentId"] == "technical_analyst" def test_initial_state_payload_falls_back_to_persisted_portfolio(): storage = _DummyStorage() storage._persisted_server_state = { "portfolio": { "total_value": 123456.0, "pnl_percent": 12.34, "equity": [{"t": 1, "v": 123456.0}], } } sync = StateSync(storage=storage) sync._state = { "portfolio": {}, } payload = sync.get_initial_state_payload(include_dashboard=True) assert payload["portfolio"]["total_value"] == 123456.0 assert payload["portfolio"]["pnl_percent"] == 12.34