Initial commit of integrated agent system
This commit is contained in:
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
104
backend/tests/test_agent_service_app.py
Normal file
104
backend/tests/test_agent_service_app.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted agent service surface."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.agent_service import create_app
|
||||
from backend.api import agents as agents_module
|
||||
|
||||
|
||||
def test_agent_service_routes_include_control_plane_endpoints(tmp_path):
|
||||
app = create_app(project_root=tmp_path)
|
||||
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/status" in paths
|
||||
assert "/api/workspaces" in paths
|
||||
assert "/api/guard/pending" in paths
|
||||
|
||||
|
||||
def test_agent_service_excludes_runtime_routes(tmp_path):
|
||||
app = create_app(project_root=tmp_path)
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/api/runtime/start" not in paths
|
||||
assert "/api/runtime/gateway/port" not in paths
|
||||
|
||||
|
||||
def test_agent_service_read_routes(monkeypatch, tmp_path):
|
||||
class _FakeSkillsManager:
|
||||
project_root = tmp_path
|
||||
|
||||
def get_agent_asset_dir(self, config_name, agent_id):
|
||||
return tmp_path / "runs" / config_name / "agents" / agent_id
|
||||
|
||||
def resolve_agent_skill_names(self, config_name, agent_id, default_skills=None):
|
||||
return ["demo_skill"]
|
||||
|
||||
def list_agent_skill_catalog(self, config_name, agent_id):
|
||||
return [
|
||||
type(
|
||||
"Skill",
|
||||
(),
|
||||
{
|
||||
"skill_name": "demo_skill",
|
||||
"name": "Demo Skill",
|
||||
"description": "demo",
|
||||
"version": "1.0.0",
|
||||
"source": "builtin",
|
||||
"tools": [],
|
||||
},
|
||||
)()
|
||||
]
|
||||
|
||||
def load_agent_skill_document(self, config_name, agent_id, skill_name):
|
||||
return {"skill_name": skill_name, "content": "# demo"}
|
||||
|
||||
class _FakeWorkspaceManager:
|
||||
def load_agent_file(self, config_name, agent_id, filename):
|
||||
return f"{config_name}:{agent_id}:{filename}"
|
||||
|
||||
monkeypatch.setattr(agents_module, "load_agent_profiles", lambda: {"portfolio_manager": {"skills": ["demo_skill"]}})
|
||||
monkeypatch.setattr(agents_module, "get_agent_model_info", lambda agent_id: ("deepseek-v3.2", "DASHSCOPE"))
|
||||
monkeypatch.setattr(
|
||||
agents_module,
|
||||
"load_agent_workspace_config",
|
||||
lambda path: type(
|
||||
"Cfg",
|
||||
(),
|
||||
{
|
||||
"active_tool_groups": ["portfolio_ops"],
|
||||
"disabled_tool_groups": [],
|
||||
"enabled_skills": [],
|
||||
"disabled_skills": [],
|
||||
"prompt_files": ["SOUL.md", "MEMORY.md"],
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
agents_module,
|
||||
"get_bootstrap_config_for_run",
|
||||
lambda project_root, config_name: type("Bootstrap", (), {"agent_override": lambda self, agent_id: {}})(),
|
||||
)
|
||||
|
||||
app = create_app(project_root=tmp_path)
|
||||
app.dependency_overrides[agents_module.get_skills_manager] = lambda: _FakeSkillsManager()
|
||||
app.dependency_overrides[agents_module.get_workspace_manager] = lambda: _FakeWorkspaceManager()
|
||||
|
||||
with TestClient(app) as client:
|
||||
profile = client.get("/api/workspaces/demo/agents/portfolio_manager/profile")
|
||||
skills = client.get("/api/workspaces/demo/agents/portfolio_manager/skills")
|
||||
detail = client.get("/api/workspaces/demo/agents/portfolio_manager/skills/demo_skill")
|
||||
workspace_file = client.get("/api/workspaces/demo/agents/portfolio_manager/files/MEMORY.md")
|
||||
|
||||
assert profile.status_code == 200
|
||||
assert profile.json()["profile"]["model_name"] == "deepseek-v3.2"
|
||||
assert skills.status_code == 200
|
||||
assert skills.json()["skills"][0]["skill_name"] == "demo_skill"
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["skill"]["content"] == "# demo"
|
||||
assert workspace_file.status_code == 200
|
||||
assert workspace_file.json()["content"] == "demo:portfolio_manager:MEMORY.md"
|
||||
233
backend/tests/test_agent_workspace.py
Normal file
233
backend/tests/test_agent_workspace.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.agents.prompt_factory import build_agent_system_prompt
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
|
||||
|
||||
class _DummyToolkit:
|
||||
def get_agent_skill_prompt(self):
|
||||
return ""
|
||||
|
||||
def get_activated_notes(self):
|
||||
return ""
|
||||
|
||||
|
||||
def test_workspace_manager_creates_core_agent_files(tmp_path):
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
assert (asset_dir / "SOUL.md").exists()
|
||||
assert (asset_dir / "PROFILE.md").exists()
|
||||
assert (asset_dir / "AGENTS.md").exists()
|
||||
assert (asset_dir / "MEMORY.md").exists()
|
||||
assert (asset_dir / "POLICY.md").exists()
|
||||
assert (asset_dir / "agent.yaml").exists()
|
||||
assert (asset_dir / "skills" / "installed").is_dir()
|
||||
assert (asset_dir / "skills" / "active").is_dir()
|
||||
assert (asset_dir / "skills" / "disabled").is_dir()
|
||||
assert (asset_dir / "skills" / "local").is_dir()
|
||||
|
||||
|
||||
def test_workspace_manager_seeds_risk_prompt_content(tmp_path):
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
soul = (asset_dir / "SOUL.md").read_text(encoding="utf-8")
|
||||
guide = (asset_dir / "AGENTS.md").read_text(encoding="utf-8")
|
||||
|
||||
assert "风险管理经理" in soul
|
||||
assert "优先使用可用的风险工具量化集中度" in guide
|
||||
|
||||
|
||||
def test_agent_workspace_config_controls_prompt_files(tmp_path, monkeypatch):
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
(asset_dir / "SOUL.md").write_text("soul-line", encoding="utf-8")
|
||||
(asset_dir / "PROFILE.md").write_text("profile-line", encoding="utf-8")
|
||||
(asset_dir / "MEMORY.md").write_text("memory-line", encoding="utf-8")
|
||||
(asset_dir / "agent.yaml").write_text(
|
||||
"prompt_files:\n"
|
||||
" - SOUL.md\n"
|
||||
" - MEMORY.md\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
from backend.agents import prompt_factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
prompt_factory,
|
||||
"SkillsManager",
|
||||
lambda: SkillsManager(project_root=tmp_path),
|
||||
)
|
||||
|
||||
prompt = build_agent_system_prompt(
|
||||
agent_id="risk_manager",
|
||||
config_name="demo",
|
||||
toolkit=_DummyToolkit(),
|
||||
)
|
||||
|
||||
assert "soul-line" in prompt
|
||||
assert "memory-line" in prompt
|
||||
assert "profile-line" not in prompt
|
||||
|
||||
|
||||
def test_prompt_is_built_from_workspace_defaults_without_system_templates(tmp_path, monkeypatch):
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["portfolio_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
|
||||
from backend.agents import prompt_factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
prompt_factory,
|
||||
"SkillsManager",
|
||||
lambda: SkillsManager(project_root=tmp_path),
|
||||
)
|
||||
|
||||
prompt = build_agent_system_prompt(
|
||||
agent_id="portfolio_manager",
|
||||
config_name="demo",
|
||||
toolkit=_DummyToolkit(),
|
||||
)
|
||||
|
||||
assert "投资组合经理" in prompt
|
||||
assert "使用 `make_decision` 工具记录每个股票的最终决策" in prompt
|
||||
|
||||
|
||||
def test_skills_manager_applies_agent_level_skill_toggles(tmp_path):
|
||||
builtin_root = tmp_path / "backend" / "skills" / "builtin"
|
||||
for skill_name in ("risk_review", "extra_guard"):
|
||||
skill_dir = builtin_root / skill_name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"# {skill_name}\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
(asset_dir / "agent.yaml").write_text(
|
||||
"enabled_skills:\n"
|
||||
" - extra_guard\n"
|
||||
"disabled_skills:\n"
|
||||
" - risk_review\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
skills_manager = SkillsManager(project_root=tmp_path)
|
||||
active_map = skills_manager.prepare_active_skills(
|
||||
config_name="demo",
|
||||
agent_defaults={"risk_manager": ["risk_review"]},
|
||||
)
|
||||
|
||||
active_dirs = active_map["risk_manager"]
|
||||
assert [path.name for path in active_dirs] == ["extra_guard"]
|
||||
assert (asset_dir / "skills" / "installed" / "extra_guard" / "SKILL.md").exists()
|
||||
assert (asset_dir / "skills" / "active" / "extra_guard" / "SKILL.md").exists()
|
||||
assert (asset_dir / "skills" / "disabled" / "risk_review" / "SKILL.md").exists()
|
||||
assert not (asset_dir / "skills" / "active" / "risk_review").exists()
|
||||
|
||||
|
||||
def test_agent_local_skill_is_activated_from_agent_workspace(tmp_path):
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
local_skill = asset_dir / "skills" / "local" / "local_guard"
|
||||
local_skill.mkdir(parents=True, exist_ok=True)
|
||||
(local_skill / "SKILL.md").write_text(
|
||||
"---\nname: 本地风控\ndescription: local skill\nversion: 1.0.0\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
skills_manager = SkillsManager(project_root=tmp_path)
|
||||
active_map = skills_manager.prepare_active_skills(
|
||||
config_name="demo",
|
||||
agent_defaults={"risk_manager": []},
|
||||
)
|
||||
|
||||
assert [path.name for path in active_map["risk_manager"]] == ["local_guard"]
|
||||
assert (asset_dir / "skills" / "active" / "local_guard" / "SKILL.md").exists()
|
||||
|
||||
|
||||
def test_prompt_includes_active_skill_metadata_summary(tmp_path, monkeypatch):
|
||||
builtin_root = tmp_path / "backend" / "skills" / "builtin"
|
||||
skill_dir = builtin_root / "extra_guard"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: extra_guard\n"
|
||||
"description: This skill should be used when the user asks to \"run a risk check\".\n"
|
||||
"version: 1.0.0\n"
|
||||
"tools:\n"
|
||||
" - risk_ops\n"
|
||||
"---\n\n"
|
||||
"# Extra Guard\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
manager = WorkspaceManager(project_root=tmp_path)
|
||||
manager.initialize_default_assets(
|
||||
config_name="demo",
|
||||
agent_ids=["risk_manager"],
|
||||
analyst_personas={},
|
||||
)
|
||||
asset_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
(asset_dir / "agent.yaml").write_text(
|
||||
"enabled_skills:\n"
|
||||
" - extra_guard\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
skills_manager = SkillsManager(project_root=tmp_path)
|
||||
skills_manager.prepare_active_skills(
|
||||
config_name="demo",
|
||||
agent_defaults={"risk_manager": []},
|
||||
)
|
||||
|
||||
from backend.agents import prompt_factory
|
||||
|
||||
monkeypatch.setattr(
|
||||
prompt_factory,
|
||||
"SkillsManager",
|
||||
lambda: SkillsManager(project_root=tmp_path),
|
||||
)
|
||||
|
||||
prompt = build_agent_system_prompt(
|
||||
agent_id="risk_manager",
|
||||
config_name="demo",
|
||||
toolkit=_DummyToolkit(),
|
||||
)
|
||||
|
||||
assert "Active Skill Catalog" in prompt
|
||||
assert "This skill should be used when the user asks to \"run a risk check\"." in prompt
|
||||
assert "version: 1.0.0" in prompt
|
||||
assert "risk_ops" not in prompt
|
||||
591
backend/tests/test_agents.py
Normal file
591
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,591 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
class TestAnalystAgent:
|
||||
def test_init_valid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="technical_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.analyst_type_key == "technical_analyst"
|
||||
assert agent.name == "technical_analyst"
|
||||
assert agent.analyst_persona == "Technical Analyst"
|
||||
|
||||
def test_init_invalid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AnalystAgent(
|
||||
analyst_type="invalid_type",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert "Unknown analyst type" in str(excinfo.value)
|
||||
|
||||
def test_init_custom_agent_id(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="fundamentals_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
agent_id="custom_analyst_id",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_analyst_id"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="sentiment_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestPMAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "portfolio_manager"
|
||||
assert agent.portfolio["cash"] == 100000.0
|
||||
assert agent.portfolio["positions"] == {}
|
||||
assert agent.portfolio["margin_requirement"] == 0.25
|
||||
|
||||
def test_init_custom_cash(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=50000.0,
|
||||
margin_requirement=0.5,
|
||||
)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert agent.portfolio["margin_requirement"] == 0.5
|
||||
|
||||
def test_get_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=75000.0,
|
||||
)
|
||||
|
||||
state = agent.get_portfolio_state()
|
||||
|
||||
assert state["cash"] == 75000.0
|
||||
assert state is not agent.portfolio # Should be a copy
|
||||
|
||||
def test_load_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
new_portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0, "long_cost_basis": 150.0},
|
||||
},
|
||||
"margin_used": 1000.0,
|
||||
}
|
||||
|
||||
agent.load_portfolio_state(new_portfolio)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert "AAPL" in agent.portfolio["positions"]
|
||||
|
||||
def test_update_portfolio(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent.update_portfolio({"cash": 80000.0})
|
||||
assert agent.portfolio["cash"] == 80000.0
|
||||
|
||||
def _get_text_from_tool_response(self, result):
|
||||
"""Helper to extract text from ToolResponse content"""
|
||||
content = result.content[0]
|
||||
if hasattr(content, "text"):
|
||||
return content.text
|
||||
elif isinstance(content, dict):
|
||||
return content.get("text", "")
|
||||
return str(content)
|
||||
|
||||
def test_make_decision_long(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=100,
|
||||
confidence=80,
|
||||
reasoning="Strong fundamentals",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["AAPL"]["action"] == "long"
|
||||
assert agent._decisions["AAPL"]["quantity"] == 100
|
||||
|
||||
def test_make_decision_hold(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="GOOGL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
confidence=50,
|
||||
reasoning="Neutral outlook",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["GOOGL"]["action"] == "hold"
|
||||
assert agent._decisions["GOOGL"]["quantity"] == 0
|
||||
|
||||
def test_make_decision_invalid_action(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="invalid",
|
||||
quantity=10,
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Invalid action" in text
|
||||
|
||||
def test_get_decisions(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent._make_decision("AAPL", "long", 100)
|
||||
agent._make_decision("GOOGL", "short", 50)
|
||||
|
||||
decisions = agent.get_decisions()
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"]["action"] == "long"
|
||||
assert decisions["GOOGL"]["action"] == "short"
|
||||
|
||||
|
||||
class TestRiskAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "risk_manager"
|
||||
|
||||
def test_init_custom_name(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
name="custom_risk_manager",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_risk_manager"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestStorageService:
|
||||
def test_storage_service_defaults_to_live_config(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
assert storage.config_name == "live"
|
||||
|
||||
def test_calculate_portfolio_value_cash_only(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {"cash": 100000.0, "positions": {}, "margin_used": 0.0}
|
||||
prices = {}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 100000.0
|
||||
|
||||
def test_calculate_portfolio_value_with_positions(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0},
|
||||
"GOOGL": {"long": 0, "short": 10},
|
||||
},
|
||||
"margin_used": 5000.0,
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 69000.0
|
||||
|
||||
def test_update_dashboard_after_cycle(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 90000.0,
|
||||
"positions": {"AAPL": {"long": 50, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
prices = {"AAPL": 200.0}
|
||||
|
||||
storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date="2024-01-15",
|
||||
executed_trades=[
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"action": "long",
|
||||
"quantity": 50,
|
||||
"price": 200.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary is not None
|
||||
assert summary["totalAssetValue"] == 100000.0 # 90000 + 50*200
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert holdings is not None
|
||||
assert len(holdings) > 0
|
||||
|
||||
trades = storage.load_file("trades")
|
||||
assert trades is not None
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["ticker"] == "AAPL"
|
||||
assert trades[0]["qty"] == 50
|
||||
assert trades[0]["price"] == 200.0
|
||||
|
||||
def test_generate_summary(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
"equity_history": [{"t": 1000, "v": 100000}],
|
||||
"all_trades": [],
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_summary(state, 100000.0, prices)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary["totalAssetValue"] == 100000.0
|
||||
assert summary["totalReturn"] == 0.0
|
||||
|
||||
def test_generate_holdings(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_holdings(state, prices)
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert len(holdings) == 2 # AAPL + CASH
|
||||
|
||||
aapl_holding = next(
|
||||
(h for h in holdings if h["ticker"] == "AAPL"),
|
||||
None,
|
||||
)
|
||||
assert aapl_holding is not None
|
||||
assert aapl_holding["quantity"] == 100
|
||||
assert aapl_holding["currentPrice"] == 500.0
|
||||
|
||||
|
||||
class TestTradeExecutor:
|
||||
def test_execute_trade_long(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
assert executor.portfolio["cash"] == 98500.0 # 100000 - 10*150
|
||||
|
||||
def test_execute_trade_short(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {
|
||||
"AAPL": {
|
||||
"long": 50,
|
||||
"short": 0,
|
||||
"long_cost_basis": 100.0,
|
||||
"short_cost_basis": 0.0,
|
||||
},
|
||||
},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="short",
|
||||
quantity=30,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 20 # 50 - 30
|
||||
|
||||
def test_execute_trade_hold(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor()
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "No trade needed"
|
||||
|
||||
|
||||
class TestPipelineExecution:
|
||||
def test_execute_decisions(self):
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
pm = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=[],
|
||||
risk_manager=MagicMock(),
|
||||
portfolio_manager=pm,
|
||||
max_comm_cycles=0,
|
||||
)
|
||||
|
||||
decisions = {
|
||||
"AAPL": {"action": "long", "quantity": 10},
|
||||
"GOOGL": {"action": "short", "quantity": 5},
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
result = pipeline._execute_decisions(decisions, prices, "2024-01-15")
|
||||
|
||||
assert len(result["executed_trades"]) == 2
|
||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||
assert result["executed_trades"][0]["quantity"] == 10
|
||||
assert pm.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
|
||||
|
||||
class TestMsgContentIsString:
|
||||
def test_msg_content_string(self):
|
||||
msg = Msg(name="test", content="simple string", role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
def test_msg_content_json_string(self):
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
parsed = json.loads(msg.content)
|
||||
assert parsed["key"] == "value"
|
||||
|
||||
def test_msg_content_should_not_be_dict(self):
|
||||
data = {"key": "value"}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="assistant")
|
||||
|
||||
assert not isinstance(msg.content, dict)
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
10
backend/tests/test_analysis_tools.py
Normal file
10
backend/tests/test_analysis_tools.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from backend.tools.analysis_tools import _resolved_date
|
||||
|
||||
|
||||
def test_resolved_date_clamps_future_date():
|
||||
future_date = (datetime.today() + timedelta(days=2)).strftime("%Y-%m-%d")
|
||||
|
||||
assert _resolved_date(future_date) == datetime.today().strftime("%Y-%m-%d")
|
||||
235
backend/tests/test_cli.py
Normal file
235
backend/tests/test_cli.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
|
||||
from backend import cli
|
||||
|
||||
|
||||
def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp_path):
|
||||
project_root = tmp_path
|
||||
(project_root / ".env").write_text("FINNHUB_API_KEY=test\n", encoding="utf-8")
|
||||
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
|
||||
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
|
||||
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_update_market_store",
|
||||
lambda config_name, end_date=None: calls.append(("auto_update_market_store", config_name, end_date)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_enrich_market_store",
|
||||
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
|
||||
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
|
||||
|
||||
def fake_run(cmd, check=True, **kwargs):
|
||||
calls.append(("subprocess.run", cmd, check))
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli.subprocess, "run", fake_run)
|
||||
|
||||
cli.live(
|
||||
config_name="smoke_fullstack",
|
||||
host="0.0.0.0",
|
||||
port=8765,
|
||||
trigger_time="now",
|
||||
poll_interval=10,
|
||||
clean=False,
|
||||
enable_memory=False,
|
||||
)
|
||||
|
||||
assert any(item[0] == "run_data_updater" for item in calls)
|
||||
assert any(
|
||||
item[0] == "auto_update_market_store" and item[1] == "smoke_fullstack"
|
||||
for item in calls
|
||||
)
|
||||
assert any(
|
||||
item[0] == "auto_enrich_market_store" and item[1] == "smoke_fullstack"
|
||||
for item in calls
|
||||
)
|
||||
run_call = next(item for item in calls if item[0] == "subprocess.run")
|
||||
assert run_call[1][:6] == [
|
||||
cli.sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"live",
|
||||
]
|
||||
|
||||
|
||||
def test_backtest_runs_full_market_store_prepare_before_start(monkeypatch, tmp_path):
|
||||
project_root = tmp_path
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "get_project_root", lambda: project_root)
|
||||
monkeypatch.setattr(cli, "handle_history_cleanup", lambda config_name, auto_clean=False: None)
|
||||
monkeypatch.setattr(cli, "run_data_updater", lambda project_root: calls.append(("run_data_updater", project_root)))
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_prepare_backtest_market_store",
|
||||
lambda config_name, start_date, end_date: calls.append(
|
||||
("auto_prepare_backtest_market_store", config_name, start_date, end_date)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"auto_enrich_market_store",
|
||||
lambda config_name, end_date=None, lookback_days=120, force=False: calls.append(
|
||||
("auto_enrich_market_store", config_name, end_date, lookback_days, force)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(cli.os, "chdir", lambda path: calls.append(("chdir", Path(path))))
|
||||
|
||||
def fake_run(cmd, check=True, **kwargs):
|
||||
calls.append(("subprocess.run", cmd, check))
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(cli.subprocess, "run", fake_run)
|
||||
|
||||
cli.backtest(
|
||||
start="2026-03-01",
|
||||
end="2026-03-10",
|
||||
config_name="smoke_fullstack",
|
||||
host="0.0.0.0",
|
||||
port=8765,
|
||||
poll_interval=10,
|
||||
clean=False,
|
||||
enable_memory=False,
|
||||
)
|
||||
|
||||
assert any(item[0] == "run_data_updater" for item in calls)
|
||||
assert any(
|
||||
item[0] == "auto_prepare_backtest_market_store"
|
||||
and item[1:] == ("smoke_fullstack", "2026-03-01", "2026-03-10")
|
||||
for item in calls
|
||||
)
|
||||
assert any(
|
||||
item[0] == "auto_enrich_market_store"
|
||||
and item[1] == "smoke_fullstack"
|
||||
and item[2] == "2026-03-10"
|
||||
for item in calls
|
||||
)
|
||||
run_call = next(item for item in calls if item[0] == "subprocess.run")
|
||||
assert run_call[1][:6] == [
|
||||
cli.sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"backtest",
|
||||
]
|
||||
|
||||
|
||||
def test_ingest_enrich_runs_batch_enrichment(monkeypatch):
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL", "MSFT"])
|
||||
|
||||
class DummyStore:
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
|
||||
monkeypatch.setattr(
|
||||
cli,
|
||||
"enrich_symbols",
|
||||
lambda store, symbols, start_date=None, end_date=None, limit=200, analysis_source="local", skip_existing=True: calls.append(
|
||||
("enrich_symbols", symbols, start_date, end_date, limit, analysis_source, skip_existing)
|
||||
) or [
|
||||
{
|
||||
"symbol": symbol,
|
||||
"news_count": 3,
|
||||
"queued_count": 3,
|
||||
"analyzed": 3,
|
||||
"skipped_existing_count": 0,
|
||||
"deduped_count": 0,
|
||||
"llm_count": 0,
|
||||
"local_count": 3,
|
||||
}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
cli.ingest_enrich(
|
||||
tickers=None,
|
||||
start="2026-03-01",
|
||||
end="2026-03-10",
|
||||
limit=150,
|
||||
force=False,
|
||||
config_name="smoke_fullstack",
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
("enrich_symbols", ["AAPL", "MSFT"], "2026-03-01", "2026-03-10", 150, "local", True)
|
||||
]
|
||||
|
||||
|
||||
def test_ingest_report_reads_market_store_report(monkeypatch):
|
||||
calls = []
|
||||
printed = []
|
||||
|
||||
monkeypatch.setattr(cli, "_resolve_symbols", lambda raw_tickers, config_name=None: ["AAPL"])
|
||||
|
||||
class DummyStore:
|
||||
def get_enrich_report(self, symbols=None, start_date=None, end_date=None):
|
||||
calls.append(("get_enrich_report", symbols, start_date, end_date))
|
||||
return [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"raw_news_count": 10,
|
||||
"analyzed_news_count": 8,
|
||||
"coverage_pct": 80.0,
|
||||
"llm_count": 5,
|
||||
"local_count": 3,
|
||||
"latest_trade_date": "2026-03-16",
|
||||
"latest_analysis_at": "2026-03-16T09:00:00",
|
||||
}
|
||||
]
|
||||
|
||||
monkeypatch.setattr(cli, "MarketStore", lambda: DummyStore())
|
||||
monkeypatch.setattr(cli, "get_explain_model_info", lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"})
|
||||
monkeypatch.setattr(cli, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
|
||||
|
||||
cli.ingest_report(
|
||||
tickers=None,
|
||||
start="2026-03-01",
|
||||
end="2026-03-16",
|
||||
config_name="smoke_fullstack",
|
||||
only_problematic=False,
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
("get_enrich_report", ["AAPL"], "2026-03-01", "2026-03-16")
|
||||
]
|
||||
assert printed
|
||||
assert getattr(printed[0], "caption", "") == "Explain LLM: DASHSCOPE:qwen-max"
|
||||
|
||||
|
||||
def test_filter_problematic_report_rows_keeps_low_coverage_and_no_llm():
|
||||
rows = [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"coverage_pct": 100.0,
|
||||
"llm_count": 2,
|
||||
},
|
||||
{
|
||||
"symbol": "MSFT",
|
||||
"coverage_pct": 80.0,
|
||||
"llm_count": 1,
|
||||
},
|
||||
{
|
||||
"symbol": "NVDA",
|
||||
"coverage_pct": 100.0,
|
||||
"llm_count": 0,
|
||||
},
|
||||
]
|
||||
|
||||
filtered = cli._filter_problematic_report_rows(rows)
|
||||
|
||||
assert [row["symbol"] for row in filtered] == ["MSFT", "NVDA"]
|
||||
55
backend/tests/test_data_config.py
Normal file
55
backend/tests/test_data_config.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for data source config ordering."""
|
||||
|
||||
from backend.config.data_config import get_config, reset_config
|
||||
|
||||
|
||||
def test_data_config_prefers_env_source(monkeypatch):
|
||||
monkeypatch.setenv("FIN_DATA_SOURCE", "financial_datasets")
|
||||
monkeypatch.setenv("FINNHUB_API_KEY", "fh")
|
||||
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd")
|
||||
reset_config()
|
||||
|
||||
config = get_config()
|
||||
|
||||
assert config.sources[0] == "financial_datasets"
|
||||
assert "local_csv" in config.sources
|
||||
|
||||
|
||||
def test_enabled_data_sources_filters_available_sources(monkeypatch):
|
||||
monkeypatch.setenv("FINNHUB_API_KEY", "fh-key")
|
||||
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd-key")
|
||||
monkeypatch.setenv("ENABLED_DATA_SOURCES", "financial_datasets,local_csv")
|
||||
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
|
||||
reset_config()
|
||||
|
||||
config = get_config()
|
||||
|
||||
assert config.sources == ["financial_datasets", "local_csv"]
|
||||
assert config.source == "financial_datasets"
|
||||
|
||||
|
||||
def test_preferred_source_reorders_enabled_sources(monkeypatch):
|
||||
monkeypatch.setenv("FINNHUB_API_KEY", "fh-key")
|
||||
monkeypatch.setenv("FINANCIAL_DATASETS_API_KEY", "fd-key")
|
||||
monkeypatch.setenv("ENABLED_DATA_SOURCES", "financial_datasets,finnhub,local_csv")
|
||||
monkeypatch.setenv("FIN_DATA_SOURCE", "finnhub")
|
||||
reset_config()
|
||||
|
||||
config = get_config()
|
||||
|
||||
assert config.sources == ["finnhub", "financial_datasets", "local_csv"]
|
||||
assert config.source == "finnhub"
|
||||
|
||||
|
||||
def test_yfinance_can_be_enabled_without_api_key(monkeypatch):
|
||||
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
|
||||
monkeypatch.setenv("FIN_DATA_SOURCE", "yfinance")
|
||||
monkeypatch.setenv("ENABLED_DATA_SOURCES", "yfinance,local_csv")
|
||||
reset_config()
|
||||
|
||||
config = get_config()
|
||||
|
||||
assert config.sources == ["yfinance", "local_csv"]
|
||||
assert config.source == "yfinance"
|
||||
139
backend/tests/test_data_tools_service_routing.py
Normal file
139
backend/tests/test_data_tools_service_routing.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for data_tools preferring split services when configured."""
|
||||
|
||||
from backend.tools import data_tools
|
||||
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
|
||||
|
||||
|
||||
def test_data_tools_prefers_trading_service(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
|
||||
monkeypatch.setenv("SERVICE_NAME", "agent_service")
|
||||
monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None)
|
||||
|
||||
def fake_service_get_json(base_url, path, *, params):
|
||||
if path == "/api/prices":
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"prices": [
|
||||
Price(
|
||||
open=1,
|
||||
close=2,
|
||||
high=3,
|
||||
low=1,
|
||||
volume=10,
|
||||
time="2026-03-16",
|
||||
).model_dump()
|
||||
],
|
||||
}
|
||||
if path == "/api/financials":
|
||||
return {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker="AAPL",
|
||||
report_period="2026-03-16",
|
||||
period="ttm",
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/insider-trades":
|
||||
return {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/news":
|
||||
return {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker="AAPL",
|
||||
title="Title",
|
||||
source="polygon",
|
||||
url="https://example.com",
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/market-cap":
|
||||
return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12}
|
||||
if path == "/api/line-items":
|
||||
return {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker="AAPL",
|
||||
report_period="2026-03-16",
|
||||
period="ttm",
|
||||
currency="USD",
|
||||
free_cash_flow=321.0,
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
raise AssertionError(path)
|
||||
|
||||
monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json)
|
||||
|
||||
prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16")
|
||||
metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16")
|
||||
trades = data_tools.get_insider_trades("AAPL", "2026-03-16")
|
||||
news = data_tools.get_company_news("AAPL", "2026-03-16")
|
||||
market_cap = data_tools.get_market_cap("AAPL", "2026-03-16")
|
||||
line_items = data_tools.search_line_items(
|
||||
"AAPL",
|
||||
["free_cash_flow"],
|
||||
"2026-03-16",
|
||||
)
|
||||
|
||||
assert prices[0].close == 2
|
||||
assert metrics[0].ticker == "AAPL"
|
||||
assert trades[0].ticker == "AAPL"
|
||||
assert news[0].ticker == "AAPL"
|
||||
assert market_cap == 2.5e12
|
||||
assert line_items[0].free_cash_flow == 321.0
|
||||
|
||||
|
||||
def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
|
||||
monkeypatch.setenv("SERVICE_NAME", "trading_service")
|
||||
|
||||
assert data_tools._trading_service_url() is None
|
||||
25
backend/tests/test_env_config.py
Normal file
25
backend/tests/test_env_config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for normalized env config helpers."""
|
||||
|
||||
from backend.config.env_config import (
|
||||
canonicalize_model_provider,
|
||||
get_agent_model_config,
|
||||
)
|
||||
|
||||
|
||||
def test_canonicalize_model_provider_aliases():
|
||||
assert canonicalize_model_provider("claude") == "ANTHROPIC"
|
||||
assert canonicalize_model_provider("openai_compatible") == "OPENAI"
|
||||
assert canonicalize_model_provider("google") == "GEMINI"
|
||||
|
||||
|
||||
def test_get_agent_model_config_fallback(monkeypatch):
|
||||
monkeypatch.delenv("AGENT_RISK_MANAGER_MODEL_NAME", raising=False)
|
||||
monkeypatch.delenv("AGENT_RISK_MANAGER_MODEL_PROVIDER", raising=False)
|
||||
monkeypatch.setenv("MODEL_NAME", "gpt-4o-mini")
|
||||
monkeypatch.setenv("MODEL_PROVIDER", "openai")
|
||||
|
||||
config = get_agent_model_config("risk_manager")
|
||||
|
||||
assert config.model_name == "gpt-4o-mini"
|
||||
assert config.provider == "OPENAI"
|
||||
934
backend/tests/test_gateway_explain_handlers.py
Normal file
934
backend/tests/test_gateway_explain_handlers.py
Normal file
@@ -0,0 +1,934 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services.gateway import Gateway
|
||||
import backend.services.gateway as gateway_module
|
||||
from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
async def send(self, payload: str):
|
||||
self.messages.append(json.loads(payload))
|
||||
|
||||
|
||||
class DummyStateSync:
|
||||
def __init__(self, current_date="2026-03-16"):
|
||||
self.state = {"current_date": current_date}
|
||||
self.system_messages = []
|
||||
|
||||
def set_broadcast_fn(self, _fn):
|
||||
return None
|
||||
|
||||
def update_state(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
|
||||
class FakeMarketStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
self.calls.append(("get_ticker_watermarks", symbol))
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
|
||||
|
||||
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
|
||||
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
|
||||
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
|
||||
|
||||
def get_news_items(self, symbol, *, start_date=None, end_date=None, limit=100):
|
||||
self.calls.append(("get_news_items", symbol, start_date, end_date, limit))
|
||||
return [
|
||||
{
|
||||
"id": "news-1",
|
||||
"ticker": symbol,
|
||||
"date": end_date,
|
||||
"trade_date": end_date,
|
||||
"title": "Title",
|
||||
"summary": "Summary",
|
||||
"source": "polygon",
|
||||
}
|
||||
]
|
||||
|
||||
def get_news_items_enriched(self, symbol, *, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
self.calls.append(("get_news_items_enriched", symbol, start_date, end_date, trade_date, limit))
|
||||
target_date = trade_date or end_date
|
||||
return [
|
||||
{
|
||||
"id": "news-1",
|
||||
"ticker": symbol,
|
||||
"date": target_date,
|
||||
"trade_date": target_date,
|
||||
"title": "Title",
|
||||
"summary": "Summary",
|
||||
"source": "polygon",
|
||||
"sentiment": "negative",
|
||||
"relevance": "high",
|
||||
"key_discussion": "Key discussion",
|
||||
}
|
||||
]
|
||||
|
||||
def get_news_by_ids_enriched(self, symbol, article_ids):
|
||||
self.calls.append(("get_news_by_ids_enriched", symbol, list(article_ids)))
|
||||
return [{"id": article_ids[0], "ticker": symbol, "date": "2026-03-16", "sentiment": "negative"}]
|
||||
|
||||
def get_news_categories_enriched(self, symbol, *, start_date=None, end_date=None, limit=200):
|
||||
self.calls.append(("get_news_categories_enriched", symbol, start_date, end_date, limit))
|
||||
return {"macro": {"label": "宏观", "count": 1, "article_ids": ["news-1"], "positive_ids": [], "negative_ids": ["news-1"], "neutral_ids": []}}
|
||||
|
||||
def get_story_cache(self, symbol, *, as_of_date):
|
||||
self.calls.append(("get_story_cache", symbol, as_of_date))
|
||||
return None
|
||||
|
||||
def upsert_story_cache(self, symbol, *, as_of_date, content, source="local"):
|
||||
self.calls.append(("upsert_story_cache", symbol, as_of_date, source))
|
||||
|
||||
def delete_story_cache(self, symbol, *, as_of_date=None):
|
||||
self.calls.append(("delete_story_cache", symbol, as_of_date))
|
||||
return 1
|
||||
|
||||
def get_similar_day_cache(self, symbol, *, target_date):
|
||||
self.calls.append(("get_similar_day_cache", symbol, target_date))
|
||||
return None
|
||||
|
||||
def upsert_similar_day_cache(self, symbol, *, target_date, payload, source="local"):
|
||||
self.calls.append(("upsert_similar_day_cache", symbol, target_date, source))
|
||||
|
||||
def delete_similar_day_cache(self, symbol, *, target_date=None):
|
||||
self.calls.append(("delete_similar_day_cache", symbol, target_date))
|
||||
return 1
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
self.calls.append(("get_ohlc", symbol, start_date, end_date))
|
||||
return [
|
||||
{"date": start_date, "open": 100, "high": 105, "low": 99, "close": 103},
|
||||
{"date": end_date, "open": 103, "high": 108, "low": 102, "close": 107},
|
||||
]
|
||||
|
||||
|
||||
def make_gateway(market_store=None):
|
||||
storage = SimpleNamespace(market_store=market_store or FakeMarketStore())
|
||||
pipeline = SimpleNamespace(state_sync=None)
|
||||
market_service = SimpleNamespace()
|
||||
state_sync = DummyStateSync()
|
||||
return Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage,
|
||||
pipeline=pipeline,
|
||||
state_sync=state_sync,
|
||||
config={"mode": "live"},
|
||||
)
|
||||
|
||||
|
||||
class FakeNewsClient:
|
||||
def __init__(self, base_url):
|
||||
self.base_url = base_url
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def get_categories(self, ticker, start_date=None, end_date=None, limit=200):
|
||||
return {"ticker": ticker, "categories": {"remote": {"count": 2}}}
|
||||
|
||||
async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"news": [
|
||||
{
|
||||
"id": "remote-news-1",
|
||||
"ticker": ticker,
|
||||
"title": "Remote Title",
|
||||
"date": end_date,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def get_story(self, ticker, as_of_date):
|
||||
return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"}
|
||||
|
||||
|
||||
class FakeTradingClient:
|
||||
def __init__(self, base_url):
|
||||
self.base_url = base_url
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None):
|
||||
return InsiderTradeResponse(
|
||||
insider_trades=[
|
||||
InsiderTrade(
|
||||
ticker=ticker,
|
||||
name="Remote Insider",
|
||||
filing_date=end_date or "2026-03-16",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def get_prices(self, ticker, start_date=None, end_date=None):
|
||||
prices = [
|
||||
Price(
|
||||
open=float(100 + idx),
|
||||
close=float(101 + idx),
|
||||
high=float(102 + idx),
|
||||
low=float(99 + idx),
|
||||
volume=1000 + idx,
|
||||
time=f"2026-01-{idx + 1:02d}",
|
||||
)
|
||||
for idx in range(30)
|
||||
]
|
||||
return PriceResponse(ticker=ticker, prices=prices)
|
||||
|
||||
async def get_market_cap(self, ticker, end_date):
|
||||
return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_timeline(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_categories_uses_market_store_symbol_argument(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_categories(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
|
||||
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
|
||||
assert websocket.messages[-1]["categories"]["macro"]["count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
def fake_build_range_explanation(*, ticker, start_date, end_date, news_rows):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"news_count": len(news_rows),
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"build_range_explanation",
|
||||
fake_build_range_explanation,
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_range_explain(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "start_date": "2026-03-10", "end_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
|
||||
]
|
||||
assert websocket.messages[-1] == {
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": "AAPL",
|
||||
"result": {
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"news_count": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"build_range_explanation",
|
||||
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_range_explain(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"article_ids": ["news-99"],
|
||||
},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_by_ids_enriched", "AAPL", ["news-99"])
|
||||
]
|
||||
assert websocket.messages[-1]["result"]["news_count"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_stock_news_for_date(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "date": "2026-03-16", "limit": 10},
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
|
||||
assert websocket.messages[-1]["date"] == "2026-03-16"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_story(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_story_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert "AAPL Story" in websocket.messages[-1]["story"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_news_categories(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
|
||||
assert websocket.messages[-1]["categories"]["remote"]["count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_story(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_story_loaded"
|
||||
assert websocket.messages[-1]["story"] == "remote story"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_news(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30, "limit": 5},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_news_loaded"
|
||||
assert websocket.messages[-1]["source"] == "news_service"
|
||||
assert websocket.messages[-1]["news"][0]["title"] == "Remote Title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_insider_trades(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded"
|
||||
assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_history(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_history_loaded"
|
||||
assert websocket.messages[-1]["source"] == "trading_service"
|
||||
assert len(websocket.messages[-1]["prices"]) == 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_technical_indicators(
|
||||
websocket,
|
||||
{"ticker": "AAPL"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert websocket.messages[-1]["indicators"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16")
|
||||
|
||||
assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
|
||||
await gateway._handle_get_stock_similar_days(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "date": "2026-03-16", "top_k": 5},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_similar_days_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert isinstance(websocket.messages[-1]["items"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.gateway_stock_handlers,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
|
||||
)
|
||||
|
||||
await gateway._handle_run_stock_enrich(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"force": True,
|
||||
"rebuild_story": True,
|
||||
"rebuild_similar_days": True,
|
||||
"story_date": "2026-03-16",
|
||||
"target_date": "2026-03-16",
|
||||
},
|
||||
)
|
||||
|
||||
assert ("delete_story_cache", "AAPL", "2026-03-16") in market_store.calls
|
||||
assert ("delete_similar_day_cache", "AAPL", "2026-03-16") in market_store.calls
|
||||
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
|
||||
assert websocket.messages[-1]["stats"]["analyzed"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False)
|
||||
|
||||
await gateway._handle_run_stock_enrich(
|
||||
websocket,
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"only_local_to_llm": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_enrich_completed"
|
||||
assert "requires EXPLAIN_ENRICH_USE_LLM=true" in websocket.messages[-1]["error"]
|
||||
|
||||
|
||||
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
||||
gateway = make_gateway()
|
||||
captured = {}
|
||||
|
||||
class DummyTask:
|
||||
def done(self):
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
captured["cancelled"] = True
|
||||
|
||||
def fake_create_task(coro):
|
||||
captured["coro_name"] = coro.cr_code.co_name
|
||||
coro.close()
|
||||
return DummyTask()
|
||||
|
||||
monkeypatch.setattr(gateway_module.asyncio, "create_task", fake_create_task)
|
||||
|
||||
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
|
||||
|
||||
assert captured["coro_name"] == "refresh_market_store_for_watchlist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
|
||||
gateway = make_gateway()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.gateway_cycle_support,
|
||||
"ingest_symbols",
|
||||
lambda symbols, mode="incremental": [
|
||||
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
await gateway._refresh_market_store_for_watchlist(["AAPL", "MSFT"])
|
||||
|
||||
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
|
||||
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
|
||||
assert "AAPL prices=3 news=4" in gateway.state_sync.system_messages[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_agent_skills_returns_statuses(tmp_path):
|
||||
builtin_root = tmp_path / "backend" / "skills" / "builtin"
|
||||
for name in ("risk_review", "extra_guard"):
|
||||
skill_dir = builtin_root / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: {name} desc\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "agent.yaml").write_text(
|
||||
"enabled_skills:\n"
|
||||
" - extra_guard\n"
|
||||
"disabled_skills:\n"
|
||||
" - risk_review\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_agent_skills(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "agent_skills_loaded"
|
||||
statuses = {
|
||||
row["skill_name"]: row["status"]
|
||||
for row in websocket.messages[-1]["skills"]
|
||||
}
|
||||
assert statuses["extra_guard"] == "enabled"
|
||||
assert statuses["risk_review"] == "disabled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatch, tmp_path):
|
||||
agent_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager"
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "agent.yaml").write_text(
|
||||
"prompt_files:\n"
|
||||
" - SOUL.md\n"
|
||||
" - MEMORY.md\n"
|
||||
"active_tool_groups:\n"
|
||||
" - risk_ops\n"
|
||||
"disabled_tool_groups:\n"
|
||||
" - legacy_group\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"load_agent_profiles",
|
||||
lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"get_agent_model_info",
|
||||
lambda agent_id: ("gpt-4o-mini", "OPENAI"),
|
||||
)
|
||||
|
||||
class _Bootstrap:
|
||||
@staticmethod
|
||||
def agent_override(_agent_id):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"get_bootstrap_config_for_run",
|
||||
lambda project_root, config_name: _Bootstrap(),
|
||||
)
|
||||
|
||||
await gateway._handle_get_agent_profile(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "agent_profile_loaded"
|
||||
profile = websocket.messages[-1]["profile"]
|
||||
assert profile["model_name"] == "gpt-4o-mini"
|
||||
assert profile["model_provider"] == "OPENAI"
|
||||
assert profile["prompt_files"] == ["SOUL.md", "MEMORY.md"]
|
||||
assert profile["active_tool_groups"] == ["risk_ops"]
|
||||
assert profile["disabled_tool_groups"] == ["legacy_group"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_skill_detail_returns_markdown_body(tmp_path):
|
||||
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: 风险审查\ndescription: 说明\nversion: 1.0.0\n---\n# 风险审查\n\n完整正文\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_skill_detail(
|
||||
websocket,
|
||||
{"skill_name": "risk_review"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "skill_detail_loaded"
|
||||
assert websocket.messages[-1]["skill"]["name"] == "风险审查"
|
||||
assert websocket.messages[-1]["skill"]["version"] == "1.0.0"
|
||||
assert websocket.messages[-1]["skill"]["content"] == "# 风险审查\n\n完整正文"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_skill_detail_prefers_agent_local_skill(tmp_path):
|
||||
skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: 本地风控\ndescription: 本地说明\nversion: 1.0.0\n---\n# 本地风控\n\n本地正文\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_skill_detail(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager", "skill_name": "local_guard"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "skill_detail_loaded"
|
||||
assert websocket.messages[-1]["agent_id"] == "risk_manager"
|
||||
assert websocket.messages[-1]["skill"]["source"] == "local"
|
||||
assert websocket.messages[-1]["skill"]["content"] == "# 本地风控\n\n本地正文"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_update_agent_skill_persists_and_returns_refresh(monkeypatch, tmp_path):
|
||||
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "extra_guard"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: extra_guard\ndescription: desc\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
async def _noop_reload():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
|
||||
|
||||
await gateway._handle_update_agent_skill(
|
||||
websocket,
|
||||
{
|
||||
"agent_id": "risk_manager",
|
||||
"skill_name": "extra_guard",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_skill_updated"
|
||||
assert websocket.messages[-1]["type"] == "agent_skills_loaded"
|
||||
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
|
||||
assert "extra_guard" in agent_yaml.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_create_and_update_agent_local_skill(monkeypatch, tmp_path):
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
async def _noop_reload():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
|
||||
|
||||
await gateway._handle_create_agent_local_skill(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager", "skill_name": "local_guard"},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_local_skill_created"
|
||||
assert websocket.messages[1]["type"] == "agent_skills_loaded"
|
||||
assert websocket.messages[2]["type"] == "skill_detail_loaded"
|
||||
target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard" / "SKILL.md"
|
||||
assert target.exists()
|
||||
|
||||
websocket.messages.clear()
|
||||
await gateway._handle_update_agent_local_skill(
|
||||
websocket,
|
||||
{
|
||||
"agent_id": "risk_manager",
|
||||
"skill_name": "local_guard",
|
||||
"content": "---\nname: 本地风控\ndescription: 更新后\nversion: 1.0.0\n---\n# 本地风控\n\n更新正文\n",
|
||||
},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_local_skill_updated"
|
||||
assert websocket.messages[1]["type"] == "skill_detail_loaded"
|
||||
assert "更新正文" in target.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_delete_agent_local_skill(monkeypatch, tmp_path):
|
||||
skill_dir = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "skills" / "local" / "local_guard"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: 本地风控\ndescription: desc\nversion: 1.0.0\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
|
||||
agent_yaml.parent.mkdir(parents=True, exist_ok=True)
|
||||
agent_yaml.write_text(
|
||||
"enabled_skills:\n"
|
||||
" - local_guard\n"
|
||||
"disabled_skills:\n"
|
||||
" - local_guard\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
async def _noop_reload():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
|
||||
|
||||
await gateway._handle_delete_agent_local_skill(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager", "skill_name": "local_guard"},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_local_skill_deleted"
|
||||
assert websocket.messages[1]["type"] == "agent_skills_loaded"
|
||||
assert not skill_dir.exists()
|
||||
assert "local_guard" not in agent_yaml.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_remove_agent_skill_marks_disabled(monkeypatch, tmp_path):
|
||||
skill_dir = tmp_path / "backend" / "skills" / "builtin" / "risk_review"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\nname: 风险审查\ndescription: desc\nversion: 1.0.0\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
async def _noop_reload():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
|
||||
|
||||
await gateway._handle_remove_agent_skill(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager", "skill_name": "risk_review"},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_skill_removed"
|
||||
assert websocket.messages[1]["type"] == "agent_skills_loaded"
|
||||
agent_yaml = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "agent.yaml"
|
||||
assert "risk_review" in agent_yaml.read_text(encoding="utf-8")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_agent_workspace_file_returns_content(tmp_path):
|
||||
file_path = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
file_path.write_text("soul content", encoding="utf-8")
|
||||
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
await gateway._handle_get_agent_workspace_file(
|
||||
websocket,
|
||||
{"agent_id": "risk_manager", "filename": "SOUL.md"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1] == {
|
||||
"type": "agent_workspace_file_loaded",
|
||||
"config_name": "demo",
|
||||
"agent_id": "risk_manager",
|
||||
"filename": "SOUL.md",
|
||||
"content": "soul content",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_update_agent_workspace_file_persists_and_returns_refresh(monkeypatch, tmp_path):
|
||||
gateway = make_gateway()
|
||||
gateway.config["config_name"] = "demo"
|
||||
gateway._project_root = tmp_path
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
async def _noop_reload():
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(gateway, "_handle_reload_runtime_assets", _noop_reload)
|
||||
|
||||
await gateway._handle_update_agent_workspace_file(
|
||||
websocket,
|
||||
{
|
||||
"agent_id": "risk_manager",
|
||||
"filename": "SOUL.md",
|
||||
"content": "updated soul",
|
||||
},
|
||||
)
|
||||
|
||||
assert websocket.messages[0]["type"] == "agent_workspace_file_updated"
|
||||
assert websocket.messages[-1]["type"] == "agent_workspace_file_loaded"
|
||||
target = tmp_path / "runs" / "demo" / "agents" / "risk_manager" / "SOUL.md"
|
||||
assert target.read_text(encoding="utf-8") == "updated soul"
|
||||
201
backend/tests/test_gateway_support_modules.py
Normal file
201
backend/tests/test_gateway_support_modules.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Direct tests for Gateway support modules."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services import gateway_cycle_support, gateway_runtime_support
|
||||
|
||||
|
||||
class _DummyScheduler:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
|
||||
|
||||
class _DummyStateSync:
|
||||
def __init__(self):
|
||||
self.updated = []
|
||||
self.saved = False
|
||||
self.system_messages = []
|
||||
self.backtest_dates = []
|
||||
self.state = {}
|
||||
|
||||
def update_state(self, key, value):
|
||||
self.updated.append((key, value))
|
||||
self.state[key] = value
|
||||
|
||||
def save_state(self):
|
||||
self.saved = True
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
def set_backtest_dates(self, dates):
|
||||
self.backtest_dates = list(dates)
|
||||
|
||||
|
||||
class _DummyStorage:
|
||||
def __init__(self):
|
||||
self.initial_cash = 100000.0
|
||||
self.is_live_session_active = False
|
||||
self.server_state_updates = []
|
||||
|
||||
def can_apply_initial_cash(self):
|
||||
return True
|
||||
|
||||
def apply_initial_cash(self, value):
|
||||
self.initial_cash = value
|
||||
return True
|
||||
|
||||
def update_server_state_from_dashboard(self, state):
|
||||
self.server_state_updates.append(state)
|
||||
|
||||
def load_file(self, name):
|
||||
if name == "summary":
|
||||
return {"totalAssetValue": self.initial_cash}
|
||||
return []
|
||||
|
||||
def build_dashboard_snapshot_from_state(self, state):
|
||||
return {
|
||||
"summary": {"totalAssetValue": self.initial_cash},
|
||||
"holdings": [],
|
||||
"stats": {},
|
||||
"trades": [],
|
||||
"leaderboard": [],
|
||||
}
|
||||
|
||||
|
||||
class _DummyPM:
|
||||
def __init__(self):
|
||||
self.portfolio = {"margin_requirement": 0.0}
|
||||
|
||||
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
|
||||
if margin_requirement is not None:
|
||||
self.portfolio["margin_requirement"] = margin_requirement
|
||||
return {"margin_requirement": True}
|
||||
|
||||
def can_apply_initial_cash(self):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyMarketService:
|
||||
def __init__(self):
|
||||
self.updated = None
|
||||
self.stopped = False
|
||||
|
||||
def update_tickers(self, tickers):
|
||||
self.updated = list(tickers)
|
||||
return {"active": list(tickers), "added": list(tickers), "removed": []}
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
|
||||
|
||||
def make_gateway_stub():
|
||||
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
|
||||
gateway = SimpleNamespace(
|
||||
market_service=_DummyMarketService(),
|
||||
pipeline=pipeline,
|
||||
scheduler=_DummyScheduler(),
|
||||
config={
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"enable_memory": False,
|
||||
},
|
||||
storage=_DummyStorage(),
|
||||
state_sync=_DummyStateSync(),
|
||||
_watchlist_ingest_task=None,
|
||||
_market_status_task=None,
|
||||
_backtest_task=None,
|
||||
_backtest_start_date=None,
|
||||
_backtest_end_date=None,
|
||||
_manual_cycle_task=None,
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def test_normalize_watchlist_filters_invalid_and_dedupes():
|
||||
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
|
||||
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_obeys_allowlist():
|
||||
allowlist = {"SOUL.md", "PROFILE.md"}
|
||||
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
|
||||
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
|
||||
|
||||
|
||||
def test_apply_runtime_config_updates_gateway_state():
|
||||
gateway = make_gateway_stub()
|
||||
|
||||
result = gateway_runtime_support.apply_runtime_config(
|
||||
gateway,
|
||||
{
|
||||
"tickers": ["MSFT", "NVDA"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "10:30",
|
||||
"initial_cash": 150000.0,
|
||||
"margin_requirement": 0.5,
|
||||
"max_comm_cycles": 4,
|
||||
"enable_memory": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
|
||||
assert gateway.config["schedule_mode"] == "intraday"
|
||||
assert gateway.storage.initial_cash == 150000.0
|
||||
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
|
||||
assert gateway.scheduler.calls[-1] == {
|
||||
"mode": "intraday",
|
||||
"trigger_time": "10:30",
|
||||
"interval_minutes": 30,
|
||||
}
|
||||
|
||||
|
||||
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
||||
gateway = make_gateway_stub()
|
||||
captured = {}
|
||||
|
||||
class DummyTask:
|
||||
def done(self):
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
captured["cancelled"] = True
|
||||
|
||||
def fake_create_task(coro):
|
||||
captured["name"] = coro.cr_code.co_name
|
||||
coro.close()
|
||||
return DummyTask()
|
||||
|
||||
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
|
||||
|
||||
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
|
||||
|
||||
assert captured["name"] == "refresh_market_store_for_watchlist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
|
||||
gateway = make_gateway_stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cycle_support,
|
||||
"ingest_symbols",
|
||||
lambda symbols, mode="incremental": [
|
||||
{"symbol": symbol, "prices": 3, "news": 4}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
|
||||
|
||||
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
|
||||
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
|
||||
65
backend/tests/test_historical_price_manager.py
Normal file
65
backend/tests/test_historical_price_manager.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
|
||||
|
||||
def test_preload_data_prefers_market_db():
|
||||
manager = HistoricalPriceManager()
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
market_rows = [
|
||||
{
|
||||
"symbol": "AAPL",
|
||||
"date": "2026-03-09",
|
||||
"open": 100.0,
|
||||
"high": 103.0,
|
||||
"low": 99.0,
|
||||
"close": 102.0,
|
||||
"volume": 10_000,
|
||||
"vwap": 101.0,
|
||||
"transactions": 500,
|
||||
"source": "polygon",
|
||||
}
|
||||
]
|
||||
|
||||
with (
|
||||
patch.object(manager._market_store, "get_ohlc", return_value=market_rows),
|
||||
patch.object(manager._router, "load_local_price_frame") as load_csv,
|
||||
):
|
||||
manager.preload_data("2026-03-01", "2026-03-10")
|
||||
|
||||
load_csv.assert_not_called()
|
||||
assert "AAPL" in manager._price_cache
|
||||
assert float(manager._price_cache["AAPL"].iloc[0]["close"]) == 102.0
|
||||
|
||||
|
||||
def test_preload_data_falls_back_to_csv():
|
||||
manager = HistoricalPriceManager()
|
||||
manager.subscribe(["MSFT"])
|
||||
|
||||
csv_df = pd.DataFrame(
|
||||
{
|
||||
"time": ["2026-03-09"],
|
||||
"open": [200.0],
|
||||
"high": [205.0],
|
||||
"low": [198.0],
|
||||
"close": [204.0],
|
||||
"volume": [20_000],
|
||||
}
|
||||
)
|
||||
csv_df["time"] = pd.to_datetime(csv_df["time"])
|
||||
csv_df["Date"] = csv_df["time"]
|
||||
csv_df.set_index("Date", inplace=True)
|
||||
|
||||
with (
|
||||
patch.object(manager._market_store, "get_ohlc", return_value=[]),
|
||||
patch.object(manager._router, "load_local_price_frame", return_value=csv_df) as load_csv,
|
||||
):
|
||||
manager.preload_data("2026-03-01", "2026-03-10")
|
||||
|
||||
load_csv.assert_called_once_with("MSFT")
|
||||
assert "MSFT" in manager._price_cache
|
||||
assert float(manager._price_cache["MSFT"].iloc[0]["close"]) == 204.0
|
||||
133
backend/tests/test_llm_enricher.py
Normal file
133
backend/tests/test_llm_enricher.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.enrich import llm_enricher
|
||||
|
||||
|
||||
class DummyResponse:
|
||||
def __init__(self, metadata):
|
||||
self.metadata = metadata
|
||||
|
||||
|
||||
class DummyModel:
|
||||
def __init__(self, metadata):
|
||||
self.metadata = metadata
|
||||
self.calls = []
|
||||
|
||||
async def __call__(self, messages, structured_model=None, **kwargs):
|
||||
self.calls.append(
|
||||
{
|
||||
"messages": messages,
|
||||
"structured_model": structured_model,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return DummyResponse(self.metadata)
|
||||
|
||||
|
||||
def test_analyze_news_row_with_llm_uses_agentscope_model(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"id": "news-1",
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Demand remains resilient",
|
||||
"summary": "Structured summary",
|
||||
"reason_growth": "Orders improved",
|
||||
"reason_decrease": "",
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_news_row_with_llm(
|
||||
{
|
||||
"id": "news-1",
|
||||
"title": "Apple expands AI features",
|
||||
"summary": "New devices and software updates were announced.",
|
||||
}
|
||||
)
|
||||
|
||||
assert result["sentiment"] == "positive"
|
||||
assert result["summary"] == "Structured summary"
|
||||
assert result["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsItem
|
||||
|
||||
|
||||
def test_analyze_news_rows_with_llm_uses_agentscope_structured_batch(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"items": [
|
||||
{
|
||||
"id": "news-1",
|
||||
"relevance": "high",
|
||||
"sentiment": "negative",
|
||||
"key_discussion": "Margin pressure",
|
||||
"summary": "Batch summary",
|
||||
"reason_growth": "",
|
||||
"reason_decrease": "Costs rose",
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_news_rows_with_llm(
|
||||
[
|
||||
{
|
||||
"id": "news-1",
|
||||
"title": "Apple margins pressured",
|
||||
"summary": "Costs increased this quarter.",
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert result["news-1"]["sentiment"] == "negative"
|
||||
assert result["news-1"]["reason_decrease"] == "Costs rose"
|
||||
assert result["news-1"]["raw_json"]["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.EnrichedNewsBatch
|
||||
|
||||
|
||||
def test_analyze_range_with_llm_uses_agentscope_structured_output(monkeypatch):
|
||||
model = DummyModel(
|
||||
{
|
||||
"summary": "该股在区间内震荡下行,相关新闻主要集中在盈利预期和供应链扰动。",
|
||||
"trend_analysis": "前半段受利空新闻压制,后半段跌幅收敛。",
|
||||
"bullish_factors": ["估值消化后出现部分承接"],
|
||||
"bearish_factors": ["盈利预期下修", "供应链扰动持续"],
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr(llm_enricher, "llm_range_analysis_enabled", lambda: True)
|
||||
monkeypatch.setattr(llm_enricher, "_get_explain_model", lambda: model)
|
||||
monkeypatch.setattr(
|
||||
llm_enricher,
|
||||
"get_explain_model_info",
|
||||
lambda: {"provider": "DASHSCOPE", "model_name": "qwen-max", "label": "DASHSCOPE:qwen-max"},
|
||||
)
|
||||
|
||||
result = llm_enricher.analyze_range_with_llm(
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-10",
|
||||
"end_date": "2026-03-16",
|
||||
"price_change_pct": -3.42,
|
||||
}
|
||||
)
|
||||
|
||||
assert result["summary"].startswith("该股在区间内震荡下行")
|
||||
assert result["model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert result["bearish_factors"] == ["盈利预期下修", "供应链扰动持续"]
|
||||
assert model.calls
|
||||
assert model.calls[0]["structured_model"] is llm_enricher.RangeAnalysisPayload
|
||||
81
backend/tests/test_market_ingest.py
Normal file
81
backend/tests/test_market_ingest.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for market ingest watermark handling."""
|
||||
|
||||
from backend.data import market_ingest
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def __init__(self, *, last_news_fetch=None, latest_news_date=None):
|
||||
self._watermarks = {
|
||||
"symbol": "AAPL",
|
||||
"last_price_fetch": None,
|
||||
"last_news_fetch": last_news_fetch,
|
||||
}
|
||||
self._latest_news_date = latest_news_date
|
||||
self.updated = []
|
||||
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
return dict(self._watermarks)
|
||||
|
||||
def get_latest_news_date(self, symbol):
|
||||
return self._latest_news_date
|
||||
|
||||
def upsert_ticker(self, **kwargs):
|
||||
return None
|
||||
|
||||
def upsert_ohlc(self, symbol, rows, source="polygon"):
|
||||
return len(rows)
|
||||
|
||||
def upsert_news(self, symbol, rows, source="polygon"):
|
||||
return len(rows)
|
||||
|
||||
def update_fetch_watermark(self, **kwargs):
|
||||
self.updated.append(kwargs)
|
||||
|
||||
|
||||
def test_refresh_news_incremental_does_not_advance_watermark_without_news(monkeypatch):
|
||||
store = _FakeStore(last_news_fetch="2026-03-28", latest_news_date="2026-03-28")
|
||||
|
||||
monkeypatch.setattr(market_ingest, "fetch_ticker_details", lambda ticker: {"name": ticker, "sic_description": None, "active": True})
|
||||
|
||||
class _Router:
|
||||
def get_company_news(self, **kwargs):
|
||||
return [], "polygon"
|
||||
|
||||
monkeypatch.setattr(market_ingest, "DataProviderRouter", lambda: _Router())
|
||||
monkeypatch.setattr(market_ingest, "align_news_for_symbol", lambda store, ticker: 0)
|
||||
|
||||
result = market_ingest.refresh_news_incremental(
|
||||
"AAPL",
|
||||
end_date="2026-03-29",
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert result["start_news_date"] == "2026-03-29"
|
||||
assert result["news"] == 0
|
||||
assert store.updated[-1]["news_date"] is None
|
||||
|
||||
|
||||
def test_refresh_news_incremental_clamps_future_watermark_to_latest_stored_date(monkeypatch):
|
||||
store = _FakeStore(last_news_fetch="2026-03-30", latest_news_date="2026-03-28")
|
||||
captured = {}
|
||||
|
||||
monkeypatch.setattr(market_ingest, "fetch_ticker_details", lambda ticker: {"name": ticker, "sic_description": None, "active": True})
|
||||
|
||||
class _Router:
|
||||
def get_company_news(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return [], "polygon"
|
||||
|
||||
monkeypatch.setattr(market_ingest, "DataProviderRouter", lambda: _Router())
|
||||
monkeypatch.setattr(market_ingest, "align_news_for_symbol", lambda store, ticker: 0)
|
||||
|
||||
result = market_ingest.refresh_news_incremental(
|
||||
"AAPL",
|
||||
end_date="2026-03-29",
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert result["start_news_date"] == "2026-03-29"
|
||||
assert captured["start_date"] == "2026-03-29"
|
||||
assert captured["end_date"] == "2026-03-29"
|
||||
285
backend/tests/test_market_service.py
Normal file
285
backend/tests/test_market_service.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
from backend.llm.models import RetryChatModel
|
||||
|
||||
|
||||
class TestPollingPriceManager:
|
||||
def test_init(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
|
||||
|
||||
assert manager.api_key == "test_key"
|
||||
assert manager.poll_interval == 30
|
||||
assert manager.provider == "finnhub"
|
||||
assert manager.running is False
|
||||
|
||||
def test_init_yfinance(self):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=15)
|
||||
|
||||
assert manager.api_key is None
|
||||
assert manager.poll_interval == 15
|
||||
assert manager.provider == "yfinance"
|
||||
assert manager.running is False
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
@patch.object(PollingPriceManager, "_fetch_prices")
|
||||
def test_start_stop(self, _mock_fetch_prices):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.start()
|
||||
|
||||
assert manager.running is False
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.latest_prices["AAPL"] = 150.0
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_open_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 148.0
|
||||
|
||||
price = manager.get_open_price("AAPL")
|
||||
assert price == 148.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 150.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
for _ in range(3):
|
||||
manager._fetch_prices()
|
||||
|
||||
assert manager._failure_counts["AAPL"] == 3
|
||||
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
|
||||
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
|
||||
|
||||
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_fetch_quote",
|
||||
side_effect=[
|
||||
ValueError("temporary outage"),
|
||||
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
|
||||
],
|
||||
):
|
||||
with caplog.at_level(logging.INFO):
|
||||
manager._fetch_prices()
|
||||
manager._fetch_prices()
|
||||
|
||||
assert "AAPL" not in manager._failure_counts
|
||||
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
class TestRetryChatModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_recovers_from_disconnect(self):
|
||||
attempts = {"count": 0}
|
||||
|
||||
class FakeAsyncModel:
|
||||
model_name = "fake-async-model"
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] < 2:
|
||||
raise RuntimeError("Server disconnected")
|
||||
return {"ok": True}
|
||||
|
||||
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
|
||||
result = await wrapped("hello")
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
)
|
||||
|
||||
service._start_real_mode()
|
||||
|
||||
assert isinstance(service._price_manager, PollingPriceManager)
|
||||
assert service._price_manager.provider == "yfinance"
|
||||
|
||||
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
)
|
||||
|
||||
service._start_real_mode()
|
||||
|
||||
assert isinstance(service._price_manager, PollingPriceManager)
|
||||
assert service._price_manager.provider == "yfinance"
|
||||
|
||||
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert "API key required" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
# First start with backtest mode
|
||||
await service.start(broadcast_func)
|
||||
assert service.running is True
|
||||
|
||||
# Start again should not fail
|
||||
await service.start(broadcast_func)
|
||||
|
||||
service.stop()
|
||||
|
||||
def test_stop(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
backtest_mode=True,
|
||||
)
|
||||
service.running = True
|
||||
service._price_manager = MagicMock()
|
||||
|
||||
service.stop()
|
||||
|
||||
assert service.running is False
|
||||
assert service._price_manager is None
|
||||
|
||||
def test_stop_when_not_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
service.stop()
|
||||
assert service.running is False
|
||||
|
||||
def test_get_price_sync(self):
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
||||
|
||||
price = service.get_price_sync("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_price_sync_not_found(self):
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
|
||||
price = service.get_price_sync("MSFT")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_prices(self):
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], backtest_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0}
|
||||
service.cache["MSFT"] = {"price": 400.0}
|
||||
|
||||
prices = service.get_all_prices()
|
||||
|
||||
assert prices["AAPL"] == 150.0
|
||||
assert prices["MSFT"] == 400.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update(self):
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service._broadcast_func = AsyncMock()
|
||||
|
||||
price_data = {
|
||||
"symbol": "AAPL",
|
||||
"price": 150.0,
|
||||
"open": 148.0,
|
||||
"timestamp": 1234567890,
|
||||
}
|
||||
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
service._broadcast_func.assert_called_once()
|
||||
call_args = service._broadcast_func.call_args[0][0]
|
||||
assert call_args["type"] == "price_update"
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert call_args["price"] == 150.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update_no_func(self):
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service._broadcast_func = None
|
||||
|
||||
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
||||
|
||||
# Should not raise
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
54
backend/tests/test_market_store_report.py
Normal file
54
backend/tests/test_market_store_report.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def test_get_enrich_report_summarizes_coverage(tmp_path: Path):
|
||||
store = MarketStore(tmp_path / "market_research.db")
|
||||
store.upsert_news(
|
||||
"AAPL",
|
||||
[
|
||||
{
|
||||
"id": "news-1",
|
||||
"published_utc": "2026-03-10T12:00:00Z",
|
||||
"title": "Apple earnings beat",
|
||||
"summary": "Revenue topped expectations",
|
||||
"tickers": ["AAPL"],
|
||||
},
|
||||
{
|
||||
"id": "news-2",
|
||||
"published_utc": "2026-03-11T12:00:00Z",
|
||||
"title": "Apple supply chain warning",
|
||||
"summary": "Outlook softened",
|
||||
"tickers": ["AAPL"],
|
||||
},
|
||||
],
|
||||
)
|
||||
store.set_trade_dates(
|
||||
[
|
||||
{"news_id": "news-1", "symbol": "AAPL", "trade_date": "2026-03-10"},
|
||||
{"news_id": "news-2", "symbol": "AAPL", "trade_date": "2026-03-11"},
|
||||
]
|
||||
)
|
||||
store.upsert_news_analysis(
|
||||
"AAPL",
|
||||
[
|
||||
{
|
||||
"news_id": "news-1",
|
||||
"trade_date": "2026-03-10",
|
||||
"summary": "LLM enriched",
|
||||
"analysis_source": "llm",
|
||||
}
|
||||
],
|
||||
analysis_source="llm",
|
||||
)
|
||||
|
||||
rows = store.get_enrich_report(["AAPL"])
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["symbol"] == "AAPL"
|
||||
assert rows[0]["raw_news_count"] == 2
|
||||
assert rows[0]["analyzed_news_count"] == 1
|
||||
assert rows[0]["coverage_pct"] == 50.0
|
||||
assert rows[0]["llm_count"] == 1
|
||||
197
backend/tests/test_news_domain.py
Normal file
197
backend/tests/test_news_domain.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the news domain helpers."""
|
||||
|
||||
from backend.domains import news as news_domain
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
self.calls.append(("get_ticker_watermarks", symbol))
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-03-10"}
|
||||
|
||||
def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit))
|
||||
target = trade_date or end_date
|
||||
return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}]
|
||||
|
||||
def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None):
|
||||
self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date))
|
||||
return [{"date": end_date, "count": 1}]
|
||||
|
||||
def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200):
|
||||
self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit))
|
||||
return {"macro": {"count": 1}}
|
||||
|
||||
def get_news_by_ids_enriched(self, ticker, article_ids):
|
||||
self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids)))
|
||||
return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}]
|
||||
|
||||
|
||||
def test_news_rows_need_enrichment_detects_missing_fields():
|
||||
assert news_domain.news_rows_need_enrichment([]) is True
|
||||
assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True
|
||||
assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False
|
||||
|
||||
|
||||
def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch):
|
||||
store = _FakeStore()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"update_ticker_incremental",
|
||||
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
|
||||
)
|
||||
|
||||
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
|
||||
|
||||
assert calls == [("AAPL", "2026-03-16")]
|
||||
assert payload["target_date"] == "2026-03-16"
|
||||
assert payload["refreshed"] is True
|
||||
|
||||
|
||||
def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch):
|
||||
store = _FakeStore()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
store,
|
||||
"get_ticker_watermarks",
|
||||
lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"update_ticker_incremental",
|
||||
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
|
||||
)
|
||||
|
||||
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
|
||||
|
||||
assert calls == []
|
||||
assert payload["refreshed"] is False
|
||||
|
||||
|
||||
def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
|
||||
payload = news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-01",
|
||||
end_date="2026-03-16",
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert payload["news"][0]["ticker"] == "AAPL"
|
||||
assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16"
|
||||
assert store.calls == [
|
||||
("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20)
|
||||
]
|
||||
|
||||
|
||||
def test_get_story_and_similar_days_delegate(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1})
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"get_or_create_stock_story",
|
||||
lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"find_similar_days",
|
||||
lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]},
|
||||
)
|
||||
|
||||
story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16")
|
||||
similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8)
|
||||
|
||||
assert story["story"] == "story"
|
||||
assert "freshness" in story
|
||||
assert similar["items"][0]["score"] == 0.9
|
||||
assert "freshness" in similar
|
||||
|
||||
|
||||
def test_get_enriched_news_defaults_to_read_only_freshness(monkeypatch):
|
||||
store = _FakeStore()
|
||||
ensure_calls = []
|
||||
|
||||
def fake_ensure(store, ticker, target_date=None, refresh_if_stale=False):
|
||||
ensure_calls.append(refresh_if_stale)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(news_domain, "ensure_news_fresh", fake_ensure)
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
|
||||
payload = news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
end_date="2026-03-16",
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert ensure_calls == [False]
|
||||
|
||||
|
||||
def test_get_range_explain_payload_uses_article_ids(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"build_range_explanation",
|
||||
lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)},
|
||||
)
|
||||
|
||||
payload = news_domain.get_range_explain_payload(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-10",
|
||||
end_date="2026-03-16",
|
||||
article_ids=["news-9"],
|
||||
limit=50,
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert payload["result"] == {"ticker": "AAPL", "count": 1}
|
||||
assert "freshness" in payload
|
||||
assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])]
|
||||
174
backend/tests/test_news_enricher.py
Normal file
174
backend/tests/test_news_enricher.py
Normal file
@@ -0,0 +1,174 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.enrich import news_enricher
|
||||
|
||||
|
||||
def test_classify_news_row_falls_back_to_local_rules(monkeypatch):
|
||||
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
|
||||
result = news_enricher.classify_news_row(
|
||||
{
|
||||
"title": "Apple shares drop after weak guidance",
|
||||
"summary": "Investors reacted negatively to softer-than-expected outlook.",
|
||||
}
|
||||
)
|
||||
assert result["analysis_source"] == "local"
|
||||
assert result["sentiment"] == "negative"
|
||||
assert result["summary"]
|
||||
|
||||
|
||||
def test_classify_news_row_prefers_llm_when_available(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"analyze_news_row_with_llm",
|
||||
lambda row: {
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Demand resilience",
|
||||
"summary": "LLM summary",
|
||||
"reason_growth": "Orders remain strong",
|
||||
"reason_decrease": "",
|
||||
"raw_json": {"provider": "llm"},
|
||||
},
|
||||
)
|
||||
result = news_enricher.classify_news_row(
|
||||
{
|
||||
"title": "Apple expands AI features",
|
||||
"summary": "New devices and software updates were announced.",
|
||||
}
|
||||
)
|
||||
assert result["analysis_source"] == "llm"
|
||||
assert result["sentiment"] == "positive"
|
||||
assert result["summary"] == "LLM summary"
|
||||
|
||||
|
||||
def test_build_analysis_rows_prefers_batch_llm_and_dedupes(monkeypatch):
|
||||
monkeypatch.setattr(news_enricher, "llm_enrichment_enabled", lambda: True)
|
||||
monkeypatch.setattr(news_enricher, "get_env_int", lambda key, default=0: 8)
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"analyze_news_rows_with_llm",
|
||||
lambda rows: {
|
||||
"news-1": {
|
||||
"relevance": "high",
|
||||
"sentiment": "positive",
|
||||
"key_discussion": "Batch result",
|
||||
"summary": "Batch summary",
|
||||
"reason_growth": "Growth",
|
||||
"reason_decrease": "",
|
||||
"raw_json": {"provider": "batch"},
|
||||
}
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_enricher, "analyze_news_row_with_llm", lambda row: None)
|
||||
rows = news_enricher.build_analysis_rows(
|
||||
symbol="AAPL",
|
||||
news_rows=[
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
|
||||
{"id": "news-2", "trade_date": "2026-03-10", "title": "Same title", "summary": "Same summary"},
|
||||
],
|
||||
ohlc_rows=[],
|
||||
)
|
||||
rows, stats = rows
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["analysis_source"] == "llm"
|
||||
assert rows[0]["summary"] == "Batch summary"
|
||||
assert stats["deduped_count"] == 1
|
||||
assert stats["llm_count"] == 1
|
||||
|
||||
|
||||
def test_enrich_news_for_symbol_skips_existing(monkeypatch):
|
||||
class DummyStore:
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return [
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
|
||||
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
|
||||
]
|
||||
|
||||
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
|
||||
return {"news-1"}
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
return []
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
|
||||
self.rows = rows
|
||||
return len(rows)
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"build_analysis_rows",
|
||||
lambda symbol, news_rows, ohlc_rows: (
|
||||
[
|
||||
{
|
||||
"news_id": row["id"],
|
||||
"trade_date": row["trade_date"],
|
||||
"summary": row["summary"],
|
||||
"analysis_source": "local",
|
||||
}
|
||||
for row in news_rows
|
||||
],
|
||||
{"deduped_count": 0, "llm_count": 0, "local_count": len(news_rows)},
|
||||
),
|
||||
)
|
||||
store = DummyStore()
|
||||
result = news_enricher.enrich_news_for_symbol(store, "AAPL")
|
||||
assert result["news_count"] == 2
|
||||
assert result["queued_count"] == 1
|
||||
assert result["skipped_existing_count"] == 1
|
||||
assert len(store.rows) == 1
|
||||
assert store.rows[0]["news_id"] == "news-2"
|
||||
|
||||
|
||||
def test_enrich_news_for_symbol_only_reanalyzes_local(monkeypatch):
|
||||
class DummyStore:
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return [
|
||||
{"id": "news-1", "trade_date": "2026-03-10", "title": "One", "summary": "One"},
|
||||
{"id": "news-2", "trade_date": "2026-03-11", "title": "Two", "summary": "Two"},
|
||||
{"id": "news-3", "trade_date": "2026-03-12", "title": "Three", "summary": "Three"},
|
||||
]
|
||||
|
||||
def get_analyzed_news_sources(self, symbol, start_date=None, end_date=None):
|
||||
return {"news-1": "local", "news-2": "llm"}
|
||||
|
||||
def get_ohlc(self, symbol, start_date, end_date):
|
||||
return []
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows, analysis_source="local"):
|
||||
self.rows = rows
|
||||
return len(rows)
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_enricher,
|
||||
"build_analysis_rows",
|
||||
lambda symbol, news_rows, ohlc_rows: (
|
||||
[
|
||||
{
|
||||
"news_id": row["id"],
|
||||
"trade_date": row["trade_date"],
|
||||
"summary": row["summary"],
|
||||
"analysis_source": "llm" if row["id"] == "news-1" else "local",
|
||||
}
|
||||
for row in news_rows
|
||||
],
|
||||
{"deduped_count": 0, "llm_count": 1, "local_count": 0},
|
||||
),
|
||||
)
|
||||
|
||||
store = DummyStore()
|
||||
result = news_enricher.enrich_news_for_symbol(
|
||||
store,
|
||||
"AAPL",
|
||||
only_reanalyze_local=True,
|
||||
)
|
||||
|
||||
assert result["news_count"] == 3
|
||||
assert result["queued_count"] == 1
|
||||
assert result["skipped_existing_count"] == 2
|
||||
assert result["only_reanalyze_local"] is True
|
||||
assert result["upgraded_local_to_llm_count"] == 1
|
||||
assert result["execution_summary"]["upgraded_dates"] == ["2026-03-10"]
|
||||
assert result["execution_summary"]["remaining_local_titles"] == []
|
||||
assert result["execution_summary"]["skipped_missing_analysis_count"] == 1
|
||||
assert result["execution_summary"]["skipped_non_local_count"] == 1
|
||||
assert [row["news_id"] for row in store.rows] == ["news-1"]
|
||||
180
backend/tests/test_news_service_app.py
Normal file
180
backend/tests/test_news_service_app.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted news service app surface."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.news_service import create_app
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
|
||||
|
||||
def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None):
|
||||
return [{"date": end_date, "count": 1}]
|
||||
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=100):
|
||||
return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}]
|
||||
|
||||
def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}]
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows):
|
||||
return len(rows)
|
||||
|
||||
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
|
||||
return set()
|
||||
|
||||
def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}}
|
||||
|
||||
def get_news_by_ids_enriched(self, symbol, article_ids):
|
||||
return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}]
|
||||
|
||||
|
||||
def test_news_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/enriched-news" in paths
|
||||
assert "/api/news-for-date" in paths
|
||||
assert "/api/news-timeline" in paths
|
||||
assert "/api/categories" in paths
|
||||
assert "/api/similar-days" in paths
|
||||
assert "/api/stories/{ticker}" in paths
|
||||
assert "/api/range-explain" in paths
|
||||
|
||||
|
||||
def test_news_service_enriched_news_and_categories(monkeypatch):
|
||||
app = create_app()
|
||||
app.dependency_overrides.clear()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
news_response = client.get(
|
||||
"/api/enriched-news",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-23"},
|
||||
)
|
||||
categories_response = client.get(
|
||||
"/api/categories",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-23"},
|
||||
)
|
||||
|
||||
assert news_response.status_code == 200
|
||||
assert news_response.json()["news"][0]["ticker"] == "AAPL"
|
||||
assert categories_response.status_code == 200
|
||||
assert categories_response.json()["categories"]["market"]["count"] == 1
|
||||
|
||||
|
||||
def test_news_service_news_for_date_and_timeline(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
date_response = client.get(
|
||||
"/api/news-for-date",
|
||||
params={"ticker": "AAPL", "date": "2026-03-23"},
|
||||
)
|
||||
timeline_response = client.get(
|
||||
"/api/news-timeline",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-23",
|
||||
},
|
||||
)
|
||||
|
||||
assert date_response.status_code == 200
|
||||
assert date_response.json()["date"] == "2026-03-23"
|
||||
assert timeline_response.status_code == 200
|
||||
assert timeline_response.json()["timeline"][0]["count"] == 1
|
||||
|
||||
|
||||
def test_news_service_similar_days_and_story(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.find_similar_days",
|
||||
lambda store, symbol, target_date, top_k: {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"items": [{"date": "2026-03-20", "score": 0.9}],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.get_or_create_stock_story",
|
||||
lambda store, symbol, as_of_date: {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": "story body",
|
||||
"source": "local",
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
similar_response = client.get(
|
||||
"/api/similar-days",
|
||||
params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3},
|
||||
)
|
||||
story_response = client.get(
|
||||
"/api/stories/AAPL",
|
||||
params={"as_of_date": "2026-03-23"},
|
||||
)
|
||||
|
||||
assert similar_response.status_code == 200
|
||||
assert similar_response.json()["items"][0]["score"] == 0.9
|
||||
assert story_response.status_code == 200
|
||||
assert story_response.json()["story"] == "story body"
|
||||
|
||||
|
||||
def test_news_service_range_explain(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.build_range_explanation",
|
||||
lambda ticker, start_date, end_date, news_rows: {
|
||||
"symbol": ticker,
|
||||
"news_count": len(news_rows),
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/api/range-explain",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-23",
|
||||
"article_ids": ["news-7"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["result"]["news_count"] == 1
|
||||
60
backend/tests/test_openclaw_cli_service.py
Normal file
60
backend/tests/test_openclaw_cli_service.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the OpenClaw CLI service wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services.openclaw_cli import OpenClawCliError, OpenClawCliService
|
||||
|
||||
|
||||
class _Completed:
|
||||
def __init__(self, *, returncode=0, stdout="", stderr=""):
|
||||
self.returncode = returncode
|
||||
self.stdout = stdout
|
||||
self.stderr = stderr
|
||||
|
||||
|
||||
def test_openclaw_cli_service_runs_json_command(monkeypatch, tmp_path):
|
||||
captured = {}
|
||||
|
||||
def _fake_run(command, **kwargs):
|
||||
captured["command"] = command
|
||||
captured["cwd"] = kwargs["cwd"]
|
||||
return _Completed(stdout='{"sessions":[{"key":"main/session-1"}]}')
|
||||
|
||||
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
|
||||
|
||||
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
|
||||
payload = service.list_sessions()
|
||||
|
||||
assert payload["sessions"][0]["key"] == "main/session-1"
|
||||
assert captured["command"] == ["openclaw", "sessions", "--json"]
|
||||
assert captured["cwd"] == tmp_path
|
||||
|
||||
|
||||
def test_openclaw_cli_service_raises_on_failure(monkeypatch, tmp_path):
|
||||
def _fake_run(command, **kwargs):
|
||||
return _Completed(returncode=7, stdout="", stderr="boom")
|
||||
|
||||
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
|
||||
|
||||
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
|
||||
|
||||
with pytest.raises(OpenClawCliError) as exc_info:
|
||||
service.list_cron_jobs()
|
||||
|
||||
assert exc_info.value.exit_code == 7
|
||||
assert exc_info.value.stderr == "boom"
|
||||
|
||||
|
||||
def test_openclaw_cli_service_can_extract_single_session(monkeypatch, tmp_path):
|
||||
def _fake_run(command, **kwargs):
|
||||
return _Completed(stdout='{"sessions":[{"key":"main/session-1","agentId":"main"}]}')
|
||||
|
||||
monkeypatch.setattr("backend.services.openclaw_cli.subprocess.run", _fake_run)
|
||||
|
||||
service = OpenClawCliService(base_command=["openclaw"], cwd=tmp_path, timeout_seconds=3)
|
||||
session = service.get_session("main/session-1")
|
||||
|
||||
assert session["agentId"] == "main"
|
||||
110
backend/tests/test_openclaw_service_app.py
Normal file
110
backend/tests/test_openclaw_service_app.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted OpenClaw service app surface."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.openclaw_service import create_app
|
||||
from backend.api import openclaw as openclaw_module
|
||||
|
||||
|
||||
class _FakeOpenClawCliService:
|
||||
def health(self):
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "openclaw-service",
|
||||
"base_command": ["openclaw"],
|
||||
"cwd": "/tmp/openclaw",
|
||||
"binary_resolved": True,
|
||||
"reference_entry_available": True,
|
||||
"timeout_seconds": 15.0,
|
||||
}
|
||||
|
||||
def status(self):
|
||||
return {"runtimeVersion": "2026.3.24"}
|
||||
|
||||
def list_sessions(self):
|
||||
return {
|
||||
"sessions": [
|
||||
{"key": "main/session-1", "agentId": "main"},
|
||||
{"key": "analyst/session-2", "agentId": "analyst"},
|
||||
]
|
||||
}
|
||||
|
||||
def get_session(self, session_key: str):
|
||||
for session in self.list_sessions()["sessions"]:
|
||||
if session["key"] == session_key:
|
||||
return session
|
||||
raise KeyError(session_key)
|
||||
|
||||
def get_session_history(self, session_key: str, *, limit: int = 20):
|
||||
return {
|
||||
"sessionKey": session_key,
|
||||
"limit": limit,
|
||||
"items": [{"role": "assistant", "text": "hello"}],
|
||||
}
|
||||
|
||||
def list_cron_jobs(self):
|
||||
return {"jobs": [{"id": "job-1", "name": "Daily sync"}]}
|
||||
|
||||
def list_approvals(self):
|
||||
return {"approvals": [{"id": "ap-1", "status": "pending"}]}
|
||||
|
||||
|
||||
def test_openclaw_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/status" in paths
|
||||
assert "/api/openclaw/status" in paths
|
||||
assert "/api/openclaw/sessions" in paths
|
||||
assert "/api/openclaw/sessions/{session_key:path}" in paths
|
||||
assert "/api/openclaw/sessions/{session_key:path}/history" in paths
|
||||
assert "/api/openclaw/cron" in paths
|
||||
assert "/api/openclaw/approvals" in paths
|
||||
|
||||
|
||||
def test_openclaw_service_read_routes():
|
||||
app = create_app()
|
||||
app.dependency_overrides[openclaw_module.get_openclaw_cli_service] = (
|
||||
lambda: _FakeOpenClawCliService()
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
health = client.get("/health")
|
||||
status = client.get("/api/status")
|
||||
openclaw_status = client.get("/api/openclaw/status")
|
||||
sessions = client.get("/api/openclaw/sessions")
|
||||
session = client.get("/api/openclaw/sessions/main/session-1")
|
||||
history = client.get("/api/openclaw/sessions/main/session-1/history", params={"limit": 5})
|
||||
cron = client.get("/api/openclaw/cron")
|
||||
approvals = client.get("/api/openclaw/approvals")
|
||||
|
||||
assert health.status_code == 200
|
||||
assert health.json()["service"] == "openclaw-service"
|
||||
assert status.status_code == 200
|
||||
assert status.json()["status"] == "operational"
|
||||
assert openclaw_status.status_code == 200
|
||||
assert openclaw_status.json()["runtimeVersion"] == "2026.3.24"
|
||||
assert sessions.status_code == 200
|
||||
assert len(sessions.json()["sessions"]) == 2
|
||||
assert session.status_code == 200
|
||||
assert session.json()["session"]["agentId"] == "main"
|
||||
assert history.status_code == 200
|
||||
assert history.json()["limit"] == 5
|
||||
assert cron.status_code == 200
|
||||
assert cron.json()["jobs"][0]["id"] == "job-1"
|
||||
assert approvals.status_code == 200
|
||||
assert approvals.json()["approvals"][0]["id"] == "ap-1"
|
||||
|
||||
|
||||
def test_openclaw_service_session_404():
|
||||
app = create_app()
|
||||
app.dependency_overrides[openclaw_module.get_openclaw_cli_service] = (
|
||||
lambda: _FakeOpenClawCliService()
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/openclaw/sessions/missing")
|
||||
|
||||
assert response.status_code == 404
|
||||
74
backend/tests/test_openclaw_websocket_client.py
Normal file
74
backend/tests/test_openclaw_websocket_client.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the OpenClaw WebSocket client session helpers."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.client.openclaw_websocket_client import OpenClawWebSocketClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_session_parses_gateway_key_response():
|
||||
client = OpenClawWebSocketClient(gateway_token="test-token")
|
||||
|
||||
async def fake_send_request(method, params=None, _allow_handshake=False):
|
||||
assert method == "sessions.resolve"
|
||||
assert params["agentId"] == "main"
|
||||
return {"ok": True, "key": "agent:main:main"}
|
||||
|
||||
client._send_request = fake_send_request # type: ignore[method-assign]
|
||||
|
||||
resolved = await client.resolve_session(agent_id="main")
|
||||
|
||||
assert resolved == "agent:main:main"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_uses_session_send_payload():
|
||||
client = OpenClawWebSocketClient(gateway_token="test-token")
|
||||
|
||||
async def fake_send_request(method, params=None, _allow_handshake=False):
|
||||
assert method == "sessions.send"
|
||||
assert params == {
|
||||
"key": "agent:main:main",
|
||||
"message": "hello",
|
||||
"thinking": "medium",
|
||||
}
|
||||
return {"ok": True, "runId": "run-1"}
|
||||
|
||||
client._send_request = fake_send_request # type: ignore[method-assign]
|
||||
|
||||
result = await client.send_message("agent:main:main", "hello", thinking="medium")
|
||||
|
||||
assert result["runId"] == "run-1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_history_uses_sessions_preview():
|
||||
client = OpenClawWebSocketClient(gateway_token="test-token")
|
||||
|
||||
async def fake_send_request(method, params=None, _allow_handshake=False):
|
||||
assert method == "sessions.preview"
|
||||
assert params == {"keys": ["agent:main:main"], "limit": 12}
|
||||
return {"previews": []}
|
||||
|
||||
client._send_request = fake_send_request # type: ignore[method-assign]
|
||||
|
||||
result = await client.get_session_history("agent:main:main", limit=12)
|
||||
|
||||
assert result == {"previews": []}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe_uses_session_messages_unsubscribe():
|
||||
client = OpenClawWebSocketClient(gateway_token="test-token")
|
||||
|
||||
async def fake_send_request(method, params=None, _allow_handshake=False):
|
||||
assert method == "sessions.messages.unsubscribe"
|
||||
assert params == {"key": "agent:main:main"}
|
||||
return {"subscribed": False}
|
||||
|
||||
client._send_request = fake_send_request # type: ignore[method-assign]
|
||||
|
||||
result = await client.unsubscribe("agent:main:main")
|
||||
|
||||
assert result == {"subscribed": False}
|
||||
30
backend/tests/test_provider_router.py
Normal file
30
backend/tests/test_provider_router.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for provider router fallback behavior."""
|
||||
|
||||
from backend.data.provider_router import DataProviderRouter
|
||||
from backend.config.data_config import reset_config
|
||||
|
||||
|
||||
def test_router_includes_local_csv_fallback(monkeypatch):
|
||||
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
|
||||
monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False)
|
||||
reset_config()
|
||||
|
||||
router = DataProviderRouter()
|
||||
|
||||
assert router.price_sources() == ["local_csv"]
|
||||
|
||||
|
||||
def test_router_allows_yfinance_when_enabled(monkeypatch):
|
||||
monkeypatch.setenv("FIN_DATA_SOURCE", "yfinance")
|
||||
monkeypatch.setenv("ENABLED_DATA_SOURCES", "yfinance,local_csv")
|
||||
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
|
||||
reset_config()
|
||||
|
||||
router = DataProviderRouter()
|
||||
|
||||
assert router.price_sources() == ["yfinance", "local_csv"]
|
||||
assert router.api_sources() == ["yfinance"]
|
||||
15
backend/tests/test_provider_utils.py
Normal file
15
backend/tests/test_provider_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for market symbol normalization helpers."""
|
||||
|
||||
from backend.data.provider_utils import describe_symbol, normalize_symbol
|
||||
|
||||
|
||||
def test_normalize_symbol_exchange_prefix():
|
||||
assert normalize_symbol("sh600519") == "600519"
|
||||
assert normalize_symbol("600519.SH") == "600519"
|
||||
|
||||
|
||||
def test_normalize_symbol_us_ticker():
|
||||
symbol = describe_symbol("aapl")
|
||||
assert symbol.canonical == "AAPL"
|
||||
assert symbol.market == "us"
|
||||
54
backend/tests/test_range_explainer.py
Normal file
54
backend/tests/test_range_explainer.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from backend.explain import range_explainer
|
||||
|
||||
|
||||
def test_build_range_explanation_prefers_llm_text_when_available(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
range_explainer,
|
||||
"get_prices",
|
||||
lambda ticker, start_date, end_date: [
|
||||
SimpleNamespace(open=100, close=98, high=102, low=97, volume=1000),
|
||||
SimpleNamespace(open=98, close=96, high=99, low=95, volume=1100),
|
||||
SimpleNamespace(open=96, close=97, high=98, low=94, volume=1200),
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
range_explainer,
|
||||
"analyze_range_with_llm",
|
||||
lambda payload: {
|
||||
"summary": "区间内整体偏弱,主题集中在盈利预期和供应链风险。",
|
||||
"trend_analysis": "前半段快速下探,后半段出现修复。",
|
||||
"bullish_factors": ["回调后出现承接"],
|
||||
"bearish_factors": ["盈利预期承压"],
|
||||
"model_label": "DASHSCOPE:qwen-max",
|
||||
},
|
||||
)
|
||||
|
||||
result = range_explainer.build_range_explanation(
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-10",
|
||||
end_date="2026-03-16",
|
||||
news_rows=[
|
||||
{
|
||||
"id": "news-1",
|
||||
"trade_date": "2026-03-10",
|
||||
"title": "Apple margin pressure concerns grow",
|
||||
"summary": "Investors focused on weaker margin outlook.",
|
||||
"sentiment": "negative",
|
||||
"relevance": "high",
|
||||
"ret_t0": -0.02,
|
||||
"reason_decrease": "盈利预期承压",
|
||||
"category": "earnings",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
assert result["analysis"]["summary"] == "区间内整体偏弱,主题集中在盈利预期和供应链风险。"
|
||||
assert result["analysis"]["trend_analysis"] == "前半段快速下探,后半段出现修复。"
|
||||
assert result["analysis"]["bullish_factors"] == ["回调后出现承接"]
|
||||
assert result["analysis"]["analysis_source"] == "llm"
|
||||
assert result["analysis"]["analysis_model_label"] == "DASHSCOPE:qwen-max"
|
||||
assert result["news_count"] == 1
|
||||
364
backend/tests/test_runtime_service_app.py
Normal file
364
backend/tests/test_runtime_service_app.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted runtime service app surface."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.api import runtime as runtime_module
|
||||
from backend.apps.runtime_service import create_app
|
||||
|
||||
|
||||
def test_runtime_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/status" in paths
|
||||
assert "/api/runtime/start" in paths
|
||||
assert "/api/runtime/stop" in paths
|
||||
assert "/api/runtime/cleanup" in paths
|
||||
assert "/api/runtime/history" in paths
|
||||
assert "/api/runtime/current" in paths
|
||||
assert "/api/runtime/gateway/port" in paths
|
||||
|
||||
|
||||
def test_runtime_service_health_and_status(monkeypatch):
|
||||
runtime_state = runtime_module.get_runtime_state()
|
||||
runtime_state.gateway_process = None
|
||||
runtime_state.gateway_port = 9876
|
||||
runtime_state.runtime_manager = object()
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
health_response = client.get("/health")
|
||||
status_response = client.get("/api/status")
|
||||
|
||||
assert health_response.status_code == 200
|
||||
assert health_response.json() == {
|
||||
"status": "healthy",
|
||||
"service": "runtime-service",
|
||||
"gateway_running": False,
|
||||
"gateway_port": 9876,
|
||||
}
|
||||
assert status_response.status_code == 200
|
||||
assert status_response.json() == {
|
||||
"status": "operational",
|
||||
"service": "runtime-service",
|
||||
"runtime": {
|
||||
"gateway_running": False,
|
||||
"gateway_port": 9876,
|
||||
"has_runtime_manager": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch):
|
||||
runtime_module.get_runtime_state().gateway_port = 9345
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/runtime/gateway/port",
|
||||
headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"port": 9345,
|
||||
"is_running": True,
|
||||
"ws_url": "wss://runtime.example:9345",
|
||||
}
|
||||
|
||||
|
||||
def test_runtime_service_get_runtime_config(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "demo"
|
||||
state_dir = run_dir / "state"
|
||||
state_dir.mkdir(parents=True)
|
||||
(run_dir / "BOOTSTRAP.md").write_text(
|
||||
"---\n"
|
||||
"tickers:\n"
|
||||
" - AAPL\n"
|
||||
"schedule_mode: intraday\n"
|
||||
"interval_minutes: 30\n"
|
||||
"trigger_time: '10:00'\n"
|
||||
"max_comm_cycles: 3\n"
|
||||
"enable_memory: true\n"
|
||||
"---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(state_dir / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "demo",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "10:00",
|
||||
"max_comm_cycles": 3,
|
||||
"enable_memory": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
runtime_module.get_runtime_state().gateway_port = 8765
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/runtime/config")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["run_id"] == "demo"
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 30
|
||||
assert payload["resolved"]["enable_memory"] is True
|
||||
|
||||
|
||||
def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "demo"
|
||||
state_dir = run_dir / "state"
|
||||
state_dir.mkdir(parents=True)
|
||||
(run_dir / "BOOTSTRAP.md").write_text(
|
||||
"---\n"
|
||||
"tickers:\n"
|
||||
" - AAPL\n"
|
||||
"schedule_mode: daily\n"
|
||||
"interval_minutes: 60\n"
|
||||
"trigger_time: '09:30'\n"
|
||||
"max_comm_cycles: 2\n"
|
||||
"---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(state_dir / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "demo",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"max_comm_cycles": 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _DummyContext:
|
||||
def __init__(self):
|
||||
self.bootstrap_values = {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"max_comm_cycles": 2,
|
||||
}
|
||||
|
||||
class _DummyManager:
|
||||
def __init__(self):
|
||||
self.config_name = "demo"
|
||||
self.bootstrap = dict(_DummyContext().bootstrap_values)
|
||||
self.context = _DummyContext()
|
||||
|
||||
def _persist_snapshot(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
runtime_module.get_runtime_state().runtime_manager = _DummyManager()
|
||||
runtime_module.get_runtime_state().gateway_port = 8765
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.put(
|
||||
"/api/runtime/config",
|
||||
json={
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 15,
|
||||
"trigger_time": "10:15",
|
||||
"max_comm_cycles": 4,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 15
|
||||
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_prune_old_timestamped_runs_keeps_named_runs(monkeypatch, tmp_path):
|
||||
runs_dir = tmp_path / "runs"
|
||||
runs_dir.mkdir()
|
||||
|
||||
keep_dirs = ["20260324_110000", "20260324_120000"]
|
||||
prune_dir = "20260324_100000"
|
||||
named_dir = "smoke_fullstack"
|
||||
|
||||
for name in [*keep_dirs, prune_dir, named_dir]:
|
||||
(runs_dir / name).mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
pruned = runtime_module._prune_old_timestamped_runs(keep=1, exclude_run_ids={"20260324_120000"})
|
||||
|
||||
assert prune_dir in pruned
|
||||
assert (runs_dir / named_dir).exists()
|
||||
assert (runs_dir / "20260324_120000").exists()
|
||||
assert (runs_dir / "20260324_110000").exists()
|
||||
|
||||
|
||||
def test_runtime_cleanup_endpoint_prunes_old_runs(monkeypatch, tmp_path):
|
||||
runs_dir = tmp_path / "runs"
|
||||
runs_dir.mkdir()
|
||||
|
||||
for name in ["20260324_090000", "20260324_100000", "20260324_110000", "smoke_fullstack"]:
|
||||
(runs_dir / name).mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: False)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.post("/api/runtime/cleanup?keep=1")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["status"] == "ok"
|
||||
assert sorted(payload["pruned_run_ids"]) == ["20260324_090000", "20260324_100000"]
|
||||
assert (runs_dir / "20260324_110000").exists()
|
||||
assert (runs_dir / "smoke_fullstack").exists()
|
||||
|
||||
|
||||
def test_runtime_history_lists_recent_runs(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "20260324_120000"
|
||||
(run_dir / "state").mkdir(parents=True)
|
||||
(run_dir / "team_dashboard").mkdir(parents=True)
|
||||
(run_dir / "state" / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "20260324_120000",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {"tickers": ["AAPL"]},
|
||||
},
|
||||
"events": [],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(run_dir / "team_dashboard" / "summary.json").write_text(
|
||||
json.dumps({"totalTrades": 3, "totalAssetValue": 123456.0}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/runtime/history?limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["runs"][0]["run_id"] == "20260324_120000"
|
||||
assert payload["runs"][0]["total_trades"] == 3
|
||||
|
||||
|
||||
def test_restore_run_assets_copies_state(monkeypatch, tmp_path):
|
||||
source_run = tmp_path / "runs" / "20260324_100000"
|
||||
(source_run / "team_dashboard").mkdir(parents=True)
|
||||
(source_run / "state").mkdir(parents=True)
|
||||
(source_run / "agents").mkdir(parents=True)
|
||||
(source_run / "team_dashboard" / "_internal_state.json").write_text("{}", encoding="utf-8")
|
||||
(source_run / "state" / "server_state.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
target_run = tmp_path / "runs" / "20260324_130000"
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
runtime_module._restore_run_assets("20260324_100000", target_run)
|
||||
|
||||
assert (target_run / "team_dashboard" / "_internal_state.json").exists()
|
||||
assert (target_run / "state" / "server_state.json").exists()
|
||||
|
||||
|
||||
def test_start_runtime_restore_reuses_historical_run_id(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "20260324_100000"
|
||||
(run_dir / "state").mkdir(parents=True)
|
||||
(run_dir / "state" / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "20260324_100000",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "now",
|
||||
"max_comm_cycles": 2,
|
||||
"initial_cash": 100000.0,
|
||||
"margin_requirement": 0.0,
|
||||
"enable_memory": False,
|
||||
"mode": "live",
|
||||
"poll_interval": 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _DummyManager:
|
||||
def __init__(self, config_name, run_dir, bootstrap):
|
||||
self.config_name = config_name
|
||||
self.run_dir = Path(run_dir)
|
||||
self.bootstrap = bootstrap
|
||||
self.context = None
|
||||
|
||||
def prepare_run(self):
|
||||
self.context = type(
|
||||
"Ctx",
|
||||
(),
|
||||
{
|
||||
"config_name": self.config_name,
|
||||
"run_dir": self.run_dir,
|
||||
"bootstrap_values": self.bootstrap,
|
||||
},
|
||||
)()
|
||||
return self.context
|
||||
|
||||
class _DummyProcess:
|
||||
def poll(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_find_available_port", lambda start_port=8765, max_port=9000: 8765)
|
||||
monkeypatch.setattr(runtime_module, "_start_gateway_process", lambda **kwargs: _DummyProcess())
|
||||
monkeypatch.setattr(runtime_module, "_stop_gateway", lambda: True)
|
||||
monkeypatch.setattr("backend.runtime.manager.TradingRuntimeManager", _DummyManager)
|
||||
runtime_state = runtime_module.get_runtime_state()
|
||||
runtime_state.gateway_process = None
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.post(
|
||||
"/api/runtime/start",
|
||||
json={
|
||||
"launch_mode": "restore",
|
||||
"restore_run_id": "20260324_100000",
|
||||
"tickers": [],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["run_id"] == "20260324_100000"
|
||||
assert payload["run_dir"] == str(run_dir)
|
||||
130
backend/tests/test_service_clients.py
Normal file
130
backend/tests/test_service_clients.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for split-aware shared service clients."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.client.control_client import ControlPlaneClient
|
||||
from shared.client.openclaw_client import OpenClawServiceClient
|
||||
from shared.client.runtime_client import RuntimeServiceClient
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _DummyAsyncClient:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def get(self, path, params=None):
|
||||
self.calls.append(("get", path, params))
|
||||
return _DummyResponse({"path": path, "params": params})
|
||||
|
||||
async def post(self, path, json=None):
|
||||
self.calls.append(("post", path, json))
|
||||
return _DummyResponse({"path": path, "json": json})
|
||||
|
||||
async def put(self, path, json=None):
|
||||
self.calls.append(("put", path, json))
|
||||
return _DummyResponse({"path": path, "json": json})
|
||||
|
||||
async def aclose(self):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_plane_client_hits_current_workspace_and_guard_routes():
|
||||
client = ControlPlaneClient()
|
||||
client._client = _DummyAsyncClient()
|
||||
|
||||
await client.list_workspaces()
|
||||
await client.get_workspace("demo")
|
||||
await client.list_agents("demo")
|
||||
await client.get_agent("demo", "risk_manager")
|
||||
await client.fetch_pending_approvals()
|
||||
await client.approve_pending_approval("ap-1")
|
||||
await client.deny_pending_approval("ap-2", reason="nope")
|
||||
|
||||
assert client._client.calls == [
|
||||
("get", "/workspaces", None),
|
||||
("get", "/workspaces/demo", None),
|
||||
("get", "/workspaces/demo/agents", None),
|
||||
("get", "/workspaces/demo/agents/risk_manager", None),
|
||||
("get", "/guard/pending", None),
|
||||
(
|
||||
"post",
|
||||
"/guard/approve",
|
||||
{
|
||||
"approval_id": "ap-1",
|
||||
"one_time": True,
|
||||
"expires_in_minutes": 30,
|
||||
},
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/guard/deny",
|
||||
{
|
||||
"approval_id": "ap-2",
|
||||
"reason": "nope",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_service_client_hits_current_runtime_routes():
|
||||
client = RuntimeServiceClient()
|
||||
client._client = _DummyAsyncClient()
|
||||
|
||||
await client.fetch_context()
|
||||
await client.fetch_agents()
|
||||
await client.fetch_events()
|
||||
await client.fetch_gateway_port()
|
||||
await client.start_runtime({"tickers": ["AAPL"]})
|
||||
await client.stop_runtime(force=True)
|
||||
await client.restart_runtime({"tickers": ["MSFT"]})
|
||||
await client.fetch_current_runtime()
|
||||
await client.get_runtime_config()
|
||||
await client.update_runtime_config({"schedule_mode": "intraday"})
|
||||
|
||||
assert client._client.calls == [
|
||||
("get", "/context", None),
|
||||
("get", "/agents", None),
|
||||
("get", "/events", None),
|
||||
("get", "/gateway/port", None),
|
||||
("post", "/start", {"tickers": ["AAPL"]}),
|
||||
("post", "/stop?force=true", None),
|
||||
("post", "/restart", {"tickers": ["MSFT"]}),
|
||||
("get", "/current", None),
|
||||
("get", "/config", None),
|
||||
("put", "/config", {"schedule_mode": "intraday"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openclaw_service_client_hits_current_openclaw_routes():
|
||||
client = OpenClawServiceClient()
|
||||
client._client = _DummyAsyncClient()
|
||||
|
||||
await client.fetch_status()
|
||||
await client.list_sessions()
|
||||
await client.get_session("main/session-1")
|
||||
await client.get_session_history("main/session-1", limit=5)
|
||||
await client.list_cron_jobs()
|
||||
await client.list_approvals()
|
||||
|
||||
assert client._client.calls == [
|
||||
("get", "/status", None),
|
||||
("get", "/sessions", None),
|
||||
("get", "/sessions/main/session-1", None),
|
||||
("get", "/sessions/main/session-1/history", {"limit": 5}),
|
||||
("get", "/cron", None),
|
||||
("get", "/approvals", None),
|
||||
]
|
||||
201
backend/tests/test_settlement.py
Normal file
201
backend/tests/test_settlement.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Test Settlement Coordinator and Baseline Calculations
|
||||
"""
|
||||
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
|
||||
|
||||
def test_baseline_equal_weight():
|
||||
"""Test equal weight baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
value = calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.equal_weight_initialized is True
|
||||
|
||||
|
||||
def test_baseline_market_cap_weighted():
|
||||
"""Test market cap weighted baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
market_caps = {"AAPL": 3e12, "MSFT": 2e12, "GOOGL": 1.5e12}
|
||||
|
||||
value = calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
market_caps,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.market_cap_initialized is True
|
||||
|
||||
|
||||
def test_momentum_scores():
|
||||
"""Test momentum score calculation"""
|
||||
tickers = ["AAPL", "MSFT"]
|
||||
prices_history = {
|
||||
"AAPL": [
|
||||
("2024-01-01", 100.0),
|
||||
("2024-01-02", 105.0),
|
||||
("2024-01-03", 110.0),
|
||||
],
|
||||
"MSFT": [
|
||||
("2024-01-01", 200.0),
|
||||
("2024-01-02", 195.0),
|
||||
("2024-01-03", 190.0),
|
||||
],
|
||||
}
|
||||
|
||||
scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
prices_history,
|
||||
lookback_days=2,
|
||||
)
|
||||
|
||||
assert scores["AAPL"] > 0
|
||||
assert scores["MSFT"] < 0
|
||||
|
||||
|
||||
def test_analyst_tracker_predictions():
|
||||
"""Test analyst prediction recording with structured format"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
final_predictions = [
|
||||
{
|
||||
"agent": "technical_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.8},
|
||||
{"ticker": "MSFT", "direction": "down", "confidence": 0.7},
|
||||
{"ticker": "GOOGL", "direction": "neutral", "confidence": 0.5},
|
||||
],
|
||||
},
|
||||
{
|
||||
"agent": "fundamentals_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.9},
|
||||
{"ticker": "MSFT", "direction": "up", "confidence": 0.6},
|
||||
{"ticker": "GOOGL", "direction": "down", "confidence": 0.75},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
assert "technical_analyst" in tracker.daily_predictions
|
||||
assert "fundamentals_analyst" in tracker.daily_predictions
|
||||
assert tracker.daily_predictions["technical_analyst"]["AAPL"] == "long"
|
||||
assert tracker.daily_predictions["technical_analyst"]["MSFT"] == "short"
|
||||
assert tracker.daily_predictions["technical_analyst"]["GOOGL"] == "hold"
|
||||
|
||||
|
||||
def test_analyst_evaluation():
|
||||
"""Test analyst prediction evaluation"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
tracker.daily_predictions = {
|
||||
"technical_analyst": {
|
||||
"AAPL": "long",
|
||||
"MSFT": "short",
|
||||
},
|
||||
}
|
||||
|
||||
open_prices = {"AAPL": 100.0, "MSFT": 200.0}
|
||||
close_prices = {"AAPL": 105.0, "MSFT": 195.0}
|
||||
|
||||
evaluations = tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
"2024-01-15",
|
||||
)
|
||||
|
||||
assert "technical_analyst" in evaluations
|
||||
eval_result = evaluations["technical_analyst"]
|
||||
assert eval_result["correct_predictions"] == 2
|
||||
assert eval_result["win_rate"] == 1.0
|
||||
|
||||
# Verify individual signals format
|
||||
assert "signals" in eval_result
|
||||
assert len(eval_result["signals"]) == 2
|
||||
for signal in eval_result["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
assert signal["date"] == "2024-01-15"
|
||||
|
||||
|
||||
def test_leaderboard_update():
|
||||
"""Test leaderboard update with evaluations"""
|
||||
leaderboard = [
|
||||
{
|
||||
"agentId": "technical_analyst",
|
||||
"name": "Technical Analyst",
|
||||
"rank": 0,
|
||||
"winRate": None,
|
||||
"bull": {"n": 0, "win": 0, "unknown": 0},
|
||||
"bear": {"n": 0, "win": 0, "unknown": 0},
|
||||
"signals": [],
|
||||
},
|
||||
]
|
||||
|
||||
evaluations = {
|
||||
"technical_analyst": {
|
||||
"total_predictions": 2,
|
||||
"correct_predictions": 1,
|
||||
"win_rate": 0.5,
|
||||
"bull": {"n": 1, "win": 1, "unknown": 0},
|
||||
"bear": {"n": 1, "win": 0, "unknown": 0},
|
||||
"hold": 0,
|
||||
"signals": [
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"signal": "bull",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": True,
|
||||
},
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"signal": "bear",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
updated = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
evaluations,
|
||||
)
|
||||
|
||||
assert updated[0]["bull"]["n"] == 1
|
||||
assert updated[0]["bull"]["win"] == 1
|
||||
assert updated[0]["winRate"] == 0.5
|
||||
assert len(updated[0]["signals"]) == 2
|
||||
|
||||
# Verify signal format matches frontend expectations
|
||||
for signal in updated[0]["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
32
backend/tests/test_shared_schema_bridge.py
Normal file
32
backend/tests/test_shared_schema_bridge.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Regression coverage for the shared schema bridge."""
|
||||
|
||||
from backend.data import schema as legacy_schema
|
||||
from shared import schema as shared_schema
|
||||
|
||||
|
||||
def test_backend_data_schema_reexports_shared_contracts():
|
||||
assert legacy_schema.Price is shared_schema.Price
|
||||
assert legacy_schema.PriceResponse is shared_schema.PriceResponse
|
||||
assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics
|
||||
assert legacy_schema.FinancialMetricsResponse is (
|
||||
shared_schema.FinancialMetricsResponse
|
||||
)
|
||||
assert legacy_schema.LineItem is shared_schema.LineItem
|
||||
assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse
|
||||
assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade
|
||||
assert legacy_schema.InsiderTradeResponse is (
|
||||
shared_schema.InsiderTradeResponse
|
||||
)
|
||||
assert legacy_schema.CompanyNews is shared_schema.CompanyNews
|
||||
assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse
|
||||
assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts
|
||||
assert legacy_schema.CompanyFactsResponse is (
|
||||
shared_schema.CompanyFactsResponse
|
||||
)
|
||||
assert legacy_schema.Position is shared_schema.Position
|
||||
assert legacy_schema.Portfolio is shared_schema.Portfolio
|
||||
assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal
|
||||
assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis
|
||||
assert legacy_schema.AgentStateData is shared_schema.AgentStateData
|
||||
assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata
|
||||
119
backend/tests/test_skills_cli.py
Normal file
119
backend/tests/test_skills_cli.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend import cli
|
||||
from backend.agents.skill_metadata import parse_skill_metadata
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.team_pipeline_config import (
|
||||
ensure_team_pipeline_config,
|
||||
load_team_pipeline_config,
|
||||
update_active_analysts,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_skill_metadata_extended_frontmatter(tmp_path):
|
||||
skill_dir = tmp_path / "demo_skill"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: demo_skill\n"
|
||||
"description: Demo description\n"
|
||||
"tools:\n"
|
||||
" - technical\n"
|
||||
"---\n\n"
|
||||
"# Demo Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
parsed = parse_skill_metadata(skill_dir, source="builtin")
|
||||
|
||||
assert parsed.skill_name == "demo_skill"
|
||||
assert parsed.description == "Demo description"
|
||||
assert parsed.tools == ["technical"]
|
||||
|
||||
|
||||
def test_update_agent_skill_overrides(tmp_path):
|
||||
manager = SkillsManager(project_root=tmp_path)
|
||||
asset_dir = manager.get_agent_asset_dir("demo", "risk_manager")
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "agent.yaml").write_text(
|
||||
"enabled_skills:\n"
|
||||
" - risk_review\n"
|
||||
"disabled_skills:\n"
|
||||
" - old_skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = manager.update_agent_skill_overrides(
|
||||
config_name="demo",
|
||||
agent_id="risk_manager",
|
||||
enable=["extra_guard"],
|
||||
disable=["risk_review"],
|
||||
)
|
||||
|
||||
assert result["enabled_skills"] == ["extra_guard"]
|
||||
assert result["disabled_skills"] == ["old_skill", "risk_review"]
|
||||
|
||||
|
||||
def test_skills_enable_disable_and_list(monkeypatch, tmp_path):
|
||||
builtin_root = tmp_path / "backend" / "skills" / "builtin"
|
||||
for name in ("risk_review", "extra_guard"):
|
||||
skill_dir = builtin_root / name
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
f"---\nname: {name}\ndescription: {name} desc\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
printed = []
|
||||
monkeypatch.setattr(cli, "get_project_root", lambda: tmp_path)
|
||||
monkeypatch.setattr(cli.console, "print", lambda value: printed.append(value))
|
||||
|
||||
cli.skills_enable(agent_id="risk_manager", skill="extra_guard", config_name="demo")
|
||||
cli.skills_disable(agent_id="risk_manager", skill="risk_review", config_name="demo")
|
||||
cli.skills_list(config_name="demo", agent_id="risk_manager")
|
||||
|
||||
text_dump = "\n".join(str(item) for item in printed)
|
||||
assert "Enabled" in text_dump
|
||||
assert "Disabled" in text_dump
|
||||
assert any(getattr(item, "title", None) == "Skill Catalog" for item in printed)
|
||||
|
||||
|
||||
def test_install_external_skill_for_agent(tmp_path):
|
||||
manager = SkillsManager(project_root=tmp_path)
|
||||
skill_dir = tmp_path / "downloaded" / "new_skill"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: new_skill\n"
|
||||
"description: external skill\n"
|
||||
"---\n\n"
|
||||
"# New Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = manager.install_external_skill_for_agent(
|
||||
config_name="demo",
|
||||
agent_id="risk_manager",
|
||||
source=str(skill_dir),
|
||||
activate=True,
|
||||
)
|
||||
|
||||
assert result["skill_name"] == "new_skill"
|
||||
target = manager.get_agent_local_root("demo", "risk_manager") / "new_skill"
|
||||
assert target.exists()
|
||||
|
||||
|
||||
def test_team_pipeline_active_analyst_updates(tmp_path):
|
||||
project_root = tmp_path
|
||||
ensure_team_pipeline_config(
|
||||
project_root=project_root,
|
||||
config_name="demo",
|
||||
default_analysts=["fundamentals_analyst", "technical_analyst"],
|
||||
)
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name="demo",
|
||||
available_analysts=["fundamentals_analyst", "technical_analyst"],
|
||||
remove=["technical_analyst"],
|
||||
)
|
||||
config = load_team_pipeline_config(project_root, "demo")
|
||||
assert config["discussion"]["active_analysts"] == ["fundamentals_analyst"]
|
||||
22
backend/tests/test_technical_signals.py
Normal file
22
backend/tests/test_technical_signals.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for structured technical analyzer."""
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from backend.tools.technical_signals import StockTechnicalAnalyzer
|
||||
|
||||
|
||||
def test_technical_analyzer_detects_bullish_trend():
|
||||
df = pd.DataFrame(
|
||||
{
|
||||
"time": pd.date_range("2024-01-01", periods=40, freq="D"),
|
||||
"close": [100 + i for i in range(40)],
|
||||
},
|
||||
)
|
||||
analyzer = StockTechnicalAnalyzer()
|
||||
|
||||
result = analyzer.analyze("AAPL", df)
|
||||
|
||||
assert result.current_price == 139.0
|
||||
assert result.trend in {"BULLISH", "STRONG BULLISH"}
|
||||
assert result.momentum_20d_pct > 0
|
||||
47
backend/tests/test_trading_domain.py
Normal file
47
backend/tests/test_trading_domain.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the trading domain helpers."""
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
|
||||
|
||||
def test_trading_domain_payload_wrappers(monkeypatch):
|
||||
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
|
||||
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
|
||||
|
||||
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"prices": [{"close": 1}],
|
||||
}
|
||||
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"financial_metrics": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"news": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"insider_trades": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-16",
|
||||
"market_cap": 2.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_get_market_status_payload_uses_market_service(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def __init__(self, tickers):
|
||||
self.tickers = tickers
|
||||
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
|
||||
|
||||
assert trading_domain.get_market_status_payload() == {
|
||||
"status": "open",
|
||||
"status_text": "Open",
|
||||
}
|
||||
231
backend/tests/test_trading_service_app.py
Normal file
231
backend/tests/test_trading_service_app.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted trading service app surface."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.trading_service import create_app
|
||||
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
|
||||
|
||||
|
||||
def test_trading_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/prices" in paths
|
||||
assert "/api/financials" in paths
|
||||
assert "/api/news" in paths
|
||||
assert "/api/insider-trades" in paths
|
||||
assert "/api/market/status" in paths
|
||||
assert "/api/market-cap" in paths
|
||||
assert "/api/line-items" in paths
|
||||
|
||||
|
||||
def test_trading_service_prices_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_prices_payload",
|
||||
lambda ticker, start_date, end_date: {
|
||||
"ticker": ticker,
|
||||
"prices": [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/prices",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-20",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["ticker"] == "AAPL"
|
||||
assert response.json()["prices"][0]["close"] == 2.0
|
||||
|
||||
|
||||
def test_trading_service_financials_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_financials_payload",
|
||||
lambda ticker, end_date, period, limit: {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/financials",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["financial_metrics"][0]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_news_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_insider_trades_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
news_response = client.get(
|
||||
"/api/news",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
insider_response = client.get(
|
||||
"/api/insider-trades",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert news_response.status_code == 200
|
||||
assert news_response.json()["news"][0]["title"] == "News title"
|
||||
assert insider_response.status_code == 200
|
||||
assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_status_payload",
|
||||
lambda: _FakeMarketService().get_market_status(),
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/market/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "open", "status_text": "Open"}
|
||||
|
||||
|
||||
def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_cap_payload",
|
||||
lambda ticker, end_date: {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": 3.5e12,
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/market-cap",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-20",
|
||||
"market_cap": 3.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_trading_service_line_items_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_line_items_payload",
|
||||
lambda ticker, line_items, end_date, period, limit: {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/line-items",
|
||||
params=[
|
||||
("ticker", "AAPL"),
|
||||
("line_items", "free_cash_flow"),
|
||||
("end_date", "2026-03-20"),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["search_results"][0]["ticker"] == "AAPL"
|
||||
assert response.json()["search_results"][0]["free_cash_flow"] == 123.0
|
||||
106
backend/tests/test_valuation_scripts.py
Normal file
106
backend/tests/test_valuation_scripts.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.skills.builtin.valuation_review.scripts.dcf_report import (
|
||||
build_dcf_report,
|
||||
)
|
||||
from backend.skills.builtin.valuation_review.scripts.multiple_valuation_report import (
|
||||
build_ev_ebitda_report,
|
||||
build_residual_income_report,
|
||||
)
|
||||
from backend.skills.builtin.valuation_review.scripts.owner_earnings_report import (
|
||||
build_owner_earnings_report,
|
||||
)
|
||||
|
||||
|
||||
def test_build_dcf_report_renders_assessment():
|
||||
report = build_dcf_report(
|
||||
[
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"current_fcf": 100.0,
|
||||
"growth_rate": 0.05,
|
||||
"market_cap": 900.0,
|
||||
"discount_rate": 0.10,
|
||||
"terminal_growth": 0.03,
|
||||
"num_years": 5,
|
||||
},
|
||||
],
|
||||
"2026-03-17",
|
||||
)
|
||||
|
||||
assert "DCF Valuation Analysis (2026-03-17)" in report
|
||||
assert "AAPL:" in report
|
||||
assert "Market Cap: $900" in report
|
||||
assert "Value Gap:" in report
|
||||
|
||||
|
||||
def test_build_owner_earnings_report_handles_errors():
|
||||
report = build_owner_earnings_report(
|
||||
[
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"error": "Negative owner earnings ($-50)",
|
||||
},
|
||||
],
|
||||
"2026-03-17",
|
||||
)
|
||||
|
||||
assert "MSFT: Negative owner earnings ($-50)" in report
|
||||
|
||||
|
||||
def test_multiple_valuation_reports_render_expected_sections():
|
||||
ev_report = build_ev_ebitda_report(
|
||||
[
|
||||
{
|
||||
"ticker": "NVDA",
|
||||
"current_multiple": 18.0,
|
||||
"median_multiple": 20.0,
|
||||
"current_ebitda": 50.0,
|
||||
"market_cap": 800.0,
|
||||
"net_debt": 100.0,
|
||||
},
|
||||
],
|
||||
"2026-03-17",
|
||||
)
|
||||
residual_report = build_residual_income_report(
|
||||
[
|
||||
{
|
||||
"ticker": "META",
|
||||
"book_value": 200.0,
|
||||
"initial_ri": 30.0,
|
||||
"market_cap": 300.0,
|
||||
"cost_of_equity": 0.10,
|
||||
"bv_growth": 0.03,
|
||||
"terminal_growth": 0.03,
|
||||
"num_years": 5,
|
||||
"margin_of_safety": 0.20,
|
||||
},
|
||||
],
|
||||
"2026-03-17",
|
||||
)
|
||||
|
||||
assert "EV/EBITDA Valuation (2026-03-17)" in ev_report
|
||||
assert "NVDA:" in ev_report
|
||||
assert "Residual Income Valuation (2026-03-17)" in residual_report
|
||||
assert "META:" in residual_report
|
||||
|
||||
|
||||
def test_prepare_active_skills_copies_skill_scripts(tmp_path):
|
||||
builtin_skill = tmp_path / "backend" / "skills" / "builtin" / "valuation_review"
|
||||
scripts_dir = builtin_skill / "scripts"
|
||||
scripts_dir.mkdir(parents=True, exist_ok=True)
|
||||
(builtin_skill / "SKILL.md").write_text(
|
||||
"---\nname: 估值分析\ndescription: desc\nversion: 1.0.0\n---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(scripts_dir / "dcf_report.py").write_text("print('ok')\n", encoding="utf-8")
|
||||
|
||||
manager = SkillsManager(project_root=tmp_path)
|
||||
active_map = manager.prepare_active_skills(
|
||||
config_name="demo",
|
||||
agent_defaults={"valuation_analyst": ["valuation_review"]},
|
||||
)
|
||||
|
||||
active_dir = active_map["valuation_analyst"][0]
|
||||
assert (active_dir / "scripts" / "dcf_report.py").exists()
|
||||
Reference in New Issue
Block a user