221 lines
6.2 KiB
Python
221 lines
6.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Direct tests for Gateway support modules."""
|
|
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from backend.services import gateway_cycle_support, gateway_runtime_support
|
|
|
|
|
|
class _DummyDashboard:
|
|
def __init__(self):
|
|
self.updated = []
|
|
self.tickers = []
|
|
self.initial_cash = None
|
|
self.enable_memory = False
|
|
self.days_total = 0
|
|
|
|
def update(self, **kwargs):
|
|
self.updated.append(kwargs)
|
|
|
|
def stop(self):
|
|
return None
|
|
|
|
def print_final_summary(self):
|
|
return None
|
|
|
|
|
|
class _DummyScheduler:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def reconfigure(self, **kwargs):
|
|
self.calls.append(kwargs)
|
|
|
|
|
|
class _DummyStateSync:
|
|
def __init__(self):
|
|
self.updated = []
|
|
self.saved = False
|
|
self.system_messages = []
|
|
self.backtest_dates = []
|
|
self.state = {}
|
|
|
|
def update_state(self, key, value):
|
|
self.updated.append((key, value))
|
|
self.state[key] = value
|
|
|
|
def save_state(self):
|
|
self.saved = True
|
|
|
|
async def on_system_message(self, message):
|
|
self.system_messages.append(message)
|
|
|
|
def set_backtest_dates(self, dates):
|
|
self.backtest_dates = list(dates)
|
|
|
|
|
|
class _DummyStorage:
|
|
def __init__(self):
|
|
self.initial_cash = 100000.0
|
|
self.is_live_session_active = False
|
|
self.server_state_updates = []
|
|
|
|
def can_apply_initial_cash(self):
|
|
return True
|
|
|
|
def apply_initial_cash(self, value):
|
|
self.initial_cash = value
|
|
return True
|
|
|
|
def update_server_state_from_dashboard(self, state):
|
|
self.server_state_updates.append(state)
|
|
|
|
def load_file(self, name):
|
|
if name == "summary":
|
|
return {"totalAssetValue": self.initial_cash}
|
|
return []
|
|
|
|
def build_dashboard_snapshot_from_state(self, state):
|
|
return {
|
|
"summary": {"totalAssetValue": self.initial_cash},
|
|
"holdings": [],
|
|
"stats": {},
|
|
"trades": [],
|
|
"leaderboard": [],
|
|
}
|
|
|
|
|
|
class _DummyPM:
|
|
def __init__(self):
|
|
self.portfolio = {"margin_requirement": 0.0}
|
|
|
|
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
|
|
if margin_requirement is not None:
|
|
self.portfolio["margin_requirement"] = margin_requirement
|
|
return {"margin_requirement": True}
|
|
|
|
def can_apply_initial_cash(self):
|
|
return True
|
|
|
|
|
|
class _DummyMarketService:
|
|
def __init__(self):
|
|
self.updated = None
|
|
self.stopped = False
|
|
|
|
def update_tickers(self, tickers):
|
|
self.updated = list(tickers)
|
|
return {"active": list(tickers), "added": list(tickers), "removed": []}
|
|
|
|
def stop(self):
|
|
self.stopped = True
|
|
|
|
|
|
def make_gateway_stub():
|
|
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
|
|
gateway = SimpleNamespace(
|
|
market_service=_DummyMarketService(),
|
|
pipeline=pipeline,
|
|
scheduler=_DummyScheduler(),
|
|
config={
|
|
"tickers": ["AAPL"],
|
|
"schedule_mode": "daily",
|
|
"interval_minutes": 60,
|
|
"trigger_time": "09:30",
|
|
"enable_memory": False,
|
|
},
|
|
storage=_DummyStorage(),
|
|
state_sync=_DummyStateSync(),
|
|
_dashboard=_DummyDashboard(),
|
|
_watchlist_ingest_task=None,
|
|
_market_status_task=None,
|
|
_backtest_task=None,
|
|
_backtest_start_date=None,
|
|
_backtest_end_date=None,
|
|
_manual_cycle_task=None,
|
|
)
|
|
return gateway
|
|
|
|
|
|
def test_normalize_watchlist_filters_invalid_and_dedupes():
|
|
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
|
|
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
|
|
|
|
|
|
def test_normalize_agent_workspace_filename_obeys_allowlist():
|
|
allowlist = {"SOUL.md", "PROFILE.md"}
|
|
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
|
|
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
|
|
|
|
|
|
def test_apply_runtime_config_updates_gateway_state():
|
|
gateway = make_gateway_stub()
|
|
|
|
result = gateway_runtime_support.apply_runtime_config(
|
|
gateway,
|
|
{
|
|
"tickers": ["MSFT", "NVDA"],
|
|
"schedule_mode": "intraday",
|
|
"interval_minutes": 30,
|
|
"trigger_time": "10:30",
|
|
"initial_cash": 150000.0,
|
|
"margin_requirement": 0.5,
|
|
"max_comm_cycles": 4,
|
|
"enable_memory": False,
|
|
},
|
|
)
|
|
|
|
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
|
|
assert gateway.config["schedule_mode"] == "intraday"
|
|
assert gateway.storage.initial_cash == 150000.0
|
|
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
|
|
assert gateway.scheduler.calls[-1] == {
|
|
"mode": "intraday",
|
|
"trigger_time": "10:30",
|
|
"interval_minutes": 30,
|
|
}
|
|
|
|
|
|
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
|
gateway = make_gateway_stub()
|
|
captured = {}
|
|
|
|
class DummyTask:
|
|
def done(self):
|
|
return False
|
|
|
|
def cancel(self):
|
|
captured["cancelled"] = True
|
|
|
|
def fake_create_task(coro):
|
|
captured["name"] = coro.cr_code.co_name
|
|
coro.close()
|
|
return DummyTask()
|
|
|
|
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
|
|
|
|
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
|
|
|
|
assert captured["name"] == "refresh_market_store_for_watchlist"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
|
|
gateway = make_gateway_stub()
|
|
|
|
monkeypatch.setattr(
|
|
gateway_cycle_support,
|
|
"ingest_symbols",
|
|
lambda symbols, mode="incremental": [
|
|
{"symbol": symbol, "prices": 3, "news": 4}
|
|
for symbol in symbols
|
|
],
|
|
)
|
|
|
|
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
|
|
|
|
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
|
|
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
|