Initial commit of integrated agent system

This commit is contained in:
cillin
2026-03-30 17:46:44 +08:00
commit 0fa413380c
337 changed files with 75268 additions and 0 deletions

View File

View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted agent service surface."""
from pathlib import Path
from fastapi.testclient import TestClient
from backend.apps.agent_service import create_app
from backend.api import agents as agents_module
def test_agent_service_routes_include_control_plane_endpoints(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/workspaces" in paths
assert "/api/guard/pending" in paths
def test_agent_service_excludes_runtime_routes(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/api/runtime/start" not in paths
assert "/api/runtime/gateway/port" not in paths
def test_agent_service_read_routes(monkeypatch, tmp_path):
class _FakeSkillsManager:
project_root = tmp_path
def get_agent_asset_dir(self, config_name, agent_id):
return tmp_path / "runs" / config_name / "agents" / agent_id
def resolve_agent_skill_names(self, config_name, agent_id, default_skills=None):
return ["demo_skill"]
def list_agent_skill_catalog(self, config_name, agent_id):
return [
type(
"Skill",
(),
{
"skill_name": "demo_skill",
"name": "Demo Skill",
"description": "demo",
"version": "1.0.0",
"source": "builtin",
"tools": [],
},
)()
]
def load_agent_skill_document(self, config_name, agent_id, skill_name):
return {"skill_name": skill_name, "content": "# demo"}
class _FakeWorkspaceManager:
def load_agent_file(self, config_name, agent_id, filename):
return f"{config_name}:{agent_id}:{filename}"
monkeypatch.setattr(agents_module, "load_agent_profiles", lambda: {"portfolio_manager": {"skills": ["demo_skill"]}})
monkeypatch.setattr(agents_module, "get_agent_model_info", lambda agent_id: ("deepseek-v3.2", "DASHSCOPE"))
monkeypatch.setattr(
agents_module,
"load_agent_workspace_config",
lambda path: type(
"Cfg",
(),
{
"active_tool_groups": ["portfolio_ops"],
"disabled_tool_groups": [],
"enabled_skills": [],
"disabled_skills": [],
"prompt_files": ["SOUL.md", "MEMORY.md"],
},
)(),
)
monkeypatch.setattr(
agents_module,
"get_bootstrap_config_for_run",
lambda project_root, config_name: type("Bootstrap", (), {"agent_override": lambda self, agent_id: {}})(),
)
app = create_app(project_root=tmp_path)
app.dependency_overrides[agents_module.get_skills_manager] = lambda: _FakeSkillsManager()
app.dependency_overrides[agents_module.get_workspace_manager] = lambda: _FakeWorkspaceManager()
with TestClient(app) as client:
profile = client.get("/api/workspaces/demo/agents/portfolio_manager/profile")
skills = client.get("/api/workspaces/demo/agents/portfolio_manager/skills")
detail = client.get("/api/workspaces/demo/agents/portfolio_manager/skills/demo_skill")
workspace_file = client.get("/api/workspaces/demo/agents/portfolio_manager/files/MEMORY.md")
assert profile.status_code == 200
assert profile.json()["profile"]["model_name"] == "deepseek-v3.2"
assert skills.status_code == 200
assert skills.json()["skills"][0]["skill_name"] == "demo_skill"
assert detail.status_code == 200
assert detail.json()["skill"]["content"] == "# demo"
assert workspace_file.status_code == 200
assert workspace_file.json()["content"] == "demo:portfolio_manager:MEMORY.md"

View File

@@ -0,0 +1,233 @@
# -*- coding: utf-8 -*-
from backend.agents.prompt_factory import build_agent_system_prompt
from backend.agents.skills_manager import SkillsManager
from backend.agents.workspace_manager import WorkspaceManager
class _DummyToolkit:
def get_agent_skill_prompt(self):
return ""
def get_activated_notes(self):
return ""
def test_workspace_manager_creates_core_agent_files(tmp_path):
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
assert (asset_dir / "SOUL.md").exists()
assert (asset_dir / "PROFILE.md").exists()
assert (asset_dir / "AGENTS.md").exists()
assert (asset_dir / "MEMORY.md").exists()
assert (asset_dir / "POLICY.md").exists()
assert (asset_dir / "agent.yaml").exists()
assert (asset_dir / "skills" / "installed").is_dir()
assert (asset_dir / "skills" / "active").is_dir()
assert (asset_dir / "skills" / "disabled").is_dir()
assert (asset_dir / "skills" / "local").is_dir()
def test_workspace_manager_seeds_risk_prompt_content(tmp_path):
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
soul = (asset_dir / "SOUL.md").read_text(encoding="utf-8")
guide = (asset_dir / "AGENTS.md").read_text(encoding="utf-8")
assert "风险管理经理" in soul
assert "优先使用可用的风险工具量化集中度" in guide
def test_agent_workspace_config_controls_prompt_files(tmp_path, monkeypatch):
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
(asset_dir / "SOUL.md").write_text("soul-line", encoding="utf-8")
(asset_dir / "PROFILE.md").write_text("profile-line", encoding="utf-8")
(asset_dir / "MEMORY.md").write_text("memory-line", encoding="utf-8")
(asset_dir / "agent.yaml").write_text(
"prompt_files:\n"
" - SOUL.md\n"
" - MEMORY.md\n",
encoding="utf-8",
)
from backend.agents import prompt_factory
monkeypatch.setattr(
prompt_factory,
"SkillsManager",
lambda: SkillsManager(project_root=tmp_path),
)
prompt = build_agent_system_prompt(
agent_id="risk_manager",
config_name="demo",
toolkit=_DummyToolkit(),
)
assert "soul-line" in prompt
assert "memory-line" in prompt
assert "profile-line" not in prompt
def test_prompt_is_built_from_workspace_defaults_without_system_templates(tmp_path, monkeypatch):
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["portfolio_manager"],
analyst_personas={},
)
from backend.agents import prompt_factory
monkeypatch.setattr(
prompt_factory,
"SkillsManager",
lambda: SkillsManager(project_root=tmp_path),
)
prompt = build_agent_system_prompt(
agent_id="portfolio_manager",
config_name="demo",
toolkit=_DummyToolkit(),
)
assert "投资组合经理" in prompt
assert "使用 `make_decision` 工具记录每个股票的最终决策" in prompt
def test_skills_manager_applies_agent_level_skill_toggles(tmp_path):
builtin_root = tmp_path / "backend" / "skills" / "builtin"
for skill_name in ("risk_review", "extra_guard"):
skill_dir = builtin_root / skill_name
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
f"# {skill_name}\n",
encoding="utf-8",
)
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
(asset_dir / "agent.yaml").write_text(
"enabled_skills:\n"
" - extra_guard\n"
"disabled_skills:\n"
" - risk_review\n",
encoding="utf-8",
)
skills_manager = SkillsManager(project_root=tmp_path)
active_map = skills_manager.prepare_active_skills(
config_name="demo",
agent_defaults={"risk_manager": ["risk_review"]},
)
active_dirs = active_map["risk_manager"]
assert [path.name for path in active_dirs] == ["extra_guard"]
assert (asset_dir / "skills" / "installed" / "extra_guard" / "SKILL.md").exists()
assert (asset_dir / "skills" / "active" / "extra_guard" / "SKILL.md").exists()
assert (asset_dir / "skills" / "disabled" / "risk_review" / "SKILL.md").exists()
assert not (asset_dir / "skills" / "active" / "risk_review").exists()
def test_agent_local_skill_is_activated_from_agent_workspace(tmp_path):
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
local_skill = asset_dir / "skills" / "local" / "local_guard"
local_skill.mkdir(parents=True, exist_ok=True)
(local_skill / "SKILL.md").write_text(
"---\nname: 本地风控\ndescription: local skill\nversion: 1.0.0\n---\n",
encoding="utf-8",
)
skills_manager = SkillsManager(project_root=tmp_path)
active_map = skills_manager.prepare_active_skills(
config_name="demo",
agent_defaults={"risk_manager": []},
)
assert [path.name for path in active_map["risk_manager"]] == ["local_guard"]
assert (asset_dir / "skills" / "active" / "local_guard" / "SKILL.md").exists()
def test_prompt_includes_active_skill_metadata_summary(tmp_path, monkeypatch):
builtin_root = tmp_path / "backend" / "skills" / "builtin"
skill_dir = builtin_root / "extra_guard"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: extra_guard\n"
"description: This skill should be used when the user asks to \"run a risk check\".\n"
"version: 1.0.0\n"
"tools:\n"
" - risk_ops\n"
"---\n\n"
"# Extra Guard\n",
encoding="utf-8",
)
manager = WorkspaceManager(project_root=tmp_path)
manager.initialize_default_assets(
config_name="demo",
agent_ids=["risk_manager"],
analyst_personas={},
)
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
(asset_dir / "agent.yaml").write_text(
"enabled_skills:\n"
" - extra_guard\n",
encoding="utf-8",
)
skills_manager = SkillsManager(project_root=tmp_path)
skills_manager.prepare_active_skills(
config_name="demo",
agent_defaults={"risk_manager": []},
)
from backend.agents import prompt_factory
monkeypatch.setattr(
prompt_factory,
"SkillsManager",
lambda: SkillsManager(project_root=tmp_path),
)
prompt = build_agent_system_prompt(
agent_id="risk_manager",
config_name="demo",
toolkit=_DummyToolkit(),
)
assert "Active Skill Catalog" in prompt
assert "This skill should be used when the user asks to \"run a risk check\"." in prompt
assert "version: 1.0.0" in prompt
assert "risk_ops" not in prompt

View File

@@ -0,0 +1,591 @@
# -*- coding: utf-8 -*-
# pylint: disable=W0212
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from agentscope.message import Msg
class TestAnalystAgent:
def test_init_valid_analyst_type(self):
from backend.agents.analyst import AnalystAgent
mock_toolkit = MagicMock()
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = AnalystAgent(
analyst_type="technical_analyst",
toolkit=mock_toolkit,
model=mock_model,
formatter=mock_formatter,
)
assert agent.analyst_type_key == "technical_analyst"
assert agent.name == "technical_analyst"
assert agent.analyst_persona == "Technical Analyst"
def test_init_invalid_analyst_type(self):
from backend.agents.analyst import AnalystAgent
mock_toolkit = MagicMock()
mock_model = MagicMock()
mock_formatter = MagicMock()
with pytest.raises(ValueError) as excinfo:
AnalystAgent(
analyst_type="invalid_type",
toolkit=mock_toolkit,
model=mock_model,
formatter=mock_formatter,
)
assert "Unknown analyst type" in str(excinfo.value)
def test_init_custom_agent_id(self):
from backend.agents.analyst import AnalystAgent
mock_toolkit = MagicMock()
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = AnalystAgent(
analyst_type="fundamentals_analyst",
toolkit=mock_toolkit,
model=mock_model,
formatter=mock_formatter,
agent_id="custom_analyst_id",
)
assert agent.name == "custom_analyst_id"
def test_load_system_prompt(self):
from backend.agents.analyst import AnalystAgent
mock_toolkit = MagicMock()
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = AnalystAgent(
analyst_type="sentiment_analyst",
toolkit=mock_toolkit,
model=mock_model,
formatter=mock_formatter,
)
prompt = agent._load_system_prompt()
assert isinstance(prompt, str)
assert len(prompt) > 0
class TestPMAgent:
def test_init_default(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
assert agent.name == "portfolio_manager"
assert agent.portfolio["cash"] == 100000.0
assert agent.portfolio["positions"] == {}
assert agent.portfolio["margin_requirement"] == 0.25
def test_init_custom_cash(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
initial_cash=50000.0,
margin_requirement=0.5,
)
assert agent.portfolio["cash"] == 50000.0
assert agent.portfolio["margin_requirement"] == 0.5
def test_get_portfolio_state(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
initial_cash=75000.0,
)
state = agent.get_portfolio_state()
assert state["cash"] == 75000.0
assert state is not agent.portfolio # Should be a copy
def test_load_portfolio_state(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
new_portfolio = {
"cash": 50000.0,
"positions": {
"AAPL": {"long": 100, "short": 0, "long_cost_basis": 150.0},
},
"margin_used": 1000.0,
}
agent.load_portfolio_state(new_portfolio)
assert agent.portfolio["cash"] == 50000.0
assert "AAPL" in agent.portfolio["positions"]
def test_update_portfolio(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
agent.update_portfolio({"cash": 80000.0})
assert agent.portfolio["cash"] == 80000.0
def _get_text_from_tool_response(self, result):
"""Helper to extract text from ToolResponse content"""
content = result.content[0]
if hasattr(content, "text"):
return content.text
elif isinstance(content, dict):
return content.get("text", "")
return str(content)
def test_make_decision_long(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
result = agent._make_decision(
ticker="AAPL",
action="long",
quantity=100,
confidence=80,
reasoning="Strong fundamentals",
)
text = self._get_text_from_tool_response(result)
assert "Decision recorded" in text
assert agent._decisions["AAPL"]["action"] == "long"
assert agent._decisions["AAPL"]["quantity"] == 100
def test_make_decision_hold(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
result = agent._make_decision(
ticker="GOOGL",
action="hold",
quantity=0,
confidence=50,
reasoning="Neutral outlook",
)
text = self._get_text_from_tool_response(result)
assert "Decision recorded" in text
assert agent._decisions["GOOGL"]["action"] == "hold"
assert agent._decisions["GOOGL"]["quantity"] == 0
def test_make_decision_invalid_action(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
result = agent._make_decision(
ticker="AAPL",
action="invalid",
quantity=10,
)
text = self._get_text_from_tool_response(result)
assert "Invalid action" in text
def test_get_decisions(self):
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = PMAgent(
model=mock_model,
formatter=mock_formatter,
)
agent._make_decision("AAPL", "long", 100)
agent._make_decision("GOOGL", "short", 50)
decisions = agent.get_decisions()
assert len(decisions) == 2
assert decisions["AAPL"]["action"] == "long"
assert decisions["GOOGL"]["action"] == "short"
class TestRiskAgent:
def test_init_default(self):
from backend.agents.risk_manager import RiskAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = RiskAgent(
model=mock_model,
formatter=mock_formatter,
)
assert agent.name == "risk_manager"
def test_init_custom_name(self):
from backend.agents.risk_manager import RiskAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = RiskAgent(
model=mock_model,
formatter=mock_formatter,
name="custom_risk_manager",
)
assert agent.name == "custom_risk_manager"
def test_load_system_prompt(self):
from backend.agents.risk_manager import RiskAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
agent = RiskAgent(
model=mock_model,
formatter=mock_formatter,
)
prompt = agent._load_system_prompt()
assert isinstance(prompt, str)
assert len(prompt) > 0
class TestStorageService:
def test_storage_service_defaults_to_live_config(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
assert storage.config_name == "live"
def test_calculate_portfolio_value_cash_only(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
portfolio = {"cash": 100000.0, "positions": {}, "margin_used": 0.0}
prices = {}
value = storage.calculate_portfolio_value(portfolio, prices)
assert value == 100000.0
def test_calculate_portfolio_value_with_positions(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
portfolio = {
"cash": 50000.0,
"positions": {
"AAPL": {"long": 100, "short": 0},
"GOOGL": {"long": 0, "short": 10},
},
"margin_used": 5000.0,
}
prices = {"AAPL": 150.0, "GOOGL": 100.0}
value = storage.calculate_portfolio_value(portfolio, prices)
assert value == 69000.0
def test_update_dashboard_after_cycle(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
portfolio = {
"cash": 90000.0,
"positions": {"AAPL": {"long": 50, "short": 0}},
"margin_used": 0.0,
}
prices = {"AAPL": 200.0}
storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date="2024-01-15",
executed_trades=[
{
"ticker": "AAPL",
"action": "long",
"quantity": 50,
"price": 200.0,
},
],
)
summary = storage.load_file("summary")
assert summary is not None
assert summary["totalAssetValue"] == 100000.0 # 90000 + 50*200
holdings = storage.load_file("holdings")
assert holdings is not None
assert len(holdings) > 0
trades = storage.load_file("trades")
assert trades is not None
assert len(trades) == 1
assert trades[0]["ticker"] == "AAPL"
assert trades[0]["qty"] == 50
assert trades[0]["price"] == 200.0
def test_generate_summary(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
state = {
"portfolio_state": {
"cash": 50000.0,
"positions": {"AAPL": {"long": 100, "short": 0}},
"margin_used": 0.0,
},
"equity_history": [{"t": 1000, "v": 100000}],
"all_trades": [],
}
prices = {"AAPL": 500.0}
storage._generate_summary(state, 100000.0, prices)
summary = storage.load_file("summary")
assert summary["totalAssetValue"] == 100000.0
assert summary["totalReturn"] == 0.0
def test_generate_holdings(self):
from backend.services.storage import StorageService
with tempfile.TemporaryDirectory() as tmpdir:
storage = StorageService(
dashboard_dir=Path(tmpdir),
initial_cash=100000.0,
)
state = {
"portfolio_state": {
"cash": 50000.0,
"positions": {"AAPL": {"long": 100, "short": 0}},
"margin_used": 0.0,
},
}
prices = {"AAPL": 500.0}
storage._generate_holdings(state, prices)
holdings = storage.load_file("holdings")
assert len(holdings) == 2 # AAPL + CASH
aapl_holding = next(
(h for h in holdings if h["ticker"] == "AAPL"),
None,
)
assert aapl_holding is not None
assert aapl_holding["quantity"] == 100
assert aapl_holding["currentPrice"] == 500.0
class TestTradeExecutor:
def test_execute_trade_long(self):
from backend.utils.trade_executor import PortfolioTradeExecutor
executor = PortfolioTradeExecutor(
initial_portfolio={
"cash": 100000.0,
"positions": {},
"margin_requirement": 0.25,
"margin_used": 0.0,
},
)
result = executor.execute_trade(
ticker="AAPL",
action="long",
quantity=10,
price=150.0,
)
assert result["status"] == "success"
assert executor.portfolio["positions"]["AAPL"]["long"] == 10
assert executor.portfolio["cash"] == 98500.0 # 100000 - 10*150
def test_execute_trade_short(self):
from backend.utils.trade_executor import PortfolioTradeExecutor
executor = PortfolioTradeExecutor(
initial_portfolio={
"cash": 100000.0,
"positions": {
"AAPL": {
"long": 50,
"short": 0,
"long_cost_basis": 100.0,
"short_cost_basis": 0.0,
},
},
"margin_requirement": 0.25,
"margin_used": 0.0,
},
)
result = executor.execute_trade(
ticker="AAPL",
action="short",
quantity=30,
price=150.0,
)
assert result["status"] == "success"
assert executor.portfolio["positions"]["AAPL"]["long"] == 20 # 50 - 30
def test_execute_trade_hold(self):
from backend.utils.trade_executor import PortfolioTradeExecutor
executor = PortfolioTradeExecutor()
result = executor.execute_trade(
ticker="AAPL",
action="hold",
quantity=0,
price=150.0,
)
assert result["status"] == "success"
assert result["message"] == "No trade needed"
class TestPipelineExecution:
def test_execute_decisions(self):
from backend.core.pipeline import TradingPipeline
from backend.agents.portfolio_manager import PMAgent
mock_model = MagicMock()
mock_formatter = MagicMock()
pm = PMAgent(
model=mock_model,
formatter=mock_formatter,
initial_cash=100000.0,
)
pipeline = TradingPipeline(
analysts=[],
risk_manager=MagicMock(),
portfolio_manager=pm,
max_comm_cycles=0,
)
decisions = {
"AAPL": {"action": "long", "quantity": 10},
"GOOGL": {"action": "short", "quantity": 5},
}
prices = {"AAPL": 150.0, "GOOGL": 100.0}
result = pipeline._execute_decisions(decisions, prices, "2024-01-15")
assert len(result["executed_trades"]) == 2
assert result["executed_trades"][0]["ticker"] == "AAPL"
assert result["executed_trades"][0]["quantity"] == 10
assert pm.portfolio["positions"]["AAPL"]["long"] == 10
class TestMsgContentIsString:
def test_msg_content_string(self):
msg = Msg(name="test", content="simple string", role="user")
assert isinstance(msg.content, str)
def test_msg_content_json_string(self):
data = {"key": "value", "nested": {"a": 1}}
msg = Msg(name="test", content=json.dumps(data), role="user")
assert isinstance(msg.content, str)
parsed = json.loads(msg.content)
assert parsed["key"] == "value"
def test_msg_content_should_not_be_dict(self):
data = {"key": "value"}
msg = Msg(name="test", content=json.dumps(data), role="assistant")
assert not isinstance(msg.content, dict)
assert isinstance(msg.content, str)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
from datetime import datetime, timedelta
from backend.tools.analysis_tools import _resolved_date
def test_resolved_date_clamps_future_date():
future_date = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
assert _resolved_date(future_date) == datetime.today().strftime("%Y-%m-%d")

235
backend/tests/test_cli.py Normal file
View File

@@ -0,0 +1,235 @@
# -*- coding: utf-8 -*-
from pathlib import Path
from backend import cli
def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp_path):
project_root = tmp_path
(project_root / ".env").write_text("FINNHUB_API_KEY=test\n", encoding="utf-8")
calls = []
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
monkeypatch.setattr(
cli,
"auto_update_market_store",
lambda config_name, end_date=None: calls.append(("auto_update_market_store", config_name, end_date)),
)
monkeypatch.setattr(
cli,
"auto_enrich_market_store",
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
),
)
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
def fake_run(cmd, check=True, **kwargs):
calls.append(("subprocess.run", cmd, check))
return 0
monkeypatch.setattr(cli.subprocess, "run", fake_run)
cli.live(
config_name="smoke_fullstack",
host="0.0.0.0",
port=8765,
trigger_time="now",
poll_interval=10,
clean=False,
enable_memory=False,
)
assert any(item[0] == "run_data_updater" for item in calls)
assert any(
item[0] == "auto_update_market_store" and item[1] == "smoke_fullstack"
for item in calls
)
assert any(
item[0] == "auto_enrich_market_store" and item[1] == "smoke_fullstack"
for item in calls
)
run_call = next(item for item in calls if item[0] == "subprocess.run")
assert run_call[1][:6] == [
cli.sys.executable,
"-u",
"-m",
"backend.main",
"--mode",
"live",
]
def test_backtest_runs_full_market_store_prepare_before_start(monkeypatch, tmp_path):
project_root = tmp_path
calls = []
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
monkeypatch.setattr(
cli,
"auto_prepare_backtest_market_store",
lambda config_name, start_date, end_date: calls.append(
("auto_prepare_backtest_market_store", config_name, start_date, end_date)
),
)
monkeypatch.setattr(
cli,
"auto_enrich_market_store",
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
),
)
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
def fake_run(cmd, check=True, **kwargs):
calls.append(("subprocess.run", cmd, check))
return 0
monkeypatch.setattr(cli.subprocess, "run", fake_run)
cli.backtest(
start="2026-03-01",
end="2026-03-10",
config_name="smoke_fullstack",
host="0.0.0.0",
port=8765,
poll_interval=10,
clean=False,
enable_memory=False,
)
assert any(item[0] == "run_data_updater" for item in calls)
assert any(
item[0] == "auto_prepare_backtest_market_store"
and item[1:] == ("smoke_fullstack", "2026-03-01", "2026-03-10")
for item in calls
)
assert any(
item[0] == "auto_enrich_market_store"
and item[1] == "smoke_fullstack"
and item[2] == "2026-03-10"
for item in calls
)
run_call = next(item for item in calls if item[0] == "subprocess.run")
assert run_call[1][:6] == [
cli.sys.executable,
"-u",
"-m",
"backend.main",
"--mode",
"backtest",
]
def test_ingest_enrich_runs_batch_enrichment(monkeypatch):
calls = []
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL", "MSFT"])
class DummyStore:
pass
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
monkeypatch.setattr(
cli,
"enrich_symbols",
lambda store, symbols, start_date=None, end_date=None, limit=200, analysis_source="local", skip_existing=True: calls.append(
("enrich_symbols", symbols, start_date, end_date, limit, analysis_source, skip_existing)
) or [
{
"symbol": symbol,
"news_count": 3,
"queued_count": 3,
"analyzed": 3,
"skipped_existing_count": 0,
"deduped_count": 0,
"llm_count": 0,
"local_count": 3,
}
for symbol in symbols
],
)
cli.ingest_enrich(
tickers=None,
start="2026-03-01",
end="2026-03-10",
limit=150,
force=False,
config_name="smoke_fullstack",
)
assert calls == [
("enrich_symbols", ["AAPL", "MSFT"], "2026-03-01", "2026-03-10", 150, "local", True)
]
def test_ingest_report_reads_market_store_report(monkeypatch):
calls = []
printed = []
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL"])
class DummyStore:
def get_enrich_report(self, symbols=None, start_date=None, end_date=None):
calls.append(("get_enrich_report", symbols, start_date, end_date))
return [
{
"symbol": "AAPL",
"raw_news_count": 10,
"analyzed_news_count": 8,
"coverage_pct": 80.0,
"llm_count": 5,
"local_count": 3,
"latest_trade_date": "2026-03-16",
"latest_analysis_at": "2026-03-16T09:00:00",
}
]
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
monkeypatch.setattr(cli, "get_explain_model_info", lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"})
monkeypatch.setattr(cli, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
cli.ingest_report(
tickers=None,
start="2026-03-01",
end="2026-03-16",
config_name="smoke_fullstack",
only_problematic=False,
)
assert calls == [
("get_enrich_report", ["AAPL"], "2026-03-01", "2026-03-16")
]
assert printed
assert getattr(printed[0], "caption", "") == "Explain LLM: DASHSCOPE:qwen-max"
def test_filter_problematic_report_rows_keeps_low_coverage_and_no_llm():
rows = [
{
"symbol": "AAPL",
"coverage_pct": 100.0,
"llm_count": 2,
},
{
"symbol": "MSFT",
"coverage_pct": 80.0,
"llm_count": 1,
},
{
"symbol": "NVDA",
"coverage_pct": 100.0,
"llm_count": 0,
},
]
filtered = cli._filter_problematic_report_rows(rows)
assert [row["symbol"] for row in filtered] == ["MSFT", "NVDA"]

View File

@@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
"""Tests for data source config ordering."""
from backend.config.data_config import get_config, reset_config
def test_data_config_prefers_env_source(monkeypatch):
monkeypatch.setenv("FIN_DATA_SOURCE", "financial_datasets")
monkeypatch.setenv("FINNHUB_API_KEY", "fh")
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd")
reset_config()
config = get_config()
assert config.sources[0] == "financial_datasets"
assert "local_csv" in config.sources
def test_enabled_data_sources_filters_available_sources(monkeypatch):
monkeypatch.setenv("FINNHUB_API_KEY", "fh-key")
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd-key")
monkeypatch.setenv("ENABLED_DATA_SOURCES", "financial_datasets,local_csv")
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
reset_config()
config = get_config()
assert config.sources == ["financial_datasets", "local_csv"]
assert config.source == "financial_datasets"
def test_preferred_source_reorders_enabled_sources(monkeypatch):
monkeypatch.setenv("FINNHUB_API_KEY", "fh-key")
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd-key")
monkeypatch.setenv("ENABLED_DATA_SOURCES", "financial_datasets,finnhub,local_csv")
monkeypatch.setenv("FIN_DATA_SOURCE", "finnhub")
reset_config()
config = get_config()
assert config.sources == ["finnhub", "financial_datasets", "local_csv"]
assert config.source == "finnhub"
def test_yfinance_can_be_enabled_without_api_key(monkeypatch):
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
monkeypatch.setenv("FIN_DATA_SOURCE", "yfinance")
monkeypatch.setenv("ENABLED_DATA_SOURCES", "yfinance,local_csv")
reset_config()
config = get_config()
assert config.sources == ["yfinance", "local_csv"]
assert config.source == "yfinance"

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""Tests for data_tools preferring split services when configured."""
from backend.tools import data_tools
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_data_tools_prefers_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "agent_service")
monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None)
def fake_service_get_json(base_url, path, *, params):
if path == "/api/prices":
return {
"ticker": "AAPL",
"prices": [
Price(
open=1,
close=2,
high=3,
low=1,
volume=10,
time="2026-03-16",
).model_dump()
],
}
if path == "/api/financials":
return {
"financial_metrics": [
FinancialMetrics(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
).model_dump()
]
}
if path == "/api/insider-trades":
return {
"insider_trades": [
InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump()
]
}
if path == "/api/news":
return {
"news": [
CompanyNews(
ticker="AAPL",
title="Title",
source="polygon",
url="https://example.com",
).model_dump()
]
}
if path == "/api/market-cap":
return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12}
if path == "/api/line-items":
return {
"search_results": [
LineItem(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
free_cash_flow=321.0,
).model_dump()
]
}
raise AssertionError(path)
monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json)
prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16")
metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16")
trades = data_tools.get_insider_trades("AAPL", "2026-03-16")
news = data_tools.get_company_news("AAPL", "2026-03-16")
market_cap = data_tools.get_market_cap("AAPL", "2026-03-16")
line_items = data_tools.search_line_items(
"AAPL",
["free_cash_flow"],
"2026-03-16",
)
assert prices[0].close == 2
assert metrics[0].ticker == "AAPL"
assert trades[0].ticker == "AAPL"
assert news[0].ticker == "AAPL"
assert market_cap == 2.5e12
assert line_items[0].free_cash_flow == 321.0
def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "trading_service")
assert data_tools._trading_service_url() is None

