# -*- coding: utf-8 -*- import json from types import SimpleNamespace import pytest from backend.services.gateway import Gateway import backend.services.gateway as gateway_module class DummyWebSocket: def __init__(self): self.messages = [] async def send(self, payload: str): self.messages.append(json.loads(payload)) class DummyStateSync: def __init__(self, current_date="2026-03-16"): self.state = {"current_date": current_date} self.system_messages = [] def set_broadcast_fn(self, _fn): return None def update_state(self, *_args, **_kwargs): return None async def on_system_message(self, message): self.system_messages.append(message) class FakeMarketStore: def __init__(self): self.calls = [] def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None): self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date)) return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}] def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100): self.calls.append(("get_news_items", symbol, start_date, end_date, limit)) return [ { "id": "news-1", "ticker": symbol, "date": end_date, "trade_date": end_date, "title": "Title", "summary": "Summary", "source": "polygon", } ] def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100): self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit)) target_date = trade_date or end_date return [ { "id": "news-1", "ticker": symbol, "date": target_date, "trade_date": target_date, "title": "Title", "summary": "Summary", "source": "polygon", "sentiment": "negative", "relevance": "high", "key_discussion": "Key discussion", } ] def get_news_by_ids_enriched(self, symbol, article_ids): self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids))) return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}] def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200): self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit)) return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}} def get_story_cache(self, symbol, *, as_of_date): self.calls.append(("get_story_cache", symbol, as_of_date)) return None def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"): self.calls.append(("upsert_story_cache", symbol, as_of_date, source)) def delete_story_cache(self, symbol, *, as_of_date=None): self.calls.append(("delete_story_cache", symbol, as_of_date)) return 1 def get_similar_day_cache(self, symbol, *, target_date): self.calls.append(("get_similar_day_cache", symbol, target_date)) return None def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"): self.calls.append(("upsert_similar_day_cache", symbol, target_date, source)) def delete_similar_day_cache(self, symbol, *, target_date=None): self.calls.append(("delete_similar_day_cache", symbol, target_date)) return 1 def get_ohlc(self, symbol, start_date, end_date): self.calls.append(("get_ohlc", symbol, start_date, end_date)) return [ {"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103}, {"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107}, ] def make_gateway(market_store=None): storage = SimpleNamespace(market_store=market_store or FakeMarketStore()) pipeline = SimpleNamespace(state_sync=None) market_service = SimpleNamespace() state_sync = DummyStateSync() return Gateway( market_service=market_service, storage_service=storage, pipeline=pipeline, state_sync=state_sync, config={"mode": "live"}, ) @pytest.mark.asyncio async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument(): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() await gateway._handle_get_stock_news_timeline( websocket, {"ticker": "AAPL", "lookback_days": 30}, ) assert market_store.calls == [ ("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16") ] assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded" assert websocket.messages[-1]["ticker"] == "AAPL" @pytest.mark.asyncio async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() await gateway._handle_get_stock_news_categories( websocket, {"ticker": "AAPL", "lookback_days": 30}, ) assert market_store.calls == [ ("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200), ("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200) ] assert websocket.messages[-1]["type"] == "stock_news_categories_loaded" assert websocket.messages[-1]["categories"]["macro"]["count"] == 1 @pytest.mark.asyncio async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows): return { "ticker": ticker, "start_date": start_date, "end_date": end_date, "news_count": len(news_rows), } monkeypatch.setattr( gateway_module, "build_range_explanation", fake_build_range_explanation, ) await gateway._handle_get_stock_range_explain( websocket, {"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"}, ) assert market_store.calls == [ ("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100) ] assert websocket.messages[-1] == { "type": "stock_range_explain_loaded", "ticker": "AAPL", "result": { "ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16", "news_count": 1, }, } @pytest.mark.asyncio async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() monkeypatch.setattr( gateway_module, "build_range_explanation", lambda **kwargs: {"news_count": len(kwargs["news_rows"])}, ) await gateway._handle_get_stock_range_explain( websocket, { "ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16", "article_ids": ["news-99"], }, ) assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])] assert websocket.messages[-1]["result"]["news_count"] == 1 @pytest.mark.asyncio async def test_handle_get_stock_news_for_date_uses_trade_date_lookup(): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() await gateway._handle_get_stock_news_for_date( websocket, {"ticker": "AAPL", "date": "2026-03-16", "limit": 10}, ) assert market_store.calls == [ ("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10) ] assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded" assert websocket.messages[-1]["date"] == "2026-03-16" @pytest.mark.asyncio async def test_handle_get_stock_story_returns_story_payload(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() monkeypatch.setattr( gateway_module, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3}, ) await gateway._handle_get_stock_story( websocket, {"ticker": "AAPL", "as_of_date": "2026-03-16"}, ) assert websocket.messages[-1]["type"] == "stock_story_loaded" assert websocket.messages[-1]["ticker"] == "AAPL" assert "AAPL Story" in websocket.messages[-1]["story"] @pytest.mark.asyncio async def test_handle_get_stock_similar_days_returns_items(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() monkeypatch.setattr( gateway_module, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3}, ) await gateway._handle_get_stock_similar_days( websocket, {"ticker": "AAPL", "date": "2026-03-16", "top_k": 5}, ) assert websocket.messages[-1]["type"] == "stock_similar_days_loaded" assert websocket.messages[-1]["ticker"] == "AAPL" assert isinstance(websocket.messages[-1]["items"], list) @pytest.mark.asyncio async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch): market_store = FakeMarketStore() gateway = make_gateway(market_store) websocket = DummyWebSocket() monkeypatch.setattr( gateway_module, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2}, ) await gateway._handle_run_stock_enrich( websocket, { "ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16", "force": True, "rebuild_story": True, "rebuild_similar_days": True, "story_date": "2026-03-16", "target_date": "2026-03-16", }, ) assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls assert websocket.messages[-1]["type"] == "stock_enrich_completed" assert websocket.messages[-1]["stats"]["analyzed"] == 2 @pytest.mark.asyncio async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch): gateway = make_gateway(FakeMarketStore()) websocket = DummyWebSocket() monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False) await gateway._handle_run_stock_enrich( websocket, { "ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16", "only_local_to_llm": True, }, ) assert websocket.messages[-1]["type"] == "stock_enrich_completed" assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"] def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch): gateway = make_gateway() captured = {} class DummyTask: def done(self): return False def cancel(self): captured["cancelled"] = True def fake_create_task(coro): captured["coro_name"] = coro.cr_code.co_name coro.close() return DummyTask() monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task) gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"]) assert captured["coro_name"] == "_refresh_market_store_for_watchlist" @pytest.mark.asyncio async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch): gateway = make_gateway() monkeypatch.setattr( gateway_module, "ingest_symbols", lambda symbols, mode="incremental": [ {"symbol": symbol, "prices": 3, "news": 4, "aligned": 4} for symbol in symbols ], ) await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"]) assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT" assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1] assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1] @pytest.mark.asyncio async def test_handle_get_agent_skills_returns_statuses(tmp_path): builtin_root = tmp_path / "backend" / "skills" / "builtin" for name in ("risk_review", "extra_guard"): skill_dir = builtin_root / name skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( f"---\nname: {name}\ndescription: {name} desc\n---\n", encoding="utf-8", ) agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" agent_dir.mkdir(parents=True, exist_ok=True) (agent_dir / "agent.yaml").write_text( "enabled_skills:\n" " - extra_guard\n" "disabled_skills:\n" " - risk_review\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() await gateway._handle_get_agent_skills( websocket, {"agent_id": "risk_manager"}, ) assert websocket.messages[-1]["type"] == "agent_skills_loaded" statuses = { row["skill_name"]: row["status"] for row in websocket.messages[-1]["skills"] } assert statuses["extra_guard"] == "enabled" assert statuses["risk_review"] == "disabled" @pytest.mark.asyncio async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatch, tmp_path): agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" agent_dir.mkdir(parents=True, exist_ok=True) (agent_dir / "agent.yaml").write_text( "prompt_files:\n" " - SOUL.md\n" " - MEMORY.md\n" "active_tool_groups:\n" " - risk_ops\n" "disabled_tool_groups:\n" " - legacy_group\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() monkeypatch.setattr( gateway_module, "load_agent_profiles", lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}}, ) monkeypatch.setattr( gateway_module, "get_agent_model_info", lambda agent_id: ("gpt-4o-mini", "OPENAI"), ) class _Bootstrap: @staticmethod def agent_override(_agent_id): return {} monkeypatch.setattr( gateway_module, "get_bootstrap_config_for_run", lambda project_root, config_name: _Bootstrap(), ) await gateway._handle_get_agent_profile( websocket, {"agent_id": "risk_manager"}, ) assert websocket.messages[-1]["type"] == "agent_profile_loaded" profile = websocket.messages[-1]["profile"] assert profile["model_name"] == "gpt-4o-mini" assert profile["model_provider"] == "OPENAI" assert profile["prompt_files"] == ["SOUL.md", "MEMORY.md"] assert profile["active_tool_groups"] == ["risk_ops"] assert profile["disabled_tool_groups"] == ["legacy_group"] @pytest.mark.asyncio async def test_handle_get_skill_detail_returns_markdown_body(tmp_path): skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review" skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( "---\nname: 风险审查\ndescription: 说明\nversion: 1.0.0\n---\n# 风险审查\n\n完整正文\n", encoding="utf-8", ) gateway = make_gateway() gateway._project_root = tmp_path websocket = DummyWebSocket() await gateway._handle_get_skill_detail( websocket, {"skill_name": "risk_review"}, ) assert websocket.messages[-1]["type"] == "skill_detail_loaded" assert websocket.messages[-1]["skill"]["name"] == "风险审查" assert websocket.messages[-1]["skill"]["version"] == "1.0.0" assert websocket.messages[-1]["skill"]["content"] == "# 风险审查\n\n完整正文" @pytest.mark.asyncio async def test_handle_get_skill_detail_prefers_agent_local_skill(tmp_path): skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard" skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( "---\nname: 本地风控\ndescription: 本地说明\nversion: 1.0.0\n---\n# 本地风控\n\n本地正文\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() await gateway._handle_get_skill_detail( websocket, {"agent_id": "risk_manager", "skill_name": "local_guard"}, ) assert websocket.messages[-1]["type"] == "skill_detail_loaded" assert websocket.messages[-1]["agent_id"] == "risk_manager" assert websocket.messages[-1]["skill"]["source"] == "local" assert websocket.messages[-1]["skill"]["content"] == "# 本地风控\n\n本地正文" @pytest.mark.asyncio async def test_handle_update_agent_skill_persists_and_returns_refresh(monkeypatch, tmp_path): skill_dir = tmp_path / "backend" / "skills" / "builtin" / "extra_guard" skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( "---\nname: extra_guard\ndescription: desc\n---\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() async def _noop_reload(): return None monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload) await gateway._handle_update_agent_skill( websocket, { "agent_id": "risk_manager", "skill_name": "extra_guard", "enabled": True, }, ) assert websocket.messages[0]["type"] == "agent_skill_updated" assert websocket.messages[-1]["type"] == "agent_skills_loaded" agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml" assert "extra_guard" in agent_yaml.read_text(encoding="utf-8") @pytest.mark.asyncio async def test_handle_create_and_update_agent_local_skill(monkeypatch, tmp_path): gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() async def _noop_reload(): return None monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload) await gateway._handle_create_agent_local_skill( websocket, {"agent_id": "risk_manager", "skill_name": "local_guard"}, ) assert websocket.messages[0]["type"] == "agent_local_skill_created" assert websocket.messages[1]["type"] == "agent_skills_loaded" assert websocket.messages[2]["type"] == "skill_detail_loaded" target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard" / "SKILL.md" assert target.exists() websocket.messages.clear() await gateway._handle_update_agent_local_skill( websocket, { "agent_id": "risk_manager", "skill_name": "local_guard", "content": "---\nname: 本地风控\ndescription: 更新后\nversion: 1.0.0\n---\n# 本地风控\n\n更新正文\n", }, ) assert websocket.messages[0]["type"] == "agent_local_skill_updated" assert websocket.messages[1]["type"] == "skill_detail_loaded" assert "更新正文" in target.read_text(encoding="utf-8") @pytest.mark.asyncio async def test_handle_delete_agent_local_skill(monkeypatch, tmp_path): skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard" skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( "---\nname: 本地风控\ndescription: desc\nversion: 1.0.0\n---\n", encoding="utf-8", ) agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml" agent_yaml.parent.mkdir(parents=True, exist_ok=True) agent_yaml.write_text( "enabled_skills:\n" " - local_guard\n" "disabled_skills:\n" " - local_guard\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() async def _noop_reload(): return None monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload) await gateway._handle_delete_agent_local_skill( websocket, {"agent_id": "risk_manager", "skill_name": "local_guard"}, ) assert websocket.messages[0]["type"] == "agent_local_skill_deleted" assert websocket.messages[1]["type"] == "agent_skills_loaded" assert not skill_dir.exists() assert "local_guard" not in agent_yaml.read_text(encoding="utf-8") @pytest.mark.asyncio async def test_handle_remove_agent_skill_marks_disabled(monkeypatch, tmp_path): skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review" skill_dir.mkdir(parents=True, exist_ok=True) (skill_dir / "SKILL.md").write_text( "---\nname: 风险审查\ndescription: desc\nversion: 1.0.0\n---\n", encoding="utf-8", ) gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() async def _noop_reload(): return None monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload) await gateway._handle_remove_agent_skill( websocket, {"agent_id": "risk_manager", "skill_name": "risk_review"}, ) assert websocket.messages[0]["type"] == "agent_skill_removed" assert websocket.messages[1]["type"] == "agent_skills_loaded" agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml" assert "risk_review" in agent_yaml.read_text(encoding="utf-8") @pytest.mark.asyncio async def test_handle_get_agent_workspace_file_returns_content(tmp_path): file_path = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md" file_path.parent.mkdir(parents=True, exist_ok=True) file_path.write_text("soul content", encoding="utf-8") gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() await gateway._handle_get_agent_workspace_file( websocket, {"agent_id": "risk_manager", "filename": "SOUL.md"}, ) assert websocket.messages[-1] == { "type": "agent_workspace_file_loaded", "config_name": "demo", "agent_id": "risk_manager", "filename": "SOUL.md", "content": "soul content", } @pytest.mark.asyncio async def test_handle_update_agent_workspace_file_persists_and_returns_refresh(monkeypatch, tmp_path): gateway = make_gateway() gateway.config["config_name"] = "demo" gateway._project_root = tmp_path websocket = DummyWebSocket() async def _noop_reload(): return None monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload) await gateway._handle_update_agent_workspace_file( websocket, { "agent_id": "risk_manager", "filename": "SOUL.md", "content": "updated soul", }, ) assert websocket.messages[0]["type"] == "agent_workspace_file_updated" assert websocket.messages[-1]["type"] == "agent_workspace_file_loaded" target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md" assert target.read_text(encoding="utf-8") == "updated soul"