feat: initial commit - EvoTraders project
量化交易多智能体系统,包含: - 分析师、投资组合经理、风险经理等智能体 - 股票分析、投资组合管理、风险控制工具 - React 前端界面 - FastAPI 后端服务 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
580
backend/tests/test_agents.py
Normal file
580
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
class TestAnalystAgent:
|
||||
def test_init_valid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="technical_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.analyst_type_key == "technical_analyst"
|
||||
assert agent.name == "technical_analyst_analyst"
|
||||
assert agent.analyst_persona == "Technical Analyst"
|
||||
|
||||
def test_init_invalid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AnalystAgent(
|
||||
analyst_type="invalid_type",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert "Unknown analyst type" in str(excinfo.value)
|
||||
|
||||
def test_init_custom_agent_id(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="fundamentals_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
agent_id="custom_analyst_id",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_analyst_id"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="sentiment_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestPMAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "portfolio_manager"
|
||||
assert agent.portfolio["cash"] == 100000.0
|
||||
assert agent.portfolio["positions"] == {}
|
||||
assert agent.portfolio["margin_requirement"] == 0.25
|
||||
|
||||
def test_init_custom_cash(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=50000.0,
|
||||
margin_requirement=0.5,
|
||||
)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert agent.portfolio["margin_requirement"] == 0.5
|
||||
|
||||
def test_get_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=75000.0,
|
||||
)
|
||||
|
||||
state = agent.get_portfolio_state()
|
||||
|
||||
assert state["cash"] == 75000.0
|
||||
assert state is not agent.portfolio # Should be a copy
|
||||
|
||||
def test_load_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
new_portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0, "long_cost_basis": 150.0},
|
||||
},
|
||||
"margin_used": 1000.0,
|
||||
}
|
||||
|
||||
agent.load_portfolio_state(new_portfolio)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert "AAPL" in agent.portfolio["positions"]
|
||||
|
||||
def test_update_portfolio(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent.update_portfolio({"cash": 80000.0})
|
||||
assert agent.portfolio["cash"] == 80000.0
|
||||
|
||||
def _get_text_from_tool_response(self, result):
|
||||
"""Helper to extract text from ToolResponse content"""
|
||||
content = result.content[0]
|
||||
if hasattr(content, "text"):
|
||||
return content.text
|
||||
elif isinstance(content, dict):
|
||||
return content.get("text", "")
|
||||
return str(content)
|
||||
|
||||
def test_make_decision_long(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=100,
|
||||
confidence=80,
|
||||
reasoning="Strong fundamentals",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["AAPL"]["action"] == "long"
|
||||
assert agent._decisions["AAPL"]["quantity"] == 100
|
||||
|
||||
def test_make_decision_hold(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="GOOGL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
confidence=50,
|
||||
reasoning="Neutral outlook",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["GOOGL"]["action"] == "hold"
|
||||
assert agent._decisions["GOOGL"]["quantity"] == 0
|
||||
|
||||
def test_make_decision_invalid_action(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="invalid",
|
||||
quantity=10,
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Invalid action" in text
|
||||
|
||||
def test_get_decisions(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent._make_decision("AAPL", "long", 100)
|
||||
agent._make_decision("GOOGL", "short", 50)
|
||||
|
||||
decisions = agent.get_decisions()
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"]["action"] == "long"
|
||||
assert decisions["GOOGL"]["action"] == "short"
|
||||
|
||||
|
||||
class TestRiskAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "risk_manager"
|
||||
|
||||
def test_init_custom_name(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
name="custom_risk_manager",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_risk_manager"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestStorageService:
|
||||
def test_calculate_portfolio_value_cash_only(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {"cash": 100000.0, "positions": {}, "margin_used": 0.0}
|
||||
prices = {}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 100000.0
|
||||
|
||||
def test_calculate_portfolio_value_with_positions(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0},
|
||||
"GOOGL": {"long": 0, "short": 10},
|
||||
},
|
||||
"margin_used": 5000.0,
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 69000.0
|
||||
|
||||
def test_update_dashboard_after_cycle(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 90000.0,
|
||||
"positions": {"AAPL": {"long": 50, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
prices = {"AAPL": 200.0}
|
||||
|
||||
storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date="2024-01-15",
|
||||
executed_trades=[
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"action": "long",
|
||||
"quantity": 50,
|
||||
"price": 200.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary is not None
|
||||
assert summary["totalAssetValue"] == 100000.0 # 90000 + 50*200
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert holdings is not None
|
||||
assert len(holdings) > 0
|
||||
|
||||
trades = storage.load_file("trades")
|
||||
assert trades is not None
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["ticker"] == "AAPL"
|
||||
assert trades[0]["qty"] == 50
|
||||
assert trades[0]["price"] == 200.0
|
||||
|
||||
def test_generate_summary(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
"equity_history": [{"t": 1000, "v": 100000}],
|
||||
"all_trades": [],
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_summary(state, 100000.0, prices)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary["totalAssetValue"] == 100000.0
|
||||
assert summary["totalReturn"] == 0.0
|
||||
|
||||
def test_generate_holdings(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_holdings(state, prices)
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert len(holdings) == 2 # AAPL + CASH
|
||||
|
||||
aapl_holding = next(
|
||||
(h for h in holdings if h["ticker"] == "AAPL"),
|
||||
None,
|
||||
)
|
||||
assert aapl_holding is not None
|
||||
assert aapl_holding["quantity"] == 100
|
||||
assert aapl_holding["currentPrice"] == 500.0
|
||||
|
||||
|
||||
class TestTradeExecutor:
|
||||
def test_execute_trade_long(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
assert executor.portfolio["cash"] == 98500.0 # 100000 - 10*150
|
||||
|
||||
def test_execute_trade_short(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {
|
||||
"AAPL": {
|
||||
"long": 50,
|
||||
"short": 0,
|
||||
"long_cost_basis": 100.0,
|
||||
"short_cost_basis": 0.0,
|
||||
},
|
||||
},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="short",
|
||||
quantity=30,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 20 # 50 - 30
|
||||
|
||||
def test_execute_trade_hold(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor()
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "No trade needed"
|
||||
|
||||
|
||||
class TestPipelineExecution:
|
||||
def test_execute_decisions(self):
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
pm = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=[],
|
||||
risk_manager=MagicMock(),
|
||||
portfolio_manager=pm,
|
||||
max_comm_cycles=0,
|
||||
)
|
||||
|
||||
decisions = {
|
||||
"AAPL": {"action": "long", "quantity": 10},
|
||||
"GOOGL": {"action": "short", "quantity": 5},
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
result = pipeline._execute_decisions(decisions, prices, "2024-01-15")
|
||||
|
||||
assert len(result["executed_trades"]) == 2
|
||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||
assert result["executed_trades"][0]["quantity"] == 10
|
||||
assert pm.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
|
||||
|
||||
class TestMsgContentIsString:
|
||||
def test_msg_content_string(self):
|
||||
msg = Msg(name="test", content="simple string", role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
def test_msg_content_json_string(self):
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
parsed = json.loads(msg.content)
|
||||
assert parsed["key"] == "value"
|
||||
|
||||
def test_msg_content_should_not_be_dict(self):
|
||||
data = {"key": "value"}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="assistant")
|
||||
|
||||
assert not isinstance(msg.content, dict)
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
438
backend/tests/test_market_service.py
Normal file
438
backend/tests/test_market_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
|
||||
class TestMockPriceManager:
|
||||
def test_init_default(self):
|
||||
manager = MockPriceManager()
|
||||
|
||||
assert manager.poll_interval == 10
|
||||
assert manager.volatility == 0.5
|
||||
assert manager.running is False
|
||||
assert len(manager.subscribed_symbols) == 0
|
||||
|
||||
def test_init_custom(self):
|
||||
manager = MockPriceManager(poll_interval=5, volatility=1.0)
|
||||
|
||||
assert manager.poll_interval == 5
|
||||
assert manager.volatility == 1.0
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
assert manager.base_prices["AAPL"] == 237.50 # default price
|
||||
assert manager.base_prices["MSFT"] == 425.30 # default price
|
||||
|
||||
def test_subscribe_with_base_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
assert manager.base_prices["AAPL"] == 100.0
|
||||
assert manager.open_prices["AAPL"] == 100.0
|
||||
assert manager.latest_prices["AAPL"] == 100.0
|
||||
|
||||
def test_subscribe_unknown_symbol(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["UNKNOWN"])
|
||||
|
||||
assert "UNKNOWN" in manager.subscribed_symbols
|
||||
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = MockPriceManager()
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
def test_generate_price_update_within_bounds(self):
|
||||
manager = MockPriceManager(volatility=0.5)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
for _ in range(100):
|
||||
new_price = manager._generate_price_update("AAPL")
|
||||
# Should be within +/-10% of open
|
||||
assert 90.0 <= new_price <= 110.0
|
||||
|
||||
def test_update_prices_triggers_callback(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
manager._update_prices()
|
||||
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert "price" in call_args
|
||||
assert "timestamp" in call_args
|
||||
|
||||
def test_start_stop(self):
|
||||
manager = MockPriceManager(poll_interval=1)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1) # let thread start
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = MockPriceManager()
|
||||
manager.start()
|
||||
|
||||
assert (
|
||||
manager.running is False
|
||||
) # should not start without subscriptions
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 100.0
|
||||
|
||||
def test_get_latest_price_unknown(self):
|
||||
manager = MockPriceManager()
|
||||
price = manager.get_latest_price("UNKNOWN")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_latest_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(
|
||||
["AAPL", "MSFT"],
|
||||
base_prices={"AAPL": 100.0, "MSFT": 200.0},
|
||||
)
|
||||
|
||||
prices = manager.get_all_latest_prices()
|
||||
assert prices["AAPL"] == 100.0
|
||||
assert prices["MSFT"] == 200.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
manager.latest_prices["AAPL"] = 105.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
# Open price should change (based on latest with small gap)
|
||||
assert manager.open_prices["AAPL"] != 100.0
|
||||
|
||||
def test_set_base_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.set_base_price("AAPL", 150.0)
|
||||
|
||||
assert manager.base_prices["AAPL"] == 150.0
|
||||
assert manager.open_prices["AAPL"] == 150.0
|
||||
assert manager.latest_prices["AAPL"] == 150.0
|
||||
|
||||
|
||||
class TestPollingPriceManager:
|
||||
def test_init(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
|
||||
|
||||
assert manager.api_key == "test_key"
|
||||
assert manager.poll_interval == 30
|
||||
assert manager.running is False
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
@patch.object(PollingPriceManager, "_fetch_prices")
|
||||
def test_start_stop(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.start()
|
||||
|
||||
assert manager.running is False
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.latest_prices["AAPL"] = 150.0
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_open_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 148.0
|
||||
|
||||
price = manager.get_open_price("AAPL")
|
||||
assert price == 148.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 150.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
def test_init_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
assert service.tickers == ["AAPL", "MSFT"]
|
||||
assert service.poll_interval == 10
|
||||
assert service.mock_mode is True
|
||||
assert service.running is False
|
||||
|
||||
def test_init_real_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key="test_key",
|
||||
)
|
||||
|
||||
assert service.mock_mode is False
|
||||
assert service.api_key == "test_key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert service.running is True
|
||||
assert service._price_manager is not None
|
||||
assert isinstance(service._price_manager, MockPriceManager)
|
||||
|
||||
service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert "API key required" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
assert service.running is True
|
||||
|
||||
# Start again should not fail
|
||||
await service.start(broadcast_func)
|
||||
|
||||
service.stop()
|
||||
|
||||
def test_stop(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
service.running = True
|
||||
service._price_manager = MagicMock()
|
||||
|
||||
service.stop()
|
||||
|
||||
assert service.running is False
|
||||
assert service._price_manager is None
|
||||
|
||||
def test_stop_when_not_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
service.stop()
|
||||
assert service.running is False
|
||||
|
||||
def test_get_price_sync(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
||||
|
||||
price = service.get_price_sync("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_price_sync_not_found(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
|
||||
price = service.get_price_sync("MSFT")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_prices(self):
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0}
|
||||
service.cache["MSFT"] = {"price": 400.0}
|
||||
|
||||
prices = service.get_all_prices()
|
||||
|
||||
assert prices["AAPL"] == 150.0
|
||||
assert prices["MSFT"] == 400.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = AsyncMock()
|
||||
|
||||
price_data = {
|
||||
"symbol": "AAPL",
|
||||
"price": 150.0,
|
||||
"open": 148.0,
|
||||
"timestamp": 1234567890,
|
||||
}
|
||||
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
service._broadcast_func.assert_called_once()
|
||||
call_args = service._broadcast_func.call_args[0][0]
|
||||
assert call_args["type"] == "price_update"
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert call_args["price"] == 150.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update_no_func(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = None
|
||||
|
||||
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
||||
|
||||
# Should not raise
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_callback_thread_safety(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
received_prices = []
|
||||
|
||||
async def capture_broadcast(msg):
|
||||
received_prices.append(msg)
|
||||
|
||||
await service.start(capture_broadcast)
|
||||
|
||||
# Wait for at least one price update
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received at least one price update
|
||||
assert len(received_prices) >= 1
|
||||
assert received_prices[0]["type"] == "price_update"
|
||||
|
||||
|
||||
class TestMarketServiceIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mock_cycle(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
async def collect_messages(msg):
|
||||
messages.append(msg)
|
||||
|
||||
await service.start(collect_messages)
|
||||
|
||||
# Wait for price updates
|
||||
await asyncio.sleep(2.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received multiple price updates
|
||||
assert len(messages) >= 2
|
||||
|
||||
# Check message structure
|
||||
symbols_seen = set()
|
||||
for msg in messages:
|
||||
assert msg["type"] == "price_update"
|
||||
assert "symbol" in msg
|
||||
assert "price" in msg
|
||||
assert "ret" in msg
|
||||
symbols_seen.add(msg["symbol"])
|
||||
|
||||
# Should have prices for both tickers
|
||||
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
201
backend/tests/test_settlement.py
Normal file
201
backend/tests/test_settlement.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Test Settlement Coordinator and Baseline Calculations
|
||||
"""
|
||||
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
|
||||
|
||||
def test_baseline_equal_weight():
|
||||
"""Test equal weight baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
value = calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.equal_weight_initialized is True
|
||||
|
||||
|
||||
def test_baseline_market_cap_weighted():
|
||||
"""Test market cap weighted baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
market_caps = {"AAPL": 3e12, "MSFT": 2e12, "GOOGL": 1.5e12}
|
||||
|
||||
value = calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
market_caps,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.market_cap_initialized is True
|
||||
|
||||
|
||||
def test_momentum_scores():
|
||||
"""Test momentum score calculation"""
|
||||
tickers = ["AAPL", "MSFT"]
|
||||
prices_history = {
|
||||
"AAPL": [
|
||||
("2024-01-01", 100.0),
|
||||
("2024-01-02", 105.0),
|
||||
("2024-01-03", 110.0),
|
||||
],
|
||||
"MSFT": [
|
||||
("2024-01-01", 200.0),
|
||||
("2024-01-02", 195.0),
|
||||
("2024-01-03", 190.0),
|
||||
],
|
||||
}
|
||||
|
||||
scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
prices_history,
|
||||
lookback_days=2,
|
||||
)
|
||||
|
||||
assert scores["AAPL"] > 0
|
||||
assert scores["MSFT"] < 0
|
||||
|
||||
|
||||
def test_analyst_tracker_predictions():
|
||||
"""Test analyst prediction recording with structured format"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
final_predictions = [
|
||||
{
|
||||
"agent": "technical_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.8},
|
||||
{"ticker": "MSFT", "direction": "down", "confidence": 0.7},
|
||||
{"ticker": "GOOGL", "direction": "neutral", "confidence": 0.5},
|
||||
],
|
||||
},
|
||||
{
|
||||
"agent": "fundamentals_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.9},
|
||||
{"ticker": "MSFT", "direction": "up", "confidence": 0.6},
|
||||
{"ticker": "GOOGL", "direction": "down", "confidence": 0.75},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
assert "technical_analyst" in tracker.daily_predictions
|
||||
assert "fundamentals_analyst" in tracker.daily_predictions
|
||||
assert tracker.daily_predictions["technical_analyst"]["AAPL"] == "long"
|
||||
assert tracker.daily_predictions["technical_analyst"]["MSFT"] == "short"
|
||||
assert tracker.daily_predictions["technical_analyst"]["GOOGL"] == "hold"
|
||||
|
||||
|
||||
def test_analyst_evaluation():
|
||||
"""Test analyst prediction evaluation"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
tracker.daily_predictions = {
|
||||
"technical_analyst": {
|
||||
"AAPL": "long",
|
||||
"MSFT": "short",
|
||||
},
|
||||
}
|
||||
|
||||
open_prices = {"AAPL": 100.0, "MSFT": 200.0}
|
||||
close_prices = {"AAPL": 105.0, "MSFT": 195.0}
|
||||
|
||||
evaluations = tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
"2024-01-15",
|
||||
)
|
||||
|
||||
assert "technical_analyst" in evaluations
|
||||
eval_result = evaluations["technical_analyst"]
|
||||
assert eval_result["correct_predictions"] == 2
|
||||
assert eval_result["win_rate"] == 1.0
|
||||
|
||||
# Verify individual signals format
|
||||
assert "signals" in eval_result
|
||||
assert len(eval_result["signals"]) == 2
|
||||
for signal in eval_result["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
assert signal["date"] == "2024-01-15"
|
||||
|
||||
|
||||
def test_leaderboard_update():
|
||||
"""Test leaderboard update with evaluations"""
|
||||
leaderboard = [
|
||||
{
|
||||
"agentId": "technical_analyst",
|
||||
"name": "Technical Analyst",
|
||||
"rank": 0,
|
||||
"winRate": None,
|
||||
"bull": {"n": 0, "win": 0, "unknown": 0},
|
||||
"bear": {"n": 0, "win": 0, "unknown": 0},
|
||||
"signals": [],
|
||||
},
|
||||
]
|
||||
|
||||
evaluations = {
|
||||
"technical_analyst": {
|
||||
"total_predictions": 2,
|
||||
"correct_predictions": 1,
|
||||
"win_rate": 0.5,
|
||||
"bull": {"n": 1, "win": 1, "unknown": 0},
|
||||
"bear": {"n": 1, "win": 0, "unknown": 0},
|
||||
"hold": 0,
|
||||
"signals": [
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"signal": "bull",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": True,
|
||||
},
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"signal": "bear",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
updated = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
evaluations,
|
||||
)
|
||||
|
||||
assert updated[0]["bull"]["n"] == 1
|
||||
assert updated[0]["bull"]["win"] == 1
|
||||
assert updated[0]["winRate"] == 0.5
|
||||
assert len(updated[0]["signals"]) == 2
|
||||
|
||||
# Verify signal format matches frontend expectations
|
||||
for signal in updated[0]["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
Reference in New Issue
Block a user