View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
"""Tests for normalized env config helpers."""
from backend.config.env_config import (
canonicalize_model_provider,
get_agent_model_config,
)
def test_canonicalize_model_provider_aliases():
assert canonicalize_model_provider("claude") == "ANTHROPIC"
assert canonicalize_model_provider("openai_compatible") == "OPENAI"
assert canonicalize_model_provider("google") == "GEMINI"
def test_get_agent_model_config_fallback(monkeypatch):
monkeypatch.delenv("AGENT_RISK_MANAGER_MODEL_NAME", raising=False)
monkeypatch.delenv("AGENT_RISK_MANAGER_MODEL_PROVIDER", raising=False)
monkeypatch.setenv("MODEL_NAME", "gpt-4o-mini")
monkeypatch.setenv("MODEL_PROVIDER", "openai")
config = get_agent_model_config("risk_manager")
assert config.model_name == "gpt-4o-mini"
assert config.provider == "OPENAI"

View File

@@ -0,0 +1,934 @@
# -*- coding: utf-8 -*-
import json
from types import SimpleNamespace
import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse
class DummyWebSocket:
def __init__(self):
self.messages = []
async def send(self, payload: str):
self.messages.append(json.loads(payload))
class DummyStateSync:
def __init__(self, current_date="2026-03-16"):
self.state = {"current_date": current_date}
self.system_messages = []
def set_broadcast_fn(self, _fn):
return None
def update_state(self, *_args, **_kwargs):
return None
async def on_system_message(self, message):
self.system_messages.append(message)
class FakeMarketStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100):
self.calls.append(("get_news_items", symbol, start_date, end_date, limit))
return [
{
"id": "news-1",
"ticker": symbol,
"date": end_date,
"trade_date": end_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
}
]
def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit))
target_date = trade_date or end_date
return [
{
"id": "news-1",
"ticker": symbol,
"date": target_date,
"trade_date": target_date,
"title": "Title",
"summary": "Summary",
"source": "polygon",
"sentiment": "negative",
"relevance": "high",
"key_discussion": "Key discussion",
}
]
def get_news_by_ids_enriched(self, symbol, article_ids):
self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids)))
return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}]
def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit))
return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}}
def get_story_cache(self, symbol, *, as_of_date):
self.calls.append(("get_story_cache", symbol, as_of_date))
return None
def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"):
self.calls.append(("upsert_story_cache", symbol, as_of_date, source))
def delete_story_cache(self, symbol, *, as_of_date=None):
self.calls.append(("delete_story_cache", symbol, as_of_date))
return 1
def get_similar_day_cache(self, symbol, *, target_date):
self.calls.append(("get_similar_day_cache", symbol, target_date))
return None
def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"):
self.calls.append(("upsert_similar_day_cache", symbol, target_date, source))
def delete_similar_day_cache(self, symbol, *, target_date=None):
self.calls.append(("delete_similar_day_cache", symbol, target_date))
return 1
def get_ohlc(self, symbol, start_date, end_date):
self.calls.append(("get_ohlc", symbol, start_date, end_date))
return [
{"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103},
{"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107},
]
def make_gateway(market_store=None):
storage = SimpleNamespace(market_store=market_store or FakeMarketStore())
pipeline = SimpleNamespace(state_sync=None)
market_service = SimpleNamespace()
state_sync = DummyStateSync()
return Gateway(
market_service=market_service,
storage_service=storage,
pipeline=pipeline,
state_sync=state_sync,
config={"mode": "live"},
)
class FakeNewsClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_categories(self, ticker, start_date=None, end_date=None, limit=200):
return {"ticker": ticker, "categories": {"remote": {"count": 2}}}
async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None):
return {
"ticker": ticker,
"news": [
{
"id": "remote-news-1",
"ticker": ticker,
"title": "Remote Title",
"date": end_date,
}
],
}
async def get_story(self, ticker, as_of_date):
return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"}
class FakeTradingClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None):
return InsiderTradeResponse(
insider_trades=[
InsiderTrade(
ticker=ticker,
name="Remote Insider",
filing_date=end_date or "2026-03-16",
)
]
)
async def get_prices(self, ticker, start_date=None, end_date=None):
prices = [
Price(
open=float(100 + idx),
close=float(101 + idx),
high=float(102 + idx),
low=float(99 + idx),
volume=1000 + idx,
time=f"2026-01-{idx + 1:02d}",
)
for idx in range(30)
]
return PriceResponse(ticker=ticker, prices=prices)
async def get_market_cap(self, ticker, end_date):
return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_timeline(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
]
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
]
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["macro"]["count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows):
return {
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"news_count": len(news_rows),
}
monkeypatch.setattr(
gateway_module.news_domain,
"build_range_explanation",
fake_build_range_explanation,
)
await gateway._handle_get_stock_range_explain(
websocket,
{"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"},
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
]
assert websocket.messages[-1] == {
"type": "stock_range_explain_loaded",
"ticker": "AAPL",
"result": {
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"news_count": 1,
},
}
@pytest.mark.asyncio
async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module.news_domain,
"build_range_explanation",
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
)
await gateway._handle_get_stock_range_explain(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"article_ids": ["news-99"],
},
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_by_ids_enriched", "AAPL", ["news-99"])
]
assert websocket.messages[-1]["result"]["news_count"] == 1
@pytest.mark.asyncio
async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
await gateway._handle_get_stock_news_for_date(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "limit": 10},
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
]
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
assert websocket.messages[-1]["date"] == "2026-03-16"
@pytest.mark.asyncio
async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert "AAPL Story" in websocket.messages[-1]["story"]
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["remote"]["count"] == 2
@pytest.mark.asyncio
async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["story"] == "remote story"
@pytest.mark.asyncio
async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news(
websocket,
{"ticker": "AAPL", "lookback_days": 30, "limit": 5},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_loaded"
assert websocket.messages[-1]["source"] == "news_service"
assert websocket.messages[-1]["news"][0]["title"] == "Remote Title"
@pytest.mark.asyncio
async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_insider_trades(
websocket,
{"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10},
)
assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded"
assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider"
@pytest.mark.asyncio
async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_history(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_history_loaded"
assert websocket.messages[-1]["source"] == "trading_service"
assert len(websocket.messages[-1]["prices"]) == 30
@pytest.mark.asyncio
async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_technical_indicators(
websocket,
{"ticker": "AAPL"},
)
assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert websocket.messages[-1]["indicators"] is not None
@pytest.mark.asyncio
async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16")
assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
await gateway._handle_get_stock_similar_days(
websocket,
{"ticker": "AAPL", "date": "2026-03-16", "top_k": 5},
)
assert websocket.messages[-1]["type"] == "stock_similar_days_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert isinstance(websocket.messages[-1]["items"], list)
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module.gateway_stock_handlers,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
monkeypatch.setattr(
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"force": True,
"rebuild_story": True,
"rebuild_similar_days": True,
"story_date": "2026-03-16",
"target_date": "2026-03-16",
},
)
assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls
assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert websocket.messages[-1]["stats"]["analyzed"] == 2
@pytest.mark.asyncio
async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False)
await gateway._handle_run_stock_enrich(
websocket,
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"only_local_to_llm": True,
},
)
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"]
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["coro_name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task)
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
assert captured["coro_name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway()
monkeypatch.setattr(
gateway_module.gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
for symbol in symbols
],
)
await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1]
@pytest.mark.asyncio
async def test_handle_get_agent_skills_returns_statuses(tmp_path):
builtin_root = tmp_path / "backend" / "skills" / "builtin"
for name in ("risk_review", "extra_guard"):
skill_dir = builtin_root / name
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
f"---\nname: {name}\ndescription: {name} desc\n---\n",
encoding="utf-8",
)
agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
agent_dir.mkdir(parents=True, exist_ok=True)
(agent_dir / "agent.yaml").write_text(
"enabled_skills:\n"
" - extra_guard\n"
"disabled_skills:\n"
" - risk_review\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
await gateway._handle_get_agent_skills(
websocket,
{"agent_id": "risk_manager"},
)
assert websocket.messages[-1]["type"] == "agent_skills_loaded"
statuses = {
row["skill_name"]: row["status"]
for row in websocket.messages[-1]["skills"]
}
assert statuses["extra_guard"] == "enabled"
assert statuses["risk_review"] == "disabled"
@pytest.mark.asyncio
async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatch, tmp_path):
agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
agent_dir.mkdir(parents=True, exist_ok=True)
(agent_dir / "agent.yaml").write_text(
"prompt_files:\n"
" - SOUL.md\n"
" - MEMORY.md\n"
"active_tool_groups:\n"
" - risk_ops\n"
"disabled_tool_groups:\n"
" - legacy_group\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module.gateway_admin_handlers,
"load_agent_profiles",
lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}},
)
monkeypatch.setattr(
gateway_module.gateway_admin_handlers,
"get_agent_model_info",
lambda agent_id: ("gpt-4o-mini", "OPENAI"),
)
class _Bootstrap:
@staticmethod
def agent_override(_agent_id):
return {}
monkeypatch.setattr(
gateway_module.gateway_admin_handlers,
"get_bootstrap_config_for_run",
lambda project_root, config_name: _Bootstrap(),
)
await gateway._handle_get_agent_profile(
websocket,
{"agent_id": "risk_manager"},
)
assert websocket.messages[-1]["type"] == "agent_profile_loaded"
profile = websocket.messages[-1]["profile"]
assert profile["model_name"] == "gpt-4o-mini"
assert profile["model_provider"] == "OPENAI"
assert profile["prompt_files"] == ["SOUL.md", "MEMORY.md"]
assert profile["active_tool_groups"] == ["risk_ops"]
assert profile["disabled_tool_groups"] == ["legacy_group"]
@pytest.mark.asyncio
async def test_handle_get_skill_detail_returns_markdown_body(tmp_path):
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: 风险审查\ndescription: 说明\nversion: 1.0.0\n---\n# 风险审查\n\n完整正文\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway._project_root = tmp_path
websocket = DummyWebSocket()
await gateway._handle_get_skill_detail(
websocket,
{"skill_name": "risk_review"},
)
assert websocket.messages[-1]["type"] == "skill_detail_loaded"
assert websocket.messages[-1]["skill"]["name"] == "风险审查"
assert websocket.messages[-1]["skill"]["version"] == "1.0.0"
assert websocket.messages[-1]["skill"]["content"] == "# 风险审查\n\n完整正文"
@pytest.mark.asyncio
async def test_handle_get_skill_detail_prefers_agent_local_skill(tmp_path):
skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: 本地风控\ndescription: 本地说明\nversion: 1.0.0\n---\n# 本地风控\n\n本地正文\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
await gateway._handle_get_skill_detail(
websocket,
{"agent_id": "risk_manager", "skill_name": "local_guard"},
)
assert websocket.messages[-1]["type"] == "skill_detail_loaded"
assert websocket.messages[-1]["agent_id"] == "risk_manager"
assert websocket.messages[-1]["skill"]["source"] == "local"
assert websocket.messages[-1]["skill"]["content"] == "# 本地风控\n\n本地正文"
@pytest.mark.asyncio
async def test_handle_update_agent_skill_persists_and_returns_refresh(monkeypatch, tmp_path):
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "extra_guard"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: extra_guard\ndescription: desc\n---\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
async def _noop_reload():
return None
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
await gateway._handle_update_agent_skill(
websocket,
{
"agent_id": "risk_manager",
"skill_name": "extra_guard",
"enabled": True,
},
)
assert websocket.messages[0]["type"] == "agent_skill_updated"
assert websocket.messages[-1]["type"] == "agent_skills_loaded"
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
assert "extra_guard" in agent_yaml.read_text(encoding="utf-8")
@pytest.mark.asyncio
async def test_handle_create_and_update_agent_local_skill(monkeypatch, tmp_path):
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
async def _noop_reload():
return None
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
await gateway._handle_create_agent_local_skill(
websocket,
{"agent_id": "risk_manager", "skill_name": "local_guard"},
)
assert websocket.messages[0]["type"] == "agent_local_skill_created"
assert websocket.messages[1]["type"] == "agent_skills_loaded"
assert websocket.messages[2]["type"] == "skill_detail_loaded"
target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard" / "SKILL.md"
assert target.exists()
websocket.messages.clear()
await gateway._handle_update_agent_local_skill(
websocket,
{
"agent_id": "risk_manager",
"skill_name": "local_guard",
"content": "---\nname: 本地风控\ndescription: 更新后\nversion: 1.0.0\n---\n# 本地风控\n\n更新正文\n",
},
)
assert websocket.messages[0]["type"] == "agent_local_skill_updated"
assert websocket.messages[1]["type"] == "skill_detail_loaded"
assert "更新正文" in target.read_text(encoding="utf-8")
@pytest.mark.asyncio
async def test_handle_delete_agent_local_skill(monkeypatch, tmp_path):
skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: 本地风控\ndescription: desc\nversion: 1.0.0\n---\n",
encoding="utf-8",
)
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
agent_yaml.parent.mkdir(parents=True, exist_ok=True)
agent_yaml.write_text(
"enabled_skills:\n"
" - local_guard\n"
"disabled_skills:\n"
" - local_guard\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
async def _noop_reload():
return None
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
await gateway._handle_delete_agent_local_skill(
websocket,
{"agent_id": "risk_manager", "skill_name": "local_guard"},
)
assert websocket.messages[0]["type"] == "agent_local_skill_deleted"
assert websocket.messages[1]["type"] == "agent_skills_loaded"
assert not skill_dir.exists()
assert "local_guard" not in agent_yaml.read_text(encoding="utf-8")
@pytest.mark.asyncio
async def test_handle_remove_agent_skill_marks_disabled(monkeypatch, tmp_path):
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\nname: 风险审查\ndescription: desc\nversion: 1.0.0\n---\n",
encoding="utf-8",
)
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
async def _noop_reload():
return None
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
await gateway._handle_remove_agent_skill(
websocket,
{"agent_id": "risk_manager", "skill_name": "risk_review"},
)
assert websocket.messages[0]["type"] == "agent_skill_removed"
assert websocket.messages[1]["type"] == "agent_skills_loaded"
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
assert "risk_review" in agent_yaml.read_text(encoding="utf-8")
@pytest.mark.asyncio
async def test_handle_get_agent_workspace_file_returns_content(tmp_path):
file_path = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md"
file_path.parent.mkdir(parents=True, exist_ok=True)
file_path.write_text("soul content", encoding="utf-8")
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
await gateway._handle_get_agent_workspace_file(
websocket,
{"agent_id": "risk_manager", "filename": "SOUL.md"},
)
assert websocket.messages[-1] == {
"type": "agent_workspace_file_loaded",
"config_name": "demo",
"agent_id": "risk_manager",
"filename": "SOUL.md",
"content": "soul content",
}
@pytest.mark.asyncio
async def test_handle_update_agent_workspace_file_persists_and_returns_refresh(monkeypatch, tmp_path):
gateway = make_gateway()
gateway.config["config_name"] = "demo"
gateway._project_root = tmp_path
websocket = DummyWebSocket()
async def _noop_reload():
return None
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
await gateway._handle_update_agent_workspace_file(
websocket,
{
"agent_id": "risk_manager",
"filename": "SOUL.md",
"content": "updated soul",
},
)
assert websocket.messages[0]["type"] == "agent_workspace_file_updated"
assert websocket.messages[-1]["type"] == "agent_workspace_file_loaded"
target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md"
assert target.read_text(encoding="utf-8") == "updated soul"

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
"""Direct tests for Gateway support modules."""
from types import SimpleNamespace
import pytest
from backend.services import gateway_cycle_support, gateway_runtime_support
class _DummyScheduler:
def __init__(self):
self.calls = []
def reconfigure(self, **kwargs):
self.calls.append(kwargs)
class _DummyStateSync:
def __init__(self):
self.updated = []
self.saved = False
self.system_messages = []
self.backtest_dates = []
self.state = {}
def update_state(self, key, value):
self.updated.append((key, value))
self.state[key] = value
def save_state(self):
self.saved = True
async def on_system_message(self, message):
self.system_messages.append(message)
def set_backtest_dates(self, dates):
self.backtest_dates = list(dates)
class _DummyStorage:
def __init__(self):
self.initial_cash = 100000.0
self.is_live_session_active = False
self.server_state_updates = []
def can_apply_initial_cash(self):
return True
def apply_initial_cash(self, value):
self.initial_cash = value
return True
def update_server_state_from_dashboard(self, state):
self.server_state_updates.append(state)
def load_file(self, name):
if name == "summary":
return {"totalAssetValue": self.initial_cash}
return []
def build_dashboard_snapshot_from_state(self, state):
return {
"summary": {"totalAssetValue": self.initial_cash},
"holdings": [],
"stats": {},
"trades": [],
"leaderboard": [],
}
class _DummyPM:
def __init__(self):
self.portfolio = {"margin_requirement": 0.0}
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
if margin_requirement is not None:
self.portfolio["margin_requirement"] = margin_requirement
return {"margin_requirement": True}
def can_apply_initial_cash(self):
return True
class _DummyMarketService:
def __init__(self):
self.updated = None
self.stopped = False
def update_tickers(self, tickers):
self.updated = list(tickers)
return {"active": list(tickers), "added": list(tickers), "removed": []}
def stop(self):
self.stopped = True
def make_gateway_stub():
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
gateway = SimpleNamespace(
market_service=_DummyMarketService(),
pipeline=pipeline,
scheduler=_DummyScheduler(),
config={
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"enable_memory": False,
},
storage=_DummyStorage(),
state_sync=_DummyStateSync(),
_watchlist_ingest_task=None,
_market_status_task=None,
_backtest_task=None,
_backtest_start_date=None,
_backtest_end_date=None,
_manual_cycle_task=None,
)
return gateway
def test_normalize_watchlist_filters_invalid_and_dedupes():
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
def test_normalize_agent_workspace_filename_obeys_allowlist():
allowlist = {"SOUL.md", "PROFILE.md"}
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
def test_apply_runtime_config_updates_gateway_state():
gateway = make_gateway_stub()
result = gateway_runtime_support.apply_runtime_config(
gateway,
{
"tickers": ["MSFT", "NVDA"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:30",
"initial_cash": 150000.0,
"margin_requirement": 0.5,
"max_comm_cycles": 4,
"enable_memory": False,
},
)
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
assert gateway.config["schedule_mode"] == "intraday"
assert gateway.storage.initial_cash == 150000.0
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
assert gateway.scheduler.calls[-1] == {
"mode": "intraday",
"trigger_time": "10:30",
"interval_minutes": 30,
}
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway_stub()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
assert captured["name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway_stub()
monkeypatch.setattr(
gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4}
for symbol in symbols
],
)
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
from unittest.mock import patch
import pandas as pd
from backend.data.historical_price_manager import HistoricalPriceManager
def test_preload_data_prefers_market_db():
manager = HistoricalPriceManager()
manager.subscribe(["AAPL"])
market_rows = [
{
"symbol": "AAPL",
"date": "2026-03-09",
"open": 100.0,
"high": 103.0,
"low": 99.0,
"close": 102.0,
"volume": 10_000,
"vwap": 101.0,
"transactions": 500,
"source": "polygon",
}
]
with (
patch.object(manager._market_store, "get_ohlc", return_value=market_rows),
patch.object(manager._router, "load_local_price_frame") as load_csv,
):
manager.preload_data("2026-03-01", "2026-03-10")
load_csv.assert_not_called()
assert "AAPL" in manager._price_cache
assert float(manager._price_cache["AAPL"].iloc[0]["close"]) == 102.0
def test_preload_data_falls_back_to_csv():
manager = HistoricalPriceManager()
manager.subscribe(["MSFT"])
csv_df = pd.DataFrame(
{
"time": ["2026-03-09"],
"open": [200.0],
"high": [205.0],
"low": [198.0],
"close": [204.0],
"volume": [20_000],
}
)
csv_df["time"] = pd.to_datetime(csv_df["time"])
csv_df["Date"] = csv_df["time"]
csv_df.set_index("Date", inplace=True)
with (
patch.object(manager._market_store, "get_ohlc", return_value=[]),
patch.object(manager._router, "load_local_price_frame", return_value=csv_df) as load_csv,
):
manager.preload_data("2026-03-01", "2026-03-10")
load_csv.assert_called_once_with("MSFT")
assert "MSFT" in manager._price_cache
assert float(manager._price_cache["MSFT"].iloc[0]["close"]) == 204.0

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
from backend.enrich import llm_enricher
class DummyResponse:
def __init__(self, metadata):
self.metadata = metadata
class DummyModel:
def __init__(self, metadata):
self.metadata = metadata
self.calls = []
async def __call__(self, messages, structured_model=None, **kwargs):
self.calls.append(
{
"messages": messages,
"structured_model": structured_model,
"kwargs": kwargs,
}
)
return DummyResponse(self.metadata)
def test_analyze_news_row_with_llm_uses_agentscope_model(monkeypatch):
model = DummyModel(
{
"id": "news-1",
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Demand remains resilient",
"summary": "Structured summary",
"reason_growth": "Orders improved",
"reason_decrease": "",
}
)
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_news_row_with_llm(
{
"id": "news-1",
"title": "Apple expands AI features",
"summary": "New devices and software updates were announced.",
}
)
assert result["sentiment"] == "positive"
assert result["summary"] == "Structured summary"
assert result["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsItem
def test_analyze_news_rows_with_llm_uses_agentscope_structured_batch(monkeypatch):
model = DummyModel(
{
"items": [
{
"id": "news-1",
"relevance": "high",
"sentiment": "negative",
"key_discussion": "Margin pressure",
"summary": "Batch summary",
"reason_growth": "",
"reason_decrease": "Costs rose",
}
]
}
)
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_news_rows_with_llm(
[
{
"id": "news-1",
"title": "Apple margins pressured",
"summary": "Costs increased this quarter.",
}
]
)
assert result["news-1"]["sentiment"] == "negative"
assert result["news-1"]["reason_decrease"] == "Costs rose"
assert result["news-1"]["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsBatch
def test_analyze_range_with_llm_uses_agentscope_structured_output(monkeypatch):
model = DummyModel(
{
"summary": "该股在区间内震荡下行,相关新闻主要集中在盈利预期和供应链扰动。",
"trend_analysis": "前半段受利空新闻压制,后半段跌幅收敛。",
"bullish_factors": ["估值消化后出现部分承接"],
"bearish_factors": ["盈利预期下修", "供应链扰动持续"],
}
)
monkeypatch.setattr(llm_enricher, "llm_range_analysis_enabled", lambda: True)
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
monkeypatch.setattr(
llm_enricher,
"get_explain_model_info",
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
)
result = llm_enricher.analyze_range_with_llm(
{
"ticker": "AAPL",
"start_date": "2026-03-10",
"end_date": "2026-03-16",
"price_change_pct": -3.42,
}
)
assert result["summary"].startswith("该股在区间内震荡下行")
assert result["model_label"] == "DASHSCOPE:qwen-max"
assert result["bearish_factors"] == ["盈利预期下修", "供应链扰动持续"]
assert model.calls
assert model.calls[0]["structured_model"] is llm_enricher.RangeAnalysisPayload

View File

@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
"""Tests for market ingest watermark handling."""
from backend.data import market_ingest
class _FakeStore:
def __init__(self, *, last_news_fetch=None, latest_news_date=None):
self._watermarks = {
"symbol": "AAPL",
"last_price_fetch": None,
"last_news_fetch": last_news_fetch,
}
self._latest_news_date = latest_news_date
self.updated = []
def get_ticker_watermarks(self, symbol):
return dict(self._watermarks)
def get_latest_news_date(self, symbol):
return self._latest_news_date
def upsert_ticker(self, **kwargs):
return None
def upsert_ohlc(self, symbol, rows, source="polygon"):
return len(rows)
def upsert_news(self, symbol, rows, source="polygon"):
return len(rows)
def update_fetch_watermark(self, **kwargs):
self.updated.append(kwargs)
def test_refresh_news_incremental_does_not_advance_watermark_without_news(monkeypatch):
store = _FakeStore(last_news_fetch="2026-03-28", latest_news_date="2026-03-28")
monkeypatch.setattr(market_ingest, "fetch_ticker_details", lambda ticker: {"name": ticker, "sic_description": None, "active": True})
class _Router:
def get_company_news(self, **kwargs):
return [], "polygon"
monkeypatch.setattr(market_ingest, "DataProviderRouter", lambda: _Router())
monkeypatch.setattr(market_ingest, "align_news_for_symbol", lambda store, ticker: 0)
result = market_ingest.refresh_news_incremental(
"AAPL",
end_date="2026-03-29",
store=store,
)
assert result["start_news_date"] == "2026-03-29"
assert result["news"] == 0
assert store.updated[-1]["news_date"] is None
def test_refresh_news_incremental_clamps_future_watermark_to_latest_stored_date(monkeypatch):
store = _FakeStore(last_news_fetch="2026-03-30", latest_news_date="2026-03-28")
captured = {}
monkeypatch.setattr(market_ingest, "fetch_ticker_details", lambda ticker: {"name": ticker, "sic_description": None, "active": True})
class _Router:
def get_company_news(self, **kwargs):
captured.update(kwargs)
return [], "polygon"
monkeypatch.setattr(market_ingest, "DataProviderRouter", lambda: _Router())
monkeypatch.setattr(market_ingest, "align_news_for_symbol", lambda store, ticker: 0)
result = market_ingest.refresh_news_incremental(
"AAPL",
end_date="2026-03-29",
store=store,
)
assert result["start_news_date"] == "2026-03-29"
assert captured["start_date"] == "2026-03-29"
assert captured["end_date"] == "2026-03-29"

View File

@@ -0,0 +1,285 @@
# -*- coding: utf-8 -*-
# pylint: disable=W0212
import asyncio
import time
import logging
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
from backend.services.market import MarketService
from backend.data.polling_price_manager import PollingPriceManager
from backend.llm.models import RetryChatModel
class TestPollingPriceManager:
def test_init(self):
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
assert manager.api_key == "test_key"
assert manager.poll_interval == 30
assert manager.provider == "finnhub"
assert manager.running is False
def test_init_yfinance(self):
manager = PollingPriceManager(provider="yfinance", poll_interval=15)
assert manager.api_key is None
assert manager.poll_interval == 15
assert manager.provider == "yfinance"
assert manager.running is False
def test_subscribe(self):
manager = PollingPriceManager(api_key="test_key")
manager.subscribe(["AAPL", "MSFT"])
assert "AAPL" in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_unsubscribe(self):
manager = PollingPriceManager(api_key="test_key")
manager.subscribe(["AAPL", "MSFT"])
manager.unsubscribe(["AAPL"])
assert "AAPL" not in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_add_price_callback(self):
manager = PollingPriceManager(api_key="test_key")
callback = MagicMock()
manager.add_price_callback(callback)
assert callback in manager.price_callbacks
@patch.object(PollingPriceManager, "_fetch_prices")
def test_start_stop(self, _mock_fetch_prices):
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
manager.subscribe(["AAPL"])
manager.start()
assert manager.running is True
time.sleep(0.1)
manager.stop()
assert manager.running is False
def test_start_without_subscription(self):
manager = PollingPriceManager(api_key="test_key")
manager.start()
assert manager.running is False
def test_get_latest_price(self):
manager = PollingPriceManager(api_key="test_key")
manager.latest_prices["AAPL"] = 150.0
price = manager.get_latest_price("AAPL")
assert price == 150.0
def test_get_open_price(self):
manager = PollingPriceManager(api_key="test_key")
manager.open_prices["AAPL"] = 148.0
price = manager.get_open_price("AAPL")
assert price == 148.0
def test_reset_open_prices(self):
manager = PollingPriceManager(api_key="test_key")
manager.open_prices["AAPL"] = 150.0
manager.reset_open_prices()
assert len(manager.open_prices) == 0
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
with caplog.at_level(logging.DEBUG):
for _ in range(3):
manager._fetch_prices()
assert manager._failure_counts["AAPL"] == 3
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(
manager,
"_fetch_quote",
side_effect=[
ValueError("temporary outage"),
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
],
):
with caplog.at_level(logging.INFO):
manager._fetch_prices()
manager._fetch_prices()
assert "AAPL" not in manager._failure_counts
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
class TestRetryChatModel:
@pytest.mark.asyncio
async def test_async_retry_recovers_from_disconnect(self):
attempts = {"count": 0}
class FakeAsyncModel:
model_name = "fake-async-model"
async def __call__(self, *args, **kwargs):
attempts["count"] += 1
if attempts["count"] < 2:
raise RuntimeError("Server disconnected")
return {"ok": True}
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
result = await wrapped("hello")
assert result == {"ok": True}
assert attempts["count"] == 2
class TestMarketService:
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
)
service._start_real_mode()
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
)
service._start_real_mode()
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
@pytest.mark.asyncio
async def test_start_real_mode_without_api_key(self, _mock_sources):
service = MarketService(
tickers=["AAPL"],
api_key=None,
)
broadcast_func = AsyncMock()
with pytest.raises(ValueError) as excinfo:
await service.start(broadcast_func)
assert "API key required" in str(excinfo.value)
@pytest.mark.asyncio
async def test_start_already_running(self):
service = MarketService(
tickers=["AAPL"],
backtest_mode=True,
)
broadcast_func = AsyncMock()
# First start with backtest mode
await service.start(broadcast_func)
assert service.running is True
# Start again should not fail
await service.start(broadcast_func)
service.stop()
def test_stop(self):
service = MarketService(
tickers=["AAPL"],
backtest_mode=True,
)
service.running = True
service._price_manager = MagicMock()
service.stop()
assert service.running is False
assert service._price_manager is None
def test_stop_when_not_running(self):
service = MarketService(
tickers=["AAPL"],
backtest_mode=True,
)
# Should not raise
service.stop()
assert service.running is False
def test_get_price_sync(self):
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
price = service.get_price_sync("AAPL")
assert price == 150.0
def test_get_price_sync_not_found(self):
service = MarketService(tickers=["AAPL"], backtest_mode=True)
price = service.get_price_sync("MSFT")
assert price is None
def test_get_all_prices(self):
service = MarketService(tickers=["AAPL", "MSFT"], backtest_mode=True)
service.cache["AAPL"] = {"price": 150.0}
service.cache["MSFT"] = {"price": 400.0}
prices = service.get_all_prices()
assert prices["AAPL"] == 150.0
assert prices["MSFT"] == 400.0
@pytest.mark.asyncio
async def test_broadcast_price_update(self):
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service._broadcast_func = AsyncMock()
price_data = {
"symbol": "AAPL",
"price": 150.0,
"open": 148.0,
"timestamp": 1234567890,
}
await service._broadcast_price_update(price_data)
service._broadcast_func.assert_called_once()
call_args = service._broadcast_func.call_args[0][0]
assert call_args["type"] == "price_update"
assert call_args["symbol"] == "AAPL"
assert call_args["price"] == 150.0
@pytest.mark.asyncio
async def test_broadcast_price_update_no_func(self):
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service._broadcast_func = None
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
# Should not raise
await service._broadcast_price_update(price_data)
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from pathlib import Path
from backend.data.market_store import MarketStore
def test_get_enrich_report_summarizes_coverage(tmp_path: Path):
store = MarketStore(tmp_path / "market_research.db")
store.upsert_news(
"AAPL",
[
{
"id": "news-1",
"published_utc": "2026-03-10T12:00:00Z",
"title": "Apple earnings beat",
"summary": "Revenue topped expectations",
"tickers": ["AAPL"],
},
{
"id": "news-2",
"published_utc": "2026-03-11T12:00:00Z",
"title": "Apple supply chain warning",
"summary": "Outlook softened",
"tickers": ["AAPL"],
},
],
)
store.set_trade_dates(
[
{"news_id": "news-1", "symbol": "AAPL", "trade_date": "2026-03-10"},
{"news_id": "news-2", "symbol": "AAPL", "trade_date": "2026-03-11"},
]
)
store.upsert_news_analysis(
"AAPL",
[
{
"news_id": "news-1",
"trade_date": "2026-03-10",
"summary": "LLM enriched",
"analysis_source": "llm",
}
],
analysis_source="llm",
)
rows = store.get_enrich_report(["AAPL"])
assert len(rows) == 1
assert rows[0]["symbol"] == "AAPL"
assert rows[0]["raw_news_count"] == 2
assert rows[0]["analyzed_news_count"] == 1
assert rows[0]["coverage_pct"] == 50.0
assert rows[0]["llm_count"] == 1

View File

@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
"""Unit tests for the news domain helpers."""
from backend.domains import news as news_domain
class _FakeStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-03-10"}
def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit))
target = trade_date or end_date
return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}]
def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date))
return [{"date": end_date, "count": 1}]
def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit))
return {"macro": {"count": 1}}
def get_news_by_ids_enriched(self, ticker, article_ids):
self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids)))
return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}]
def test_news_rows_need_enrichment_detects_missing_fields():
assert news_domain.news_rows_need_enrichment([]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False
def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == [("AAPL", "2026-03-16")]
assert payload["target_date"] == "2026-03-16"
assert payload["refreshed"] is True
def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
store,
"get_ticker_watermarks",
lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"},
)
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == []
assert payload["refreshed"] is False
def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
payload = news_domain.get_enriched_news(
store,
ticker="AAPL",
start_date="2026-03-01",
end_date="2026-03-16",
limit=20,
)
assert payload["ticker"] == "AAPL"
assert payload["news"][0]["ticker"] == "AAPL"
assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16"
assert store.calls == [
("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20)
]
def test_get_story_and_similar_days_delegate(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1})
monkeypatch.setattr(
news_domain,
"get_or_create_stock_story",
lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"},
)
monkeypatch.setattr(
news_domain,
"find_similar_days",
lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]},
)
story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16")
similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8)
assert story["story"] == "story"
assert "freshness" in story
assert similar["items"][0]["score"] == 0.9
assert "freshness" in similar
def test_get_enriched_news_defaults_to_read_only_freshness(monkeypatch):
store = _FakeStore()
ensure_calls = []
def fake_ensure(store, ticker, target_date=None, refresh_if_stale=False):
ensure_calls.append(refresh_if_stale)
return {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
}
monkeypatch.setattr(news_domain, "ensure_news_fresh", fake_ensure)
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
payload = news_domain.get_enriched_news(
store,
ticker="AAPL",
end_date="2026-03-16",
)
assert payload["ticker"] == "AAPL"
assert ensure_calls == [False]
def test_get_range_explain_payload_uses_article_ids(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)},
)
payload = news_domain.get_range_explain_payload(
store,
ticker="AAPL",
start_date="2026-03-10",
end_date="2026-03-16",
article_ids=["news-9"],
limit=50,
)
assert payload["ticker"] == "AAPL"
assert payload["result"] == {"ticker": "AAPL", "count": 1}
assert "freshness" in payload
assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])]

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
from backend.enrich import news_enricher
def test_classify_news_row_falls_back_to_local_rules(monkeypatch):
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
result = news_enricher.classify_news_row(
{
"title": "Apple shares drop after weak guidance",
"summary": "Investors reacted negatively to softer-than-expected outlook.",
}
)
assert result["analysis_source"] == "local"
assert result["sentiment"] == "negative"
assert result["summary"]
def test_classify_news_row_prefers_llm_when_available(monkeypatch):
monkeypatch.setattr(
news_enricher,
"analyze_news_row_with_llm",
lambda row: {
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Demand resilience",
"summary": "LLM summary",
"reason_growth": "Orders remain strong",
"reason_decrease": "",
"raw_json": {"provider": "llm"},
},
)
result = news_enricher.classify_news_row(
{
"title": "Apple expands AI features",
"summary": "New devices and software updates were announced.",
}
)
assert result["analysis_source"] == "llm"
assert result["sentiment"] == "positive"
assert result["summary"] == "LLM summary"
def test_build_analysis_rows_prefers_batch_llm_and_dedupes(monkeypatch):
monkeypatch.setattr(news_enricher, "llm_enrichment_enabled", lambda: True)
monkeypatch.setattr(news_enricher, "get_env_int", lambda key, default=0: 8)
monkeypatch.setattr(
news_enricher,
"analyze_news_rows_with_llm",
lambda rows: {
"news-1": {
"relevance": "high",
"sentiment": "positive",
"key_discussion": "Batch result",
"summary": "Batch summary",
"reason_growth": "Growth",
"reason_decrease": "",
"raw_json": {"provider": "batch"},
}
},
)
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
rows = news_enricher.build_analysis_rows(
symbol="AAPL",
news_rows=[
{"id": "news-1", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
{"id": "news-2", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
],
ohlc_rows=[],
)
rows, stats = rows
assert len(rows) == 1
assert rows[0]["analysis_source"] == "llm"
assert rows[0]["summary"] == "Batch summary"
assert stats["deduped_count"] == 1
assert stats["llm_count"] == 1
def test_enrich_news_for_symbol_skips_existing(monkeypatch):
class DummyStore:
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
return [
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
]
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
return {"news-1"}
def get_ohlc(self, symbol, start_date, end_date):
return []
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
self.rows = rows
return len(rows)
monkeypatch.setattr(
news_enricher,
"build_analysis_rows",
lambda symbol, news_rows, ohlc_rows: (
[
{
"news_id": row["id"],
"trade_date": row["trade_date"],
"summary": row["summary"],
"analysis_source": "local",
}
for row in news_rows
],
{"deduped_count": 0, "llm_count": 0, "local_count": len(news_rows)},
),
)
store = DummyStore()
result = news_enricher.enrich_news_for_symbol(store, "AAPL")
assert result["news_count"] == 2
assert result["queued_count"] == 1
assert result["skipped_existing_count"] == 1
assert len(store.rows) == 1
assert store.rows[0]["news_id"] == "news-2"
def test_enrich_news_for_symbol_only_reanalyzes_local(monkeypatch):
class DummyStore:
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
return [
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
{"id": "news-3", "trade_date": "2026-03-12", "title": "Three", "summary": "Three"},
]
def get_analyzed_news_sources(self, symbol, start_date=None, end_date=None):
return {"news-1": "local", "news-2": "llm"}
def get_ohlc(self, symbol, start_date, end_date):
return []
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
self.rows = rows
return len(rows)
monkeypatch.setattr(
news_enricher,
"build_analysis_rows",
lambda symbol, news_rows, ohlc_rows: (
[
{
"news_id": row["id"],
"trade_date": row["trade_date"],
"summary": row["summary"],
"analysis_source": "llm" if row["id"] == "news-1" else "local",
}
for row in news_rows
],
{"deduped_count": 0, "llm_count": 1, "local_count": 0},
),
)
store = DummyStore()
result = news_enricher.enrich_news_for_symbol(
store,
"AAPL",
only_reanalyze_local=True,
)
assert result["news_count"] == 3
assert result["queued_count"] == 1
assert result["skipped_existing_count"] == 2
assert result["only_reanalyze_local"] is True
assert result["upgraded_local_to_llm_count"] == 1
assert result["execution_summary"]["upgraded_dates"] == ["2026-03-10"]
assert result["execution_summary"]["remaining_local_titles"] == []
assert result["execution_summary"]["skipped_missing_analysis_count"] == 1
assert result["execution_summary"]["skipped_non_local_count"] == 1
assert [row["news_id"] for row in store.rows] == ["news-1"]

View File

@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted news service app surface."""
from fastapi.testclient import TestClient
from backend.apps.news_service import create_app
class _FakeStore:
def get_ticker_watermarks(self, symbol):
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None):
return [{"date": end_date, "count": 1}]
def get_news_items(self, symbol, start_date=None, end_date=None, limit=100):
return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}]
def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100):
return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}]
def upsert_news_analysis(self, symbol, rows):
return len(rows)
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
return set()
def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200):
return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}}
def get_news_by_ids_enriched(self, symbol, article_ids):
return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}]
def test_news_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/enriched-news" in paths
assert "/api/news-for-date" in paths
assert "/api/news-timeline" in paths
assert "/api/categories" in paths
assert "/api/similar-days" in paths
assert "/api/stories/{ticker}" in paths
assert "/api/range-explain" in paths
def test_news_service_enriched_news_and_categories(monkeypatch):
app = create_app()
app.dependency_overrides.clear()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
news_response = client.get(
"/api/enriched-news",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
categories_response = client.get(
"/api/categories",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["ticker"] == "AAPL"
assert categories_response.status_code == 200
assert categories_response.json()["categories"]["market"]["count"] == 1
def test_news_service_news_for_date_and_timeline(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
date_response = client.get(
"/api/news-for-date",
params={"ticker": "AAPL", "date": "2026-03-23"},
)
timeline_response = client.get(
"/api/news-timeline",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
},
)
assert date_response.status_code == 200
assert date_response.json()["date"] == "2026-03-23"
assert timeline_response.status_code == 200
assert timeline_response.json()["timeline"][0]["count"] == 1
def test_news_service_similar_days_and_story(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.find_similar_days",
lambda store, symbol, target_date, top_k: {
"symbol": symbol,
"target_date": target_date,
"items": [{"date": "2026-03-20", "score": 0.9}],
},
)
monkeypatch.setattr(
"backend.domains.news.get_or_create_stock_story",
lambda store, symbol, as_of_date: {
"symbol": symbol,
"as_of_date": as_of_date,
"story": "story body",
"source": "local",
},
)
with TestClient(app) as client:
similar_response = client.get(
"/api/similar-days",
params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3},
)
story_response = client.get(
"/api/stories/AAPL",
params={"as_of_date": "2026-03-23"},
)
assert similar_response.status_code == 200
assert similar_response.json()["items"][0]["score"] == 0.9
assert story_response.status_code == 200
assert story_response.json()["story"] == "story body"
def test_news_service_range_explain(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {
"symbol": ticker,
"news_count": len(news_rows),
"start_date": start_date,
"end_date": end_date,
},
)
with TestClient(app) as client:
response = client.get(
"/api/range-explain",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
"article_ids": ["news-7"],
},
)
assert response.status_code == 200
assert response.json()["result"]["news_count"] == 1

View File

@@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
"""Tests for the OpenClaw CLI service wrapper."""
from pathlib import Path
import pytest
from backend.services.openclaw_cli import OpenClawCliError, OpenClawCliService
class _Completed:
def __init__(self, *, returncode=0, stdout="", stderr=""):
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
def test_openclaw_cli_service_runs_json_command(monkeypatch, tmp_path):
captured = {}
def _fake_run(command, **kwargs):
captured["command"] = command
captured["cwd"] = kwargs["cwd"]
return _Completed(stdout='{"sessions":[{"key":"main/session-1"}]}')
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
payload = service.list_sessions()
assert payload["sessions"][0]["key"] == "main/session-1"
assert captured["command"] == ["openclaw", "sessions", "--json"]
assert captured["cwd"] == tmp_path
def test_openclaw_cli_service_raises_on_failure(monkeypatch, tmp_path):
def _fake_run(command, **kwargs):
return _Completed(returncode=7, stdout="", stderr="boom")
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
with pytest.raises(OpenClawCliError) as exc_info:
service.list_cron_jobs()
assert exc_info.value.exit_code == 7
assert exc_info.value.stderr == "boom"
def test_openclaw_cli_service_can_extract_single_session(monkeypatch, tmp_path):
def _fake_run(command, **kwargs):
return _Completed(stdout='{"sessions":[{"key":"main/session-1","agentId":"main"}]}')
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
session = service.get_session("main/session-1")
assert session["agentId"] == "main"

View File

@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted OpenClaw service app surface."""
from fastapi.testclient import TestClient
from backend.apps.openclaw_service import create_app
from backend.api import openclaw as openclaw_module
class _FakeOpenClawCliService:
def health(self):
return {
"status": "healthy",
"service": "openclaw-service",
"base_command": ["openclaw"],
"cwd": "/tmp/openclaw",
"binary_resolved": True,
"reference_entry_available": True,
"timeout_seconds": 15.0,
}
def status(self):
return {"runtimeVersion": "2026.3.24"}
def list_sessions(self):
return {
"sessions": [
{"key": "main/session-1", "agentId": "main"},
{"key": "analyst/session-2", "agentId": "analyst"},
]
}
def get_session(self, session_key: str):
for session in self.list_sessions()["sessions"]:
if session["key"] == session_key:
return session
raise KeyError(session_key)
def get_session_history(self, session_key: str, *, limit: int = 20):
return {
"sessionKey": session_key,
"limit": limit,
"items": [{"role": "assistant", "text": "hello"}],
}
def list_cron_jobs(self):
return {"jobs": [{"id": "job-1", "name": "Daily sync"}]}
def list_approvals(self):
return {"approvals": [{"id": "ap-1", "status": "pending"}]}
def test_openclaw_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/openclaw/status" in paths
assert "/api/openclaw/sessions" in paths
assert "/api/openclaw/sessions/{session_key:path}" in paths
assert "/api/openclaw/sessions/{session_key:path}/history" in paths
assert "/api/openclaw/cron" in paths
assert "/api/openclaw/approvals" in paths
def test_openclaw_service_read_routes():
app = create_app()
app.dependency_overrides[openclaw_module.get_openclaw_cli_service] = (
lambda: _FakeOpenClawCliService()
)
with TestClient(app) as client:
health = client.get("/health")
status = client.get("/api/status")
openclaw_status = client.get("/api/openclaw/status")
sessions = client.get("/api/openclaw/sessions")
session = client.get("/api/openclaw/sessions/main/session-1")
history = client.get("/api/openclaw/sessions/main/session-1/history", params={"limit": 5})
cron = client.get("/api/openclaw/cron")
approvals = client.get("/api/openclaw/approvals")
assert health.status_code == 200
assert health.json()["service"] == "openclaw-service"
assert status.status_code == 200
assert status.json()["status"] == "operational"
assert openclaw_status.status_code == 200
assert openclaw_status.json()["runtimeVersion"] == "2026.3.24"
assert sessions.status_code == 200
assert len(sessions.json()["sessions"]) == 2
assert session.status_code == 200
assert session.json()["session"]["agentId"] == "main"
assert history.status_code == 200
assert history.json()["limit"] == 5
assert cron.status_code == 200
assert cron.json()["jobs"][0]["id"] == "job-1"
assert approvals.status_code == 200
assert approvals.json()["approvals"][0]["id"] == "ap-1"
def test_openclaw_service_session_404():
app = create_app()
app.dependency_overrides[openclaw_module.get_openclaw_cli_service] = (
lambda: _FakeOpenClawCliService()
)
with TestClient(app) as client:
response = client.get("/api/openclaw/sessions/missing")
assert response.status_code == 404

View File

@@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
"""Tests for the OpenClaw WebSocket client session helpers."""
import pytest
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient
@pytest.mark.asyncio
async def test_resolve_session_parses_gateway_key_response():
client = OpenClawWebSocketClient(gateway_token="test-token")
async def fake_send_request(method, params=None, _allow_handshake=False):
assert method == "sessions.resolve"
assert params["agentId"] == "main"
return {"ok": True, "key": "agent:main:main"}
client._send_request = fake_send_request # type: ignore[method-assign]
resolved = await client.resolve_session(agent_id="main")
assert resolved == "agent:main:main"
@pytest.mark.asyncio
async def test_send_message_uses_session_send_payload():
client = OpenClawWebSocketClient(gateway_token="test-token")
async def fake_send_request(method, params=None, _allow_handshake=False):
assert method == "sessions.send"
assert params == {
"key": "agent:main:main",
"message": "hello",
"thinking": "medium",
}
return {"ok": True, "runId": "run-1"}
client._send_request = fake_send_request # type: ignore[method-assign]
result = await client.send_message("agent:main:main", "hello", thinking="medium")
assert result["runId"] == "run-1"
@pytest.mark.asyncio
async def test_get_session_history_uses_sessions_preview():
client = OpenClawWebSocketClient(gateway_token="test-token")
async def fake_send_request(method, params=None, _allow_handshake=False):
assert method == "sessions.preview"
assert params == {"keys": ["agent:main:main"], "limit": 12}
return {"previews": []}
client._send_request = fake_send_request # type: ignore[method-assign]
result = await client.get_session_history("agent:main:main", limit=12)
assert result == {"previews": []}
@pytest.mark.asyncio
async def test_unsubscribe_uses_session_messages_unsubscribe():
client = OpenClawWebSocketClient(gateway_token="test-token")
async def fake_send_request(method, params=None, _allow_handshake=False):
assert method == "sessions.messages.unsubscribe"
assert params == {"key": "agent:main:main"}
return {"subscribed": False}
client._send_request = fake_send_request # type: ignore[method-assign]
result = await client.unsubscribe("agent:main:main")
assert result == {"subscribed": False}

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""Tests for provider router fallback behavior."""
from backend.data.provider_router import DataProviderRouter
from backend.config.data_config import reset_config
def test_router_includes_local_csv_fallback(monkeypatch):
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False)
reset_config()
router = DataProviderRouter()
assert router.price_sources() == ["local_csv"]
def test_router_allows_yfinance_when_enabled(monkeypatch):
monkeypatch.setenv("FIN_DATA_SOURCE", "yfinance")
monkeypatch.setenv("ENABLED_DATA_SOURCES", "yfinance,local_csv")
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
reset_config()
router = DataProviderRouter()
assert router.price_sources() == ["yfinance", "local_csv"]
assert router.api_sources() == ["yfinance"]

View File

@@ -0,0 +1,15 @@
# -*- coding: utf-8 -*-
"""Tests for market symbol normalization helpers."""
from backend.data.provider_utils import describe_symbol, normalize_symbol
def test_normalize_symbol_exchange_prefix():
assert normalize_symbol("sh600519") == "600519"
assert normalize_symbol("600519.SH") == "600519"
def test_normalize_symbol_us_ticker():
symbol = describe_symbol("aapl")
assert symbol.canonical == "AAPL"
assert symbol.market == "us"

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from types import SimpleNamespace
from backend.explain import range_explainer
def test_build_range_explanation_prefers_llm_text_when_available(monkeypatch):
monkeypatch.setattr(
range_explainer,
"get_prices",
lambda ticker, start_date, end_date: [
SimpleNamespace(open=100, close=98, high=102, low=97, volume=1000),
SimpleNamespace(open=98, close=96, high=99, low=95, volume=1100),
SimpleNamespace(open=96, close=97, high=98, low=94, volume=1200),
],
)
monkeypatch.setattr(
range_explainer,
"analyze_range_with_llm",
lambda payload: {
"summary": "区间内整体偏弱,主题集中在盈利预期和供应链风险。",
"trend_analysis": "前半段快速下探,后半段出现修复。",
"bullish_factors": ["回调后出现承接"],
"bearish_factors": ["盈利预期承压"],
"model_label": "DASHSCOPE:qwen-max",
},
)
result = range_explainer.build_range_explanation(
ticker="AAPL",
start_date="2026-03-10",
end_date="2026-03-16",
news_rows=[
{
"id": "news-1",
"trade_date": "2026-03-10",
"title": "Apple margin pressure concerns grow",
"summary": "Investors focused on weaker margin outlook.",
"sentiment": "negative",
"relevance": "high",
"ret_t0": -0.02,
"reason_decrease": "盈利预期承压",
"category": "earnings",
}
],
)
assert result["analysis"]["summary"] == "区间内整体偏弱,主题集中在盈利预期和供应链风险。"
assert result["analysis"]["trend_analysis"] == "前半段快速下探,后半段出现修复。"
assert result["analysis"]["bullish_factors"] == ["回调后出现承接"]
assert result["analysis"]["analysis_source"] == "llm"
assert result["analysis"]["analysis_model_label"] == "DASHSCOPE:qwen-max"
assert result["news_count"] == 1

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted runtime service app surface."""
import json
from pathlib import Path
from fastapi.testclient import TestClient
from backend.api import runtime as runtime_module
from backend.apps.runtime_service import create_app
def test_runtime_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/runtime/start" in paths
assert "/api/runtime/stop" in paths
assert "/api/runtime/cleanup" in paths
assert "/api/runtime/history" in paths
assert "/api/runtime/current" in paths
assert "/api/runtime/gateway/port" in paths
def test_runtime_service_health_and_status(monkeypatch):
runtime_state = runtime_module.get_runtime_state()
runtime_state.gateway_process = None
runtime_state.gateway_port = 9876
runtime_state.runtime_manager = object()
with TestClient(create_app()) as client:
health_response = client.get("/health")
status_response = client.get("/api/status")
assert health_response.status_code == 200
assert health_response.json() == {
"status": "healthy",
"service": "runtime-service",
"gateway_running": False,
"gateway_port": 9876,
}
assert status_response.status_code == 200
assert status_response.json() == {
"status": "operational",
"service": "runtime-service",
"runtime": {
"gateway_running": False,
"gateway_port": 9876,
"has_runtime_manager": True,
},
}
def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch):
runtime_module.get_runtime_state().gateway_port = 9345
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
with TestClient(create_app()) as client:
response = client.get(
"/api/runtime/gateway/port",
headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"},
)
assert response.status_code == 200
assert response.json() == {
"port": 9345,
"is_running": True,
"ws_url": "wss://runtime.example:9345",
}
def test_runtime_service_get_runtime_config(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: intraday\n"
"interval_minutes: 30\n"
"trigger_time: '10:00'\n"
"max_comm_cycles: 3\n"
"enable_memory: true\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:00",
"max_comm_cycles": 3,
"enable_memory": True,
},
}
}
),
encoding="utf-8",
)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.get("/api/runtime/config")
assert response.status_code == 200
payload = response.json()
assert payload["run_id"] == "demo"
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 30
assert payload["resolved"]["enable_memory"] is True
def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: daily\n"
"interval_minutes: 60\n"
"trigger_time: '09:30'\n"
"max_comm_cycles: 2\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
},
}
}
),
encoding="utf-8",
)
class _DummyContext:
def __init__(self):
self.bootstrap_values = {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
}
class _DummyManager:
def __init__(self):
self.config_name = "demo"
self.bootstrap = dict(_DummyContext().bootstrap_values)
self.context = _DummyContext()
def _persist_snapshot(self):
return None
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().runtime_manager = _DummyManager()
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.put(
"/api/runtime/config",
json={
"schedule_mode": "intraday",
"interval_minutes": 15,
"trigger_time": "10:15",
"max_comm_cycles": 4,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 15
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
def test_prune_old_timestamped_runs_keeps_named_runs(monkeypatch, tmp_path):
runs_dir = tmp_path / "runs"
runs_dir.mkdir()
keep_dirs = ["20260324_110000", "20260324_120000"]
prune_dir = "20260324_100000"
named_dir = "smoke_fullstack"
for name in [*keep_dirs, prune_dir, named_dir]:
(runs_dir / name).mkdir(parents=True)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
pruned = runtime_module._prune_old_timestamped_runs(keep=1, exclude_run_ids={"20260324_120000"})
assert prune_dir in pruned
assert (runs_dir / named_dir).exists()
assert (runs_dir / "20260324_120000").exists()
assert (runs_dir / "20260324_110000").exists()
def test_runtime_cleanup_endpoint_prunes_old_runs(monkeypatch, tmp_path):
runs_dir = tmp_path / "runs"
runs_dir.mkdir()
for name in ["20260324_090000", "20260324_100000", "20260324_110000", "smoke_fullstack"]:
(runs_dir / name).mkdir(parents=True)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: False)
with TestClient(create_app()) as client:
response = client.post("/api/runtime/cleanup?keep=1")
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "ok"
assert sorted(payload["pruned_run_ids"]) == ["20260324_090000", "20260324_100000"]
assert (runs_dir / "20260324_110000").exists()
assert (runs_dir / "smoke_fullstack").exists()
def test_runtime_history_lists_recent_runs(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "20260324_120000"
(run_dir / "state").mkdir(parents=True)
(run_dir / "team_dashboard").mkdir(parents=True)
(run_dir / "state" / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "20260324_120000",
"run_dir": str(run_dir),
"bootstrap_values": {"tickers": ["AAPL"]},
},
"events": [],
}
),
encoding="utf-8",
)
(run_dir / "team_dashboard" / "summary.json").write_text(
json.dumps({"totalTrades": 3, "totalAssetValue": 123456.0}),
encoding="utf-8",
)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
with TestClient(create_app()) as client:
response = client.get("/api/runtime/history?limit=5")
assert response.status_code == 200
payload = response.json()
assert payload["runs"][0]["run_id"] == "20260324_120000"
assert payload["runs"][0]["total_trades"] == 3
def test_restore_run_assets_copies_state(monkeypatch, tmp_path):
source_run = tmp_path / "runs" / "20260324_100000"
(source_run / "team_dashboard").mkdir(parents=True)
(source_run / "state").mkdir(parents=True)
(source_run / "agents").mkdir(parents=True)
(source_run / "team_dashboard" / "_internal_state.json").write_text("{}", encoding="utf-8")
(source_run / "state" / "server_state.json").write_text("{}", encoding="utf-8")
target_run = tmp_path / "runs" / "20260324_130000"
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
runtime_module._restore_run_assets("20260324_100000", target_run)
assert (target_run / "team_dashboard" / "_internal_state.json").exists()
assert (target_run / "state" / "server_state.json").exists()
def test_start_runtime_restore_reuses_historical_run_id(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "20260324_100000"
(run_dir / "state").mkdir(parents=True)
(run_dir / "state" / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "20260324_100000",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "now",
"max_comm_cycles": 2,
"initial_cash": 100000.0,
"margin_requirement": 0.0,
"enable_memory": False,
"mode": "live",
"poll_interval": 10,
},
}
}
),
encoding="utf-8",
)
class _DummyManager:
def __init__(self, config_name, run_dir, bootstrap):
self.config_name = config_name
self.run_dir = Path(run_dir)
self.bootstrap = bootstrap
self.context = None
def prepare_run(self):
self.context = type(
"Ctx",
(),
{
"config_name": self.config_name,
"run_dir": self.run_dir,
"bootstrap_values": self.bootstrap,
},
)()
return self.context
class _DummyProcess:
def poll(self):
return None
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_find_available_port", lambda start_port=8765, max_port=9000: 8765)
monkeypatch.setattr(runtime_module, "_start_gateway_process", lambda **kwargs: _DummyProcess())
monkeypatch.setattr(runtime_module, "_stop_gateway", lambda: True)
monkeypatch.setattr("backend.runtime.manager.TradingRuntimeManager", _DummyManager)
runtime_state = runtime_module.get_runtime_state()
runtime_state.gateway_process = None
with TestClient(create_app()) as client:
response = client.post(
"/api/runtime/start",
json={
"launch_mode": "restore",
"restore_run_id": "20260324_100000",
"tickers": [],
},
)
assert response.status_code == 200
payload = response.json()
assert payload["run_id"] == "20260324_100000"
assert payload["run_dir"] == str(run_dir)

View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
"""Tests for split-aware shared service clients."""
import pytest
from shared.client.control_client import ControlPlaneClient
from shared.client.openclaw_client import OpenClawServiceClient
from shared.client.runtime_client import RuntimeServiceClient
class _DummyResponse:
def __init__(self, payload):
self._payload = payload
def raise_for_status(self):
return None
def json(self):
return self._payload
class _DummyAsyncClient:
def __init__(self):
self.calls = []
async def get(self, path, params=None):
self.calls.append(("get", path, params))
return _DummyResponse({"path": path, "params": params})
async def post(self, path, json=None):
self.calls.append(("post", path, json))
return _DummyResponse({"path": path, "json": json})
async def put(self, path, json=None):
self.calls.append(("put", path, json))
return _DummyResponse({"path": path, "json": json})
async def aclose(self):
return None
@pytest.mark.asyncio
async def test_control_plane_client_hits_current_workspace_and_guard_routes():
client = ControlPlaneClient()
client._client = _DummyAsyncClient()
await client.list_workspaces()
await client.get_workspace("demo")
await client.list_agents("demo")
await client.get_agent("demo", "risk_manager")
await client.fetch_pending_approvals()
await client.approve_pending_approval("ap-1")
await client.deny_pending_approval("ap-2", reason="nope")
assert client._client.calls == [
("get", "/workspaces", None),
("get", "/workspaces/demo", None),
("get", "/workspaces/demo/agents", None),
("get", "/workspaces/demo/agents/risk_manager", None),
("get", "/guard/pending", None),
(
"post",
"/guard/approve",
{
"approval_id": "ap-1",
"one_time": True,
"expires_in_minutes": 30,
},
),
(
"post",
"/guard/deny",
{
"approval_id": "ap-2",
"reason": "nope",
},
),
]
@pytest.mark.asyncio
async def test_runtime_service_client_hits_current_runtime_routes():
client = RuntimeServiceClient()
client._client = _DummyAsyncClient()
await client.fetch_context()
await client.fetch_agents()
await client.fetch_events()
await client.fetch_gateway_port()
await client.start_runtime({"tickers": ["AAPL"]})
await client.stop_runtime(force=True)
await client.restart_runtime({"tickers": ["MSFT"]})
await client.fetch_current_runtime()
await client.get_runtime_config()
await client.update_runtime_config({"schedule_mode": "intraday"})
assert client._client.calls == [
("get", "/context", None),
("get", "/agents", None),
("get", "/events", None),
("get", "/gateway/port", None),
("post", "/start", {"tickers": ["AAPL"]}),
("post", "/stop?force=true", None),
("post", "/restart", {"tickers": ["MSFT"]}),
("get", "/current", None),
("get", "/config", None),
("put", "/config", {"schedule_mode": "intraday"}),
]
@pytest.mark.asyncio
async def test_openclaw_service_client_hits_current_openclaw_routes():
client = OpenClawServiceClient()
client._client = _DummyAsyncClient()
await client.fetch_status()
await client.list_sessions()
await client.get_session("main/session-1")
await client.get_session_history("main/session-1", limit=5)
await client.list_cron_jobs()
await client.list_approvals()
assert client._client.calls == [
("get", "/status", None),
("get", "/sessions", None),
("get", "/sessions/main/session-1", None),
("get", "/sessions/main/session-1/history", {"limit": 5}),
("get", "/cron", None),
("get", "/approvals", None),
]

View File

@@ -0,0 +1,201 @@
# -*- coding: utf-8 -*-
"""
Test Settlement Coordinator and Baseline Calculations
"""
from backend.utils.baselines import (
BaselineCalculator,
calculate_momentum_scores,
)
from backend.utils.analyst_tracker import (
AnalystPerformanceTracker,
update_leaderboard_with_evaluations,
)
def test_baseline_equal_weight():
"""Test equal weight baseline calculation"""
calculator = BaselineCalculator(initial_capital=100000.0)
tickers = ["AAPL", "MSFT", "GOOGL"]
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
value = calculator.calculate_equal_weight_value(
tickers,
openprices,
prices,
)
assert value > 0
assert calculator.equal_weight_initialized is True
def test_baseline_market_cap_weighted():
"""Test market cap weighted baseline calculation"""
calculator = BaselineCalculator(initial_capital=100000.0)
tickers = ["AAPL", "MSFT", "GOOGL"]
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
market_caps = {"AAPL": 3e12, "MSFT": 2e12, "GOOGL": 1.5e12}
value = calculator.calculate_market_cap_weighted_value(
tickers,
openprices,
prices,
market_caps,
)
assert value > 0
assert calculator.market_cap_initialized is True
def test_momentum_scores():
"""Test momentum score calculation"""
tickers = ["AAPL", "MSFT"]
prices_history = {
"AAPL": [
("2024-01-01", 100.0),
("2024-01-02", 105.0),
("2024-01-03", 110.0),
],
"MSFT": [
("2024-01-01", 200.0),
("2024-01-02", 195.0),
("2024-01-03", 190.0),
],
}
scores = calculate_momentum_scores(
tickers,
prices_history,
lookback_days=2,
)
assert scores["AAPL"] > 0
assert scores["MSFT"] < 0
def test_analyst_tracker_predictions():
"""Test analyst prediction recording with structured format"""
tracker = AnalystPerformanceTracker()
final_predictions = [
{
"agent": "technical_analyst",
"predictions": [
{"ticker": "AAPL", "direction": "up", "confidence": 0.8},
{"ticker": "MSFT", "direction": "down", "confidence": 0.7},
{"ticker": "GOOGL", "direction": "neutral", "confidence": 0.5},
],
},
{
"agent": "fundamentals_analyst",
"predictions": [
{"ticker": "AAPL", "direction": "up", "confidence": 0.9},
{"ticker": "MSFT", "direction": "up", "confidence": 0.6},
{"ticker": "GOOGL", "direction": "down", "confidence": 0.75},
],
},
]
tracker.record_analyst_predictions(final_predictions)
assert "technical_analyst" in tracker.daily_predictions
assert "fundamentals_analyst" in tracker.daily_predictions
assert tracker.daily_predictions["technical_analyst"]["AAPL"] == "long"
assert tracker.daily_predictions["technical_analyst"]["MSFT"] == "short"
assert tracker.daily_predictions["technical_analyst"]["GOOGL"] == "hold"
def test_analyst_evaluation():
"""Test analyst prediction evaluation"""
tracker = AnalystPerformanceTracker()
tracker.daily_predictions = {
"technical_analyst": {
"AAPL": "long",
"MSFT": "short",
},
}
open_prices = {"AAPL": 100.0, "MSFT": 200.0}
close_prices = {"AAPL": 105.0, "MSFT": 195.0}
evaluations = tracker.evaluate_predictions(
open_prices,
close_prices,
"2024-01-15",
)
assert "technical_analyst" in evaluations
eval_result = evaluations["technical_analyst"]
assert eval_result["correct_predictions"] == 2
assert eval_result["win_rate"] == 1.0
# Verify individual signals format
assert "signals" in eval_result
assert len(eval_result["signals"]) == 2
for signal in eval_result["signals"]:
assert "ticker" in signal
assert "signal" in signal
assert "date" in signal
assert "is_correct" in signal
assert signal["date"] == "2024-01-15"
def test_leaderboard_update():
"""Test leaderboard update with evaluations"""
leaderboard = [
{
"agentId": "technical_analyst",
"name": "Technical Analyst",
"rank": 0,
"winRate": None,
"bull": {"n": 0, "win": 0, "unknown": 0},
"bear": {"n": 0, "win": 0, "unknown": 0},
"signals": [],
},
]
evaluations = {
"technical_analyst": {
"total_predictions": 2,
"correct_predictions": 1,
"win_rate": 0.5,
"bull": {"n": 1, "win": 1, "unknown": 0},
"bear": {"n": 1, "win": 0, "unknown": 0},
"hold": 0,
"signals": [
{
"ticker": "AAPL",
"signal": "bull",
"date": "2024-01-01",
"is_correct": True,
},
{
"ticker": "MSFT",
"signal": "bear",
"date": "2024-01-01",
"is_correct": False,
},
],
},
}
updated = update_leaderboard_with_evaluations(
leaderboard,
evaluations,
)
assert updated[0]["bull"]["n"] == 1
assert updated[0]["bull"]["win"] == 1
assert updated[0]["winRate"] == 0.5
assert len(updated[0]["signals"]) == 2
# Verify signal format matches frontend expectations
for signal in updated[0]["signals"]:
assert "ticker" in signal
assert "signal" in signal
assert "date" in signal
assert "is_correct" in signal

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""Regression coverage for the shared schema bridge."""
from backend.data import schema as legacy_schema
from shared import schema as shared_schema
def test_backend_data_schema_reexports_shared_contracts():
assert legacy_schema.Price is shared_schema.Price
assert legacy_schema.PriceResponse is shared_schema.PriceResponse
assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics
assert legacy_schema.FinancialMetricsResponse is (
shared_schema.FinancialMetricsResponse
)
assert legacy_schema.LineItem is shared_schema.LineItem
assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse
assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade
assert legacy_schema.InsiderTradeResponse is (
shared_schema.InsiderTradeResponse
)
assert legacy_schema.CompanyNews is shared_schema.CompanyNews
assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse
assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts
assert legacy_schema.CompanyFactsResponse is (
shared_schema.CompanyFactsResponse
)
assert legacy_schema.Position is shared_schema.Position
assert legacy_schema.Portfolio is shared_schema.Portfolio
assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal
assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis
assert legacy_schema.AgentStateData is shared_schema.AgentStateData
assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
from backend import cli
from backend.agents.skill_metadata import parse_skill_metadata
from backend.agents.skills_manager import SkillsManager
from backend.agents.team_pipeline_config import (
ensure_team_pipeline_config,
load_team_pipeline_config,
update_active_analysts,
)
def test_parse_skill_metadata_extended_frontmatter(tmp_path):
skill_dir = tmp_path / "demo_skill"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: demo_skill\n"
"description: Demo description\n"
"tools:\n"
" - technical\n"
"---\n\n"
"# Demo Skill\n",
encoding="utf-8",
)
parsed = parse_skill_metadata(skill_dir, source="builtin")
assert parsed.skill_name == "demo_skill"
assert parsed.description == "Demo description"
assert parsed.tools == ["technical"]
def test_update_agent_skill_overrides(tmp_path):
manager = SkillsManager(project_root=tmp_path)
asset_dir = manager.get_agent_asset_dir("demo", "risk_manager")
asset_dir.mkdir(parents=True, exist_ok=True)
(asset_dir / "agent.yaml").write_text(
"enabled_skills:\n"
" - risk_review\n"
"disabled_skills:\n"
" - old_skill\n",
encoding="utf-8",
)
result = manager.update_agent_skill_overrides(
config_name="demo",
agent_id="risk_manager",
enable=["extra_guard"],
disable=["risk_review"],
)
assert result["enabled_skills"] == ["extra_guard"]
assert result["disabled_skills"] == ["old_skill", "risk_review"]
def test_skills_enable_disable_and_list(monkeypatch, tmp_path):
builtin_root = tmp_path / "backend" / "skills" / "builtin"
for name in ("risk_review", "extra_guard"):
skill_dir = builtin_root / name
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
f"---\nname: {name}\ndescription: {name} desc\n---\n",
encoding="utf-8",
)
printed = []
monkeypatch.setattr(cli, "get_project_root", lambda: tmp_path)
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
cli.skills_enable(agent_id="risk_manager", skill="extra_guard", config_name="demo")
cli.skills_disable(agent_id="risk_manager", skill="risk_review", config_name="demo")
cli.skills_list(config_name="demo", agent_id="risk_manager")
text_dump = "\n".join(str(item) for item in printed)
assert "Enabled" in text_dump
assert "Disabled" in text_dump
assert any(getattr(item, "title", None) == "Skill Catalog" for item in printed)
def test_install_external_skill_for_agent(tmp_path):
manager = SkillsManager(project_root=tmp_path)
skill_dir = tmp_path / "downloaded" / "new_skill"
skill_dir.mkdir(parents=True, exist_ok=True)
(skill_dir / "SKILL.md").write_text(
"---\n"
"name: new_skill\n"
"description: external skill\n"
"---\n\n"
"# New Skill\n",
encoding="utf-8",
)
result = manager.install_external_skill_for_agent(
config_name="demo",
agent_id="risk_manager",
source=str(skill_dir),
activate=True,
)
assert result["skill_name"] == "new_skill"
target = manager.get_agent_local_root("demo", "risk_manager") / "new_skill"
assert target.exists()
def test_team_pipeline_active_analyst_updates(tmp_path):
project_root = tmp_path
ensure_team_pipeline_config(
project_root=project_root,
config_name="demo",
default_analysts=["fundamentals_analyst", "technical_analyst"],
)
update_active_analysts(
project_root=project_root,
config_name="demo",
available_analysts=["fundamentals_analyst", "technical_analyst"],
remove=["technical_analyst"],
)
config = load_team_pipeline_config(project_root, "demo")
assert config["discussion"]["active_analysts"] == ["fundamentals_analyst"]

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""Tests for structured technical analyzer."""
import pandas as pd
from backend.tools.technical_signals import StockTechnicalAnalyzer
def test_technical_analyzer_detects_bullish_trend():
df = pd.DataFrame(
{
"time": pd.date_range("2024-01-01", periods=40, freq="D"),
"close": [100 + i for i in range(40)],
},
)
analyzer = StockTechnicalAnalyzer()
result = analyzer.analyze("AAPL", df)
assert result.current_price == 139.0
assert result.trend in {"BULLISH", "STRONG BULLISH"}
assert result.momentum_20d_pct > 0

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Unit tests for the trading domain helpers."""
from backend.domains import trading as trading_domain
def test_trading_domain_payload_wrappers(monkeypatch):
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
"ticker": "AAPL",
"prices": [{"close": 1}],
}
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
"financial_metrics": [{"ticker": "AAPL"}],
}
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
"news": [{"ticker": "AAPL"}],
}
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
"insider_trades": [{"ticker": "AAPL"}],
}
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
"ticker": "AAPL",
"end_date": "2026-03-16",
"market_cap": 2.5e12,
}
def test_get_market_status_payload_uses_market_service(monkeypatch):
class _FakeMarketService:
def __init__(self, tickers):
self.tickers = tickers
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
assert trading_domain.get_market_status_payload() == {
"status": "open",
"status_text": "Open",
}

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted trading service app surface."""
from fastapi.testclient import TestClient
from backend.apps.trading_service import create_app
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_trading_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/prices" in paths
assert "/api/financials" in paths
assert "/api/news" in paths
assert "/api/insider-trades" in paths
assert "/api/market/status" in paths
assert "/api/market-cap" in paths
assert "/api/line-items" in paths
def test_trading_service_prices_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_prices_payload",
lambda ticker, start_date, end_date: {
"ticker": ticker,
"prices": [
Price(
open=1.0,
close=2.0,
high=2.5,
low=0.5,
volume=100,
time="2026-03-20",
)
],
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/prices",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-20",
},
)
assert response.status_code == 200
assert response.json()["ticker"] == "AAPL"
assert response.json()["prices"][0]["close"] == 2.0
def test_trading_service_financials_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_financials_payload",
lambda ticker, end_date, period, limit: {
"financial_metrics": [
FinancialMetrics(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/financials",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json()["financial_metrics"][0]["ticker"] == "AAPL"
def test_trading_service_news_and_insider_endpoints(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_news_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"news": [
CompanyNews(
ticker=ticker,
title="News title",
source="polygon",
url="https://example.com/news",
date=end_date,
)
]
},
)
monkeypatch.setattr(
"backend.domains.trading.get_insider_trades_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"insider_trades": [
InsiderTrade(ticker=ticker, filing_date=end_date)
]
},
)
with TestClient(create_app()) as client:
news_response = client.get(
"/api/news",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
insider_response = client.get(
"/api/insider-trades",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["title"] == "News title"
assert insider_response.status_code == 200
assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL"
def test_trading_service_market_status_endpoint(monkeypatch):
class _FakeMarketService:
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(
"backend.domains.trading.get_market_status_payload",
lambda: _FakeMarketService().get_market_status(),
)
with TestClient(create_app()) as client:
response = client.get("/api/market/status")
assert response.status_code == 200
assert response.json() == {"status": "open", "status_text": "Open"}
def test_trading_service_market_cap_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_market_cap_payload",
lambda ticker, end_date: {
"ticker": ticker,
"end_date": end_date,
"market_cap": 3.5e12,
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/market-cap",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json() == {
"ticker": "AAPL",
"end_date": "2026-03-20",
"market_cap": 3.5e12,
}
def test_trading_service_line_items_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_line_items_payload",
lambda ticker, line_items, end_date, period, limit: {
"search_results": [
LineItem(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
free_cash_flow=123.0,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/line-items",
params=[
("ticker", "AAPL"),
("line_items", "free_cash_flow"),
("end_date", "2026-03-20"),
],
)
assert response.status_code == 200
assert response.json()["search_results"][0]["ticker"] == "AAPL"
assert response.json()["search_results"][0]["free_cash_flow"] == 123.0

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
from backend.agents.skills_manager import SkillsManager
from backend.skills.builtin.valuation_review.scripts.dcf_report import (
build_dcf_report,
)
from backend.skills.builtin.valuation_review.scripts.multiple_valuation_report import (
build_ev_ebitda_report,
build_residual_income_report,
)
from backend.skills.builtin.valuation_review.scripts.owner_earnings_report import (
build_owner_earnings_report,
)
def test_build_dcf_report_renders_assessment():
report = build_dcf_report(
[
{
"ticker": "AAPL",
"current_fcf": 100.0,
"growth_rate": 0.05,
"market_cap": 900.0,
"discount_rate": 0.10,
"terminal_growth": 0.03,
"num_years": 5,
},
],
"2026-03-17",
)
assert "DCF Valuation Analysis (2026-03-17)" in report
assert "AAPL:" in report
assert "Market Cap: $900" in report
assert "Value Gap:" in report
def test_build_owner_earnings_report_handles_errors():
report = build_owner_earnings_report(
[
{
"ticker": "MSFT",
"error": "Negative owner earnings ($-50)",
},
],
"2026-03-17",
)
assert "MSFT: Negative owner earnings ($-50)" in report
def test_multiple_valuation_reports_render_expected_sections():
ev_report = build_ev_ebitda_report(
[
{
"ticker": "NVDA",
"current_multiple": 18.0,
"median_multiple": 20.0,
"current_ebitda": 50.0,
"market_cap": 800.0,
"net_debt": 100.0,
},
],
"2026-03-17",
)
residual_report = build_residual_income_report(
[
{
"ticker": "META",
"book_value": 200.0,
"initial_ri": 30.0,
"market_cap": 300.0,
"cost_of_equity": 0.10,
"bv_growth": 0.03,
"terminal_growth": 0.03,
"num_years": 5,
"margin_of_safety": 0.20,
},
],
"2026-03-17",
)
assert "EV/EBITDA Valuation (2026-03-17)" in ev_report
assert "NVDA:" in ev_report
assert "Residual Income Valuation (2026-03-17)" in residual_report
assert "META:" in residual_report
def test_prepare_active_skills_copies_skill_scripts(tmp_path):
builtin_skill = tmp_path / "backend" / "skills" / "builtin" / "valuation_review"
scripts_dir = builtin_skill / "scripts"
scripts_dir.mkdir(parents=True, exist_ok=True)
(builtin_skill / "SKILL.md").write_text(
"---\nname: 估值分析\ndescription: desc\nversion: 1.0.0\n---\n",
encoding="utf-8",
)
(scripts_dir / "dcf_report.py").write_text("print('ok')\n", encoding="utf-8")
manager = SkillsManager(project_root=tmp_path)
active_map = manager.prepare_active_skills(
config_name="demo",
agent_defaults={"valuation_analyst": ["valuation_review"]},
)
active_dir = active_map["valuation_analyst"][0]
assert (active_dir / "scripts" / "dcf_report.py").exists()