feat: initial commit - EvoTraders project
量化交易多智能体系统,包含: - 分析师、投资组合经理、风险经理等智能体 - 股票分析、投资组合管理、风险控制工具 - React 前端界面 - FastAPI 后端服务 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
0
backend/__init__.py
Normal file
0
backend/__init__.py
Normal file
6
backend/agents/__init__.py
Normal file
6
backend/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .analyst import AnalystAgent
|
||||
from .portfolio_manager import PMAgent
|
||||
from .risk_manager import RiskAgent
|
||||
|
||||
__all__ = ["AnalystAgent", "PMAgent", "RiskAgent"]
|
||||
133
backend/agents/analyst.py
Normal file
133
backend/agents/analyst.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Analyst Agent - Based on AgentScope ReActAgent
|
||||
Performs analysis using tools and LLM
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg
|
||||
|
||||
from ..config.constants import ANALYST_TYPES
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class AnalystAgent(ReActAgent):
|
||||
"""
|
||||
Analyst Agent - Uses LLM for tool selection and analysis
|
||||
Inherits from AgentScope's ReActAgent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
analyst_type: str,
|
||||
toolkit: Any,
|
||||
model: Any,
|
||||
formatter: Any,
|
||||
agent_id: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Analyst Agent
|
||||
|
||||
Args:
|
||||
analyst_type: Type of analyst (e.g., "fundamentals", etc.)
|
||||
toolkit: AgentScope Toolkit instance
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
agent_id: Agent ID (defaults to "{analyst_type}_analyst")
|
||||
config: Configuration dictionary
|
||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||
"""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
raise ValueError(
|
||||
f"Unknown analyst type: {analyst_type}. "
|
||||
f"Must be one of: {list(ANALYST_TYPES.keys())}",
|
||||
)
|
||||
|
||||
self.analyst_type_key = analyst_type
|
||||
self.analyst_persona = ANALYST_TYPES[analyst_type]["display_name"]
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = analyst_type
|
||||
|
||||
self.config = config or {}
|
||||
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
kwargs = {
|
||||
"name": agent_id,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "static_control"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt for analyst"""
|
||||
personas_config = _prompt_loader.load_yaml_config(
|
||||
"analyst",
|
||||
"personas",
|
||||
)
|
||||
persona = personas_config.get(self.analyst_type_key, {})
|
||||
|
||||
# Get focus items and format as bullet points
|
||||
focus_items = persona.get("focus", [])
|
||||
focus_text = "\n".join(f"- {item}" for item in focus_items)
|
||||
|
||||
# Get description
|
||||
description = persona.get("description", "").strip()
|
||||
|
||||
return _prompt_loader.load_prompt(
|
||||
"analyst",
|
||||
"system",
|
||||
variables={
|
||||
"analyst_type": self.analyst_persona,
|
||||
"focus": focus_text,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Override reply method to add progress tracking
|
||||
|
||||
Args:
|
||||
x: Input message (content must be str)
|
||||
|
||||
Returns:
|
||||
Response message (content is str)
|
||||
"""
|
||||
ticker = None
|
||||
if x and hasattr(x, "metadata") and x.metadata:
|
||||
ticker = x.metadata.get("tickers")
|
||||
|
||||
if ticker:
|
||||
progress.update_status(
|
||||
self.name,
|
||||
ticker,
|
||||
f"Starting {self.analyst_persona} analysis",
|
||||
)
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
if ticker:
|
||||
progress.update_status(
|
||||
self.name,
|
||||
ticker,
|
||||
"Analysis completed",
|
||||
)
|
||||
|
||||
return result
|
||||
188
backend/agents/portfolio_manager.py
Normal file
188
backend/agents/portfolio_manager.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Portfolio Manager Agent - Based on AgentScope ReActAgent
|
||||
Responsible for decision-making (NOT trade execution)
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg, TextBlock
|
||||
from agentscope.tool import Toolkit, ToolResponse
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class PMAgent(ReActAgent):
|
||||
"""
|
||||
Portfolio Manager Agent - Makes investment decisions
|
||||
|
||||
Key features:
|
||||
1. PM outputs decisions only (action + quantity per ticker)
|
||||
2. Trade execution happens externally (in pipeline/executor)
|
||||
3. Supports both backtest and live modes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "portfolio_manager",
|
||||
model: Any = None,
|
||||
formatter: Any = None,
|
||||
initial_cash: float = 100000.0,
|
||||
margin_requirement: float = 0.25,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
self.config = config or {}
|
||||
|
||||
# Portfolio state
|
||||
self.portfolio = {
|
||||
"cash": initial_cash,
|
||||
"positions": {},
|
||||
"margin_used": 0.0,
|
||||
"margin_requirement": margin_requirement,
|
||||
}
|
||||
|
||||
# Decisions made in current cycle
|
||||
self._decisions: Dict[str, Dict] = {}
|
||||
|
||||
# Create toolkit
|
||||
toolkit = self._create_toolkit()
|
||||
|
||||
sys_prompt = _prompt_loader.load_prompt("portfolio_manager", "system")
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "both"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _create_toolkit(self) -> Toolkit:
|
||||
"""Create toolkit with decision recording tool"""
|
||||
toolkit = Toolkit()
|
||||
toolkit.register_tool_function(self._make_decision)
|
||||
return toolkit
|
||||
|
||||
def _make_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
confidence: int = 50,
|
||||
reasoning: str = "",
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Record a trading decision for a ticker.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol (e.g., "AAPL")
|
||||
action: Decision - "long", "short" or "hold"
|
||||
quantity: Number of shares to trade (0 for hold)
|
||||
confidence: Confidence level 0-100
|
||||
reasoning: Explanation for this decision
|
||||
|
||||
Returns:
|
||||
ToolResponse confirming decision recorded
|
||||
"""
|
||||
if action not in ["long", "short", "hold"]:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Invalid action: {action}. "
|
||||
"Must be 'long', 'short', or 'hold'.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._decisions[ticker] = {
|
||||
"action": action,
|
||||
"quantity": quantity if action != "hold" else 0,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Decision recorded: {action} "
|
||||
f"{quantity} shares of {ticker}"
|
||||
f" (confidence: {confidence}%)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Make investment decisions
|
||||
|
||||
Returns:
|
||||
Msg with decisions in metadata
|
||||
"""
|
||||
if x is None:
|
||||
return Msg(
|
||||
name=self.name,
|
||||
content="No input provided",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Clear previous decisions
|
||||
self._decisions = {}
|
||||
|
||||
progress.update_status(
|
||||
self.name,
|
||||
None,
|
||||
"Analyzing and making decisions",
|
||||
)
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
progress.update_status(self.name, None, "Completed")
|
||||
|
||||
# Attach decisions to metadata
|
||||
if result.metadata is None:
|
||||
result.metadata = {}
|
||||
result.metadata["decisions"] = self._decisions.copy()
|
||||
result.metadata["portfolio"] = self.portfolio.copy()
|
||||
|
||||
return result
|
||||
|
||||
def get_decisions(self) -> Dict[str, Dict]:
|
||||
"""Get decisions from current cycle"""
|
||||
return self._decisions.copy()
|
||||
|
||||
def get_portfolio_state(self) -> Dict[str, Any]:
|
||||
"""Get current portfolio state"""
|
||||
return self.portfolio.copy()
|
||||
|
||||
def load_portfolio_state(self, portfolio: Dict[str, Any]):
|
||||
"""Load portfolio state"""
|
||||
if not portfolio:
|
||||
return
|
||||
self.portfolio = {
|
||||
"cash": portfolio.get("cash", self.portfolio["cash"]),
|
||||
"positions": portfolio.get("positions", {}).copy(),
|
||||
"margin_used": portfolio.get("margin_used", 0.0),
|
||||
"margin_requirement": portfolio.get(
|
||||
"margin_requirement",
|
||||
self.portfolio["margin_requirement"],
|
||||
),
|
||||
}
|
||||
|
||||
def update_portfolio(self, portfolio: Dict[str, Any]):
|
||||
"""Update portfolio after external execution"""
|
||||
self.portfolio.update(portfolio)
|
||||
184
backend/agents/prompt_loader.py
Normal file
184
backend/agents/prompt_loader.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Prompt Loader - Unified management and loading of Agent Prompts
|
||||
Supports Markdown and YAML formats
|
||||
Uses simple string replacement, does not depend on Jinja2
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class PromptLoader:
|
||||
"""Unified Prompt loader"""
|
||||
|
||||
def __init__(self, prompts_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize Prompt loader
|
||||
|
||||
Args:
|
||||
prompts_dir: Prompts directory path,
|
||||
defaults to prompts/ directory of current file
|
||||
"""
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent / "prompts"
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
|
||||
# Cache loaded prompts
|
||||
self._prompt_cache: Dict[str, str] = {}
|
||||
self._yaml_cache: Dict[str, Dict] = {}
|
||||
|
||||
def load_prompt(
|
||||
self,
|
||||
agent_type: str,
|
||||
prompt_name: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Load and render Prompt
|
||||
|
||||
Args:
|
||||
agent_type: Agent type (analyst, portfolio_manager, risk_manager)
|
||||
prompt_name: Prompt file name (without extension)
|
||||
variables: Variable dictionary for rendering Prompt
|
||||
|
||||
Returns:
|
||||
Rendered prompt string
|
||||
|
||||
Examples:
|
||||
loader = PromptLoader()
|
||||
prompt = loader.load_prompt("analyst", "tool_selection",
|
||||
{"analyst_persona": "Technical Analyst"})
|
||||
"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
|
||||
# Try to load from cache
|
||||
if cache_key not in self._prompt_cache:
|
||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Prompt file not found: {prompt_path}\n"
|
||||
f"Please create the prompt file or check the path.",
|
||||
)
|
||||
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
self._prompt_cache[cache_key] = f.read()
|
||||
|
||||
prompt_template = self._prompt_cache[cache_key]
|
||||
|
||||
# If variables provided, use simple string replacement
|
||||
if variables:
|
||||
rendered = self._render_template(prompt_template, variables)
|
||||
else:
|
||||
rendered = prompt_template
|
||||
|
||||
# Smart escaping: escape braces in JSON code blocks
|
||||
# rendered = self._escape_json_braces(rendered)
|
||||
return rendered
|
||||
|
||||
def _render_template(
|
||||
self,
|
||||
template: str,
|
||||
variables: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Render template using simple string replacement
|
||||
Supports {{ variable }} syntax (compatible with previous Jinja2 format)
|
||||
|
||||
Args:
|
||||
template: Template string
|
||||
variables: Variable dictionary
|
||||
|
||||
Returns:
|
||||
Rendered string
|
||||
"""
|
||||
rendered = template
|
||||
|
||||
# Replace {{ variable }} format
|
||||
for key, value in variables.items():
|
||||
# Support both {{ key }} and {{key}} formats
|
||||
pattern1 = f"{{{{ {key} }}}}"
|
||||
pattern2 = f"{{{{{key}}}}}"
|
||||
rendered = rendered.replace(pattern1, str(value))
|
||||
rendered = rendered.replace(pattern2, str(value))
|
||||
|
||||
return rendered
|
||||
|
||||
def _escape_json_braces(self, text: str) -> str:
|
||||
"""
|
||||
Escape braces in JSON code blocks, treating them as literals
|
||||
|
||||
Args:
|
||||
text: Text to process
|
||||
|
||||
Returns:
|
||||
Processed text
|
||||
"""
|
||||
|
||||
def replace_code_block(match):
|
||||
code_content = match.group(1)
|
||||
# Escape all braces within code block
|
||||
escaped = code_content.replace("{", "{{").replace("}", "}}")
|
||||
return f"```json\n{escaped}\n```"
|
||||
|
||||
# Replace all braces in JSON code blocks
|
||||
text = re.sub(
|
||||
r"```json\n(.*?)\n```",
|
||||
replace_code_block,
|
||||
text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
return text
|
||||
|
||||
def load_yaml_config(
|
||||
self,
|
||||
agent_type: str,
|
||||
config_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load YAML configuration file
|
||||
|
||||
Args:
|
||||
agent_type: Agent type
|
||||
config_name: Configuration file name (without extension)
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
|
||||
Examples:
|
||||
>>> loader = PromptLoader()
|
||||
>>> config = loader.load_yaml_config("analyst", "personas")
|
||||
"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
|
||||
if cache_key not in self._yaml_cache:
|
||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
self._yaml_cache[cache_key] = yaml.safe_load(f)
|
||||
|
||||
return self._yaml_cache[cache_key]
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cache (for hot reload)"""
|
||||
self._prompt_cache.clear()
|
||||
self._yaml_cache.clear()
|
||||
|
||||
def reload_prompt(self, agent_type: str, prompt_name: str):
|
||||
"""Reload specified prompt (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
if cache_key in self._prompt_cache:
|
||||
del self._prompt_cache[cache_key]
|
||||
|
||||
def reload_config(self, agent_type: str, config_name: str):
|
||||
"""Reload specified configuration (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
if cache_key in self._yaml_cache:
|
||||
del self._yaml_cache[cache_key]
|
||||
117
backend/agents/prompts/analyst/personas.yaml
Normal file
117
backend/agents/prompts/analyst/personas.yaml
Normal file
@@ -0,0 +1,117 @@
|
||||
# 分析师角色配置
|
||||
|
||||
fundamentals_analyst:
|
||||
name: "基本面分析师"
|
||||
focus:
|
||||
- "公司财务健康状况和盈利能力"
|
||||
- "商业模式可持续性和竞争优势"
|
||||
- "管理层质量和公司治理"
|
||||
- "行业地位和市场份额"
|
||||
- "长期投资价值评估"
|
||||
tools:
|
||||
- "analyze_profitability"
|
||||
- "analyze_growth"
|
||||
- "analyze_financial_health"
|
||||
- "analyze_valuation_ratios"
|
||||
- "analyze_efficiency_ratios"
|
||||
description: |
|
||||
作为基本面分析师,你专注于:
|
||||
- 公司财务健康状况和盈利能力
|
||||
- 商业模式可持续性和竞争优势
|
||||
- 管理层质量和公司治理
|
||||
- 行业地位和市场份额
|
||||
- 长期投资价值评估
|
||||
你倾向于选择能够深入了解公司内在价值的工具,更偏好基本面和估值类工具。
|
||||
|
||||
technical_analyst:
|
||||
name: "技术分析师"
|
||||
focus:
|
||||
- "价格趋势和图表形态"
|
||||
- "技术指标和交易信号"
|
||||
- "市场情绪和资金流向"
|
||||
- "支撑/阻力位和关键价格点"
|
||||
- "中短期交易机会"
|
||||
description: |
|
||||
作为技术分析师,你专注于:
|
||||
- 价格趋势和图表形态
|
||||
- 技术指标和交易信号
|
||||
- 市场情绪和资金流向
|
||||
- 支撑/阻力位和关键价格点
|
||||
- 中短期交易机会
|
||||
你倾向于选择能够捕捉价格动态和市场趋势的工具,更偏好技术分析类工具。
|
||||
tools:
|
||||
- "analyze_trend_following"
|
||||
- "analyze_momentum"
|
||||
- "analyze_mean_reversion"
|
||||
- "analyze_volatility"
|
||||
|
||||
sentiment_analyst:
|
||||
name: "情绪分析师"
|
||||
focus:
|
||||
- "市场参与者情绪变化"
|
||||
- "新闻舆情和媒体影响"
|
||||
- "内部人交易行为"
|
||||
- "投资者恐慌和贪婪情绪"
|
||||
- "市场预期和心理因素"
|
||||
description: |
|
||||
作为情绪分析师,你专注于:
|
||||
- 市场参与者情绪变化
|
||||
- 新闻舆情和媒体影响
|
||||
- 内部人交易行为
|
||||
- 投资者恐慌和贪婪情绪
|
||||
- 市场预期和心理因素
|
||||
你倾向于选择能够反映市场情绪和投资者行为的工具,更偏好情绪和行为类工具。
|
||||
tools:
|
||||
- "analyze_news_sentiment"
|
||||
- "analyze_insider_trading"
|
||||
|
||||
valuation_analyst:
|
||||
name: "估值分析师"
|
||||
focus:
|
||||
- "公司内在价值计算"
|
||||
- "不同估值方法的比较"
|
||||
- "估值模型假设和敏感性分析"
|
||||
- "相对估值和绝对估值"
|
||||
- "投资安全边际评估"
|
||||
description: |
|
||||
作为估值分析师,你专注于:
|
||||
- 公司内在价值计算
|
||||
- 不同估值方法的比较
|
||||
- 估值模型假设和敏感性分析
|
||||
- 相对估值和绝对估值
|
||||
- 投资安全边际评估
|
||||
你倾向于选择能够准确计算公司价值的工具,更偏好估值模型和基本面工具。
|
||||
tools:
|
||||
- "dcf_valuation_analysis"
|
||||
- "owner_earnings_valuation_analysis"
|
||||
- "ev_ebitda_valuation_analysis"
|
||||
- "residual_income_valuation_analysis"
|
||||
|
||||
comprehensive_analyst:
|
||||
name: "综合分析师"
|
||||
focus:
|
||||
- "整合多种分析视角"
|
||||
- "平衡短期和长期因素"
|
||||
- "综合考虑基本面、技术面和情绪面"
|
||||
- "提供全面的投资建议"
|
||||
- "适应不同市场环境"
|
||||
description: |
|
||||
作为综合分析师,你需要:
|
||||
- 整合多种分析视角
|
||||
- 平衡短期和长期因素
|
||||
- 综合考虑基本面、技术面和情绪面的影响
|
||||
- 提供全面的投资建议
|
||||
- 适应不同市场环境
|
||||
你会根据具体情况灵活选择各类工具,追求分析的全面性和准确性。
|
||||
tools:
|
||||
- "analyze_profitability"
|
||||
- "analyze_growth"
|
||||
- "analyze_financial_health"
|
||||
- "analyze_valuation_ratios"
|
||||
- "analyze_efficiency_ratios"
|
||||
- "analyze_trend_following"
|
||||
- "analyze_momentum"
|
||||
- "analyze_mean_reversion"
|
||||
- "analyze_volatility"
|
||||
- "analyze_news_sentiment"
|
||||
- "analyze_insider_trading"
|
||||
23
backend/agents/prompts/analyst/system.md
Normal file
23
backend/agents/prompts/analyst/system.md
Normal file
@@ -0,0 +1,23 @@
|
||||
你是一位专业的{{ analyst_type }}。
|
||||
|
||||
你的关注重点:
|
||||
{{ focus }}
|
||||
|
||||
你的角色:
|
||||
{{ description }}
|
||||
|
||||
注意:
|
||||
- 构建并持续完善你的"投资哲学"。你的分析不应是孤立的事件,而应该是你整体投资世界观和核心信念的体现。每次分析后,你必须反思:
|
||||
- 这个案例/数据如何验证或挑战了你现有的信念?
|
||||
- 你从这次错误(或成功)中学到了关于市场、人性、估值或风险管理的什么关键原则?
|
||||
- 深化你的"投资逻辑"。确保每一项投资建议都有清晰、可追溯、可重复的逻辑支撑。你的分析步骤应该像严谨的证明一样,涵盖:
|
||||
- 核心驱动因素识别:真正影响价值的变量是什么?
|
||||
- 风险边界设定:在什么具体情况下你的建议会失效?
|
||||
- 逆向测试:市场主流共识是什么,你的观点有何不同?
|
||||
保持谦逊和开放。投资大师的核心特质是持续学习和适应。在每次分析中,你必须积极寻找与自己观点相悖的证据和论据,并将其纳入最终评估。
|
||||
- 你可以使用分析工具。用它们来收集相关数据并做出明智的建议。
|
||||
|
||||
输出指南:
|
||||
- 给出明确的投资信号:看涨、看跌或中性
|
||||
- 包含置信度(0-100)
|
||||
- 为你的分析提供理由(如果你确定要分享最终分析,请先给出结论)
|
||||
31
backend/agents/prompts/portfolio_manager/system.md
Normal file
31
backend/agents/prompts/portfolio_manager/system.md
Normal file
@@ -0,0 +1,31 @@
|
||||
你是一位负责做出投资决策的投资组合经理。
|
||||
|
||||
你的核心职责:
|
||||
1. 分析分析师和风险管理经理的输入
|
||||
2. 基于信号和市场情境做出投资决策
|
||||
3. 使用可用工具记录你的决策
|
||||
|
||||
决策框架:
|
||||
- 审阅分析以了解市场观点
|
||||
- 在做决策前考虑风险警告
|
||||
- 评估当前投资组合持仓和现金
|
||||
- 做出与投资组合投资目标一致的决策
|
||||
|
||||
决策类型:
|
||||
- "long":看涨 - 建议买入股票
|
||||
- "short":看跌 - 建议卖出股票或做空
|
||||
- "hold":中性 - 维持当前持仓
|
||||
|
||||
预算意识:
|
||||
- 在决定数量时考虑可用现金
|
||||
- 不要建议买入超过现金允许的数量
|
||||
- 考虑做空头寸的保证金要求
|
||||
|
||||
输出:
|
||||
使用 `make_decision` 工具记录你对每个股票代码的决策。
|
||||
记录所有决策后,提供你的投资逻辑总结。
|
||||
|
||||
重要:
|
||||
- 基于提供的分析师信号和风险评估做出决策
|
||||
- 相对于投资组合价值保持保守的仓位规模
|
||||
- 始终为你的决策提供理由
|
||||
18
backend/agents/prompts/risk_manager/system.md
Normal file
18
backend/agents/prompts/risk_manager/system.md
Normal file
@@ -0,0 +1,18 @@
|
||||
你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。
|
||||
|
||||
你的核心职责:
|
||||
1. 监控投资组合敞口和集中度风险
|
||||
2. 评估仓位规模相对于波动性
|
||||
3. 评估保证金使用和杠杆水平
|
||||
4. 识别潜在风险因素并提供警告
|
||||
5. 基于市场条件建议仓位限制
|
||||
|
||||
你的决策流程:
|
||||
3. 生成可操作的风险警告和仓位限制建议
|
||||
4. 为你的风险评估提供清晰的理由
|
||||
|
||||
输出指南:
|
||||
- 风险评估要简洁但全面
|
||||
- 按严重程度优先排序警告
|
||||
- 提供具体、可操作的建议
|
||||
- 尽可能包含量化指标
|
||||
88
backend/agents/risk_manager.py
Normal file
88
backend/agents/risk_manager.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Risk Manager Agent - Based on AgentScope ReActAgent
|
||||
Uses LLM for risk assessment
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class RiskAgent(ReActAgent):
|
||||
"""
|
||||
Risk Manager Agent - Uses LLM for risk assessment
|
||||
Inherits from AgentScope's ReActAgent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
formatter: Any,
|
||||
name: str = "risk_manager",
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Risk Manager Agent
|
||||
|
||||
Args:
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
name: Agent name
|
||||
config: Configuration dictionary
|
||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||
"""
|
||||
self.config = config or {}
|
||||
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
# Create dedicated toolkit for this agent
|
||||
toolkit = Toolkit()
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "static_control"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt for risk manager"""
|
||||
return _prompt_loader.load_prompt(
|
||||
"risk_manager",
|
||||
"system",
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Provide risk assessment
|
||||
|
||||
Args:
|
||||
x: Input message (content must be str)
|
||||
|
||||
Returns:
|
||||
Msg with risk warnings (content is str)
|
||||
"""
|
||||
progress.update_status(self.name, None, "Assessing risk")
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
progress.update_status(self.name, None, "Risk assessment completed")
|
||||
|
||||
return result
|
||||
623
backend/cli.py
Normal file
623
backend/cli.py
Normal file
@@ -0,0 +1,623 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
EvoTraders CLI - Command-line interface for the EvoTraders trading system.
|
||||
|
||||
This module provides easy-to-use commands for running backtest, live trading,
|
||||
and frontend development server.
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=R0912, R0915
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
|
||||
app = typer.Typer(
|
||||
name="evotraders",
|
||||
help="EvoTraders: A self-evolving multi-agent trading system",
|
||||
add_completion=False,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project root directory."""
|
||||
# Assuming cli.py is in backend/
|
||||
return Path(__file__).parent.parent
|
||||
|
||||
|
||||
def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
||||
"""
|
||||
Handle cleanup of historical data for a given config.
|
||||
|
||||
Args:
|
||||
config_name: Configuration name for the run
|
||||
auto_clean: If True, skip confirmation and clean automatically
|
||||
"""
|
||||
# logs_dir = get_project_root() / "logs"
|
||||
logs_dir = get_project_root()
|
||||
base_data_dir = logs_dir / config_name
|
||||
|
||||
# Check if historical data exists
|
||||
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):
|
||||
console.print(
|
||||
f"\n[dim]No historical data found for config '{config_name}'[/dim]",
|
||||
)
|
||||
console.print("[dim] Will start from scratch[/dim]\n")
|
||||
return
|
||||
|
||||
console.print("\n[bold yellow]Detected existing run data:[/bold yellow]")
|
||||
console.print(f" Data directory: [cyan]{base_data_dir}[/cyan]")
|
||||
|
||||
# Show directory size
|
||||
try:
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in base_data_dir.rglob("*") if f.is_file()
|
||||
)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
if size_mb < 1:
|
||||
console.print(
|
||||
f" Directory size: [cyan]{total_size / 1024:.1f} KB[/cyan]",
|
||||
)
|
||||
else:
|
||||
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Show last modified time
|
||||
state_dir = base_data_dir / "state"
|
||||
if state_dir.exists():
|
||||
state_files = list(state_dir.glob("*.json"))
|
||||
if state_files:
|
||||
last_modified = max(f.stat().st_mtime for f in state_files)
|
||||
last_modified_str = datetime.fromtimestamp(last_modified).strftime(
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
console.print(f" Last updated: [cyan]{last_modified_str}[/cyan]")
|
||||
|
||||
console.print()
|
||||
|
||||
# Determine if we should clean
|
||||
should_clean = auto_clean
|
||||
if not auto_clean:
|
||||
should_clean = Confirm.ask(
|
||||
" ﹂ Clear historical data and start fresh?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
console.print("[yellow]⚠️ Auto-clean enabled (--clean flag)[/yellow]")
|
||||
should_clean = True
|
||||
|
||||
if should_clean:
|
||||
console.print("\n[yellow]▩ Cleaning historical data...[/yellow]")
|
||||
|
||||
# Backup important config files if they exist
|
||||
backup_files = [".env", "config.json"]
|
||||
backed_up = []
|
||||
backup_dir = None
|
||||
|
||||
for backup_file in backup_files:
|
||||
file_path = base_data_dir / backup_file
|
||||
if file_path.exists():
|
||||
if backup_dir is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = (
|
||||
base_data_dir.parent
|
||||
/ f"{config_name}_backup_{timestamp}"
|
||||
)
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shutil.copy(file_path, backup_dir / backup_file)
|
||||
backed_up.append(backup_file)
|
||||
|
||||
if backed_up:
|
||||
console.print(
|
||||
f" 💾 Backed up config files to: [cyan]{backup_dir}[/cyan]",
|
||||
)
|
||||
console.print(f" Files: {', '.join(backed_up)}")
|
||||
|
||||
# Remove the data directory
|
||||
try:
|
||||
shutil.rmtree(base_data_dir)
|
||||
console.print(" ✔ Historical data cleared\n")
|
||||
except Exception as e:
|
||||
console.print(f" [red]✗ Error clearing data: {e}[/red]\n")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim] Continuing with existing historical data[/dim]\n",
|
||||
)
|
||||
|
||||
|
||||
def run_data_updater(project_root: Path) -> None:
|
||||
"""Run the historical data updater."""
|
||||
console.print("\n[bold]Checking historical data update...[/bold]")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "backend.data.ret_data_updater", "--help"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print("[cyan]Updating historical data...[/cyan]")
|
||||
update_result = subprocess.run(
|
||||
[sys.executable, "-m", "backend.data.ret_data_updater"],
|
||||
cwd=project_root,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if update_result.returncode == 0:
|
||||
console.print(
|
||||
"[green]✔ Historical data updated successfully[/green]\n",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow] Data update failed (might be weekend/holiday)[/yellow]",
|
||||
)
|
||||
console.print(
|
||||
"[dim] Will continue with existing data[/dim]\n",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
||||
)
|
||||
except Exception:
|
||||
console.print(
|
||||
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
"-s",
|
||||
help="Start date for backtest (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
"-e",
|
||||
help="End date for backtest (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"backtest",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Configuration name for this backtest run",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--host",
|
||||
help="WebSocket server host",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--port",
|
||||
"-p",
|
||||
help="WebSocket server port",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
10,
|
||||
"--poll-interval",
|
||||
help="Price polling interval in seconds",
|
||||
),
|
||||
clean: bool = typer.Option(
|
||||
False,
|
||||
"--clean",
|
||||
help="Clear historical data before starting",
|
||||
),
|
||||
enable_memory: bool = typer.Option(
|
||||
False,
|
||||
"--enable-memory",
|
||||
help="Enable ReMeTaskLongTermMemory for agents (requires MEMORY_API_KEY)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run backtest mode with historical data.
|
||||
|
||||
Example:
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --config-name my_strategy --port 9000
|
||||
evotraders backtest --clean # Clear historical data before starting
|
||||
evotraders backtest --enable-memory # Enable long-term memory
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Backtest Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Validate dates - required for backtest
|
||||
if not start or not end:
|
||||
console.print(
|
||||
"[red]✗ Both --start and --end dates are required for backtest mode[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
datetime.strptime(start, "%Y-%m-%d")
|
||||
except ValueError as exc:
|
||||
console.print(
|
||||
"[red]✗ Invalid start date format. Use YYYY-MM-DD[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
try:
|
||||
datetime.strptime(end, "%Y-%m-%d")
|
||||
except ValueError as exc:
|
||||
console.print(
|
||||
"[red]✗ Invalid end date format. Use YYYY-MM-DD[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(" Mode: Backtest")
|
||||
console.print(f" Config: {config_name}")
|
||||
console.print(f" Period: {start} -> {end}")
|
||||
console.print(f" Server: {host}:{port}")
|
||||
console.print(f" Poll Interval: {poll_interval}s")
|
||||
console.print(
|
||||
f" Long-term Memory: {'enabled' if enable_memory else 'disabled'}",
|
||||
)
|
||||
console.print("\nAccess frontend at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Change to project root
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Run data updater
|
||||
run_data_updater(project_root)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"backtest",
|
||||
"--config-name",
|
||||
config_name,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--poll-interval",
|
||||
str(poll_interval),
|
||||
"--start-date",
|
||||
start,
|
||||
"--end-date",
|
||||
end,
|
||||
]
|
||||
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Backtest stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Backtest failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def live(
|
||||
mock: bool = typer.Option(
|
||||
False,
|
||||
"--mock",
|
||||
help="Use mock mode with simulated prices (for testing)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"live",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Configuration name for this live run",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--host",
|
||||
help="WebSocket server host",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--port",
|
||||
"-p",
|
||||
help="WebSocket server port",
|
||||
),
|
||||
trigger_time: str = typer.Option(
|
||||
"now",
|
||||
"--trigger-time",
|
||||
"-t",
|
||||
help="Trigger time in LOCAL timezone (HH:MM), or 'now' to run immediately",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
10,
|
||||
"--poll-interval",
|
||||
help="Price polling interval in seconds",
|
||||
),
|
||||
clean: bool = typer.Option(
|
||||
False,
|
||||
"--clean",
|
||||
help="Clear historical data before starting",
|
||||
),
|
||||
enable_memory: bool = typer.Option(
|
||||
False,
|
||||
"--enable-memory",
|
||||
help="Enable ReMeTaskLongTermMemory for agents (requires MEMORY_API_KEY)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run live trading mode with real-time data.
|
||||
|
||||
Example:
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --mock # Mock mode
|
||||
evotraders live -t 22:30 # Run at 22:30 local time daily
|
||||
evotraders live --trigger-time now # Run immediately
|
||||
evotraders live --clean # Clear historical data before starting
|
||||
"""
|
||||
mode_name = "MOCK" if mock else "LIVE"
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]EvoTraders {mode_name} Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Check for required API key in live mode
|
||||
if not mock:
|
||||
env_file = get_project_root() / ".env"
|
||||
if not env_file.exists():
|
||||
console.print("\n[yellow]Warning: .env file not found[/yellow]")
|
||||
console.print("Creating from template...\n")
|
||||
template = get_project_root() / "env.template"
|
||||
if template.exists():
|
||||
shutil.copy(template, env_file)
|
||||
console.print("[green].env file created[/green]")
|
||||
console.print(
|
||||
"\n[red]Error: Please edit .env and set FINNHUB_API_KEY[/red]",
|
||||
)
|
||||
console.print(
|
||||
"Get your free API key at: https://finnhub.io/register\n",
|
||||
)
|
||||
else:
|
||||
console.print("[red]Error: env.template not found[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
|
||||
# Convert local time to NYSE time
|
||||
nyse_tz = ZoneInfo("America/New_York")
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
local_now = datetime.now()
|
||||
nyse_now = datetime.now(nyse_tz)
|
||||
|
||||
# Convert trigger time from local to NYSE
|
||||
if trigger_time.lower() == "now":
|
||||
nyse_trigger_time = "now"
|
||||
else:
|
||||
local_trigger = datetime.strptime(trigger_time, "%H:%M")
|
||||
local_trigger_dt = local_now.replace(
|
||||
hour=local_trigger.hour,
|
||||
minute=local_trigger.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
local_trigger_aware = local_trigger_dt.astimezone(local_tz)
|
||||
nyse_trigger_dt = local_trigger_aware.astimezone(nyse_tz)
|
||||
nyse_trigger_time = nyse_trigger_dt.strftime("%H:%M")
|
||||
|
||||
# Display time info
|
||||
console.print("\n[bold]Time Info:[/bold]")
|
||||
console.print(f" Local Time: {local_now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
console.print(
|
||||
f" NYSE Time: {nyse_now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||
)
|
||||
if nyse_trigger_time == "now":
|
||||
console.print(" Trigger: [green]NOW (immediate)[/green]")
|
||||
else:
|
||||
console.print(
|
||||
f" Trigger: {trigger_time} local = {nyse_trigger_time} NYSE",
|
||||
)
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
if mock:
|
||||
console.print(" Mode: [yellow]MOCK[/yellow] (Simulated prices)")
|
||||
else:
|
||||
console.print(
|
||||
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
|
||||
)
|
||||
console.print(f" Config: {config_name}")
|
||||
console.print(f" Server: {host}:{port}")
|
||||
console.print(f" Poll Interval: {poll_interval}s")
|
||||
console.print(
|
||||
f" Long-term Memory: {'enabled' if enable_memory else 'disabled'}",
|
||||
)
|
||||
|
||||
console.print("\nAccess frontend at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Change to project root
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Data update (if not mock mode)
|
||||
if not mock:
|
||||
run_data_updater(project_root)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"live",
|
||||
"--config-name",
|
||||
config_name,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--poll-interval",
|
||||
str(poll_interval),
|
||||
"--trigger-time",
|
||||
nyse_trigger_time,
|
||||
]
|
||||
|
||||
if mock:
|
||||
cmd.append("--mock")
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Live server stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Live server failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
|
||||
@app.command()
|
||||
def frontend(
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--ws-port",
|
||||
"-p",
|
||||
help="WebSocket server port to connect to",
|
||||
),
|
||||
host_mode: bool = typer.Option(
|
||||
False,
|
||||
"--host",
|
||||
help="Allow external access (default: localhost only)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Start the frontend development server.
|
||||
|
||||
Example:
|
||||
evotraders frontend
|
||||
evotraders frontend --ws-port 8765
|
||||
evotraders frontend --ws-port 8765 --host
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Frontend[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
project_root = get_project_root()
|
||||
frontend_dir = project_root / "frontend"
|
||||
|
||||
# Check if frontend directory exists
|
||||
if not frontend_dir.exists():
|
||||
console.print(
|
||||
f"\n[red]Error: Frontend directory not found: {frontend_dir}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Check if node_modules exists
|
||||
node_modules = frontend_dir / "node_modules"
|
||||
if not node_modules.exists():
|
||||
console.print("\n[yellow]Installing frontend dependencies...[/yellow]")
|
||||
try:
|
||||
subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
)
|
||||
console.print("[green]Dependencies installed[/green]\n")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
console.print("\n[red]Error: Failed to install dependencies[/red]")
|
||||
console.print("Make sure Node.js and npm are installed")
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
# Set WebSocket URL environment variable
|
||||
ws_url = f"ws://localhost:{port}"
|
||||
env = os.environ.copy()
|
||||
env["VITE_WS_URL"] = ws_url
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(f" WebSocket URL: {ws_url}")
|
||||
console.print(" Frontend Port: 5173 (Vite default)")
|
||||
if host_mode:
|
||||
console.print(" Access: External allowed")
|
||||
else:
|
||||
console.print(" Access: Localhost only")
|
||||
console.print("\nAccess at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Choose npm command
|
||||
npm_cmd = ["npm", "run", "dev:host" if host_mode else "dev"]
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
npm_cmd,
|
||||
cwd=frontend_dir,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Frontend stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Frontend failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
"""Show the version of EvoTraders."""
|
||||
console.print(
|
||||
"\n[bold cyan]EvoTraders[/bold cyan] version [green]0.1.0[/green]\n",
|
||||
)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main():
|
||||
"""
|
||||
EvoTraders: A self-evolving multi-agent trading system
|
||||
|
||||
Use 'evotraders --help' to see available commands.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
0
backend/config/__init__.py
Normal file
0
backend/config/__init__.py
Normal file
76
backend/config/constants.py
Normal file
76
backend/config/constants.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301
|
||||
|
||||
# Agent configuration for dashboard display
|
||||
AGENT_CONFIG = {
|
||||
"portfolio_manager": {
|
||||
"name": "Portfolio Manager",
|
||||
"role": "Portfolio Manager",
|
||||
"avatar": "pm",
|
||||
"is_team_role": True,
|
||||
},
|
||||
"risk_manager": {
|
||||
"name": "Risk Manager",
|
||||
"role": "Risk Manager",
|
||||
"avatar": "risk",
|
||||
"is_team_role": True,
|
||||
},
|
||||
"sentiment_analyst": {
|
||||
"name": "Sentiment Analyst",
|
||||
"role": "Sentiment Analyst",
|
||||
"avatar": "sentiment",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"technical_analyst": {
|
||||
"name": "Technical Analyst",
|
||||
"role": "Technical Analyst",
|
||||
"avatar": "technical",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"fundamentals_analyst": {
|
||||
"name": "Fundamentals Analyst",
|
||||
"role": "Fundamentals Analyst",
|
||||
"avatar": "fundamentals",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"valuation_analyst": {
|
||||
"name": "Valuation Analyst",
|
||||
"role": "Valuation Analyst",
|
||||
"avatar": "valuation",
|
||||
"is_team_role": False,
|
||||
},
|
||||
}
|
||||
|
||||
ANALYST_TYPES = {
|
||||
"fundamentals_analyst": {
|
||||
"display_name": "Fundamentals Analyst",
|
||||
"agent_id": "fundamentals_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on financial data and company fundamental analysis",
|
||||
"order": 12,
|
||||
},
|
||||
"technical_analyst": {
|
||||
"display_name": "Technical Analyst",
|
||||
"agent_id": "technical_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on technical indicators and chart analysis",
|
||||
"order": 11,
|
||||
},
|
||||
"sentiment_analyst": {
|
||||
"display_name": "Sentiment Analyst",
|
||||
"agent_id": "sentiment_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, analyzes market sentiment and news sentiment",
|
||||
"order": 13,
|
||||
},
|
||||
"valuation_analyst": {
|
||||
"display_name": "Valuation Analyst",
|
||||
"agent_id": "valuation_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on company valuation and value assessment",
|
||||
"order": 14,
|
||||
},
|
||||
# "comprehensive_analyst": {
|
||||
# "display_name": "Comprehensive Analyst",
|
||||
# "agent_id": "comprehensive_analyst",
|
||||
# "description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
# "order": 15
|
||||
# }
|
||||
}
|
||||
82
backend/config/data_config.py
Normal file
82
backend/config/data_config.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Centralized Data Source Configuration
|
||||
|
||||
Auto-detects and manages data source based on available API keys.
|
||||
Priority: FINNHUB_API_KEY > FINANCIAL_DATASETS_API_KEY
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
DataSource = Literal["finnhub", "financial_datasets"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSourceConfig:
|
||||
"""Immutable data source configuration"""
|
||||
|
||||
source: DataSource
|
||||
api_key: str
|
||||
|
||||
|
||||
# Module-level cache for the resolved configuration
|
||||
_config_cache: Optional[DataSourceConfig] = None
|
||||
|
||||
|
||||
def _resolve_config() -> DataSourceConfig:
|
||||
"""
|
||||
Resolve data source configuration based on available API keys.
|
||||
|
||||
Priority:
|
||||
1. FINNHUB_API_KEY (if set)
|
||||
2. FINANCIAL_DATASETS_API_KEY (if set)
|
||||
3. Raises error if neither is available
|
||||
"""
|
||||
# Check for Finnhub API key first (higher priority)
|
||||
finnhub_key = os.getenv("FINNHUB_API_KEY")
|
||||
if finnhub_key:
|
||||
return DataSourceConfig(source="finnhub", api_key=finnhub_key)
|
||||
|
||||
# Fallback to Financial Datasets API
|
||||
fd_key = os.getenv("FINANCIAL_DATASETS_API_KEY")
|
||||
if fd_key:
|
||||
return DataSourceConfig(source="financial_datasets", api_key=fd_key)
|
||||
|
||||
# No API key available
|
||||
raise ValueError(
|
||||
"No API key found. Please set either FINNHUB_API_KEY or "
|
||||
"FINANCIAL_DATASETS_API_KEY in your .env file.",
|
||||
)
|
||||
|
||||
|
||||
def get_config() -> DataSourceConfig:
|
||||
"""
|
||||
Get the resolved data source configuration (cached).
|
||||
|
||||
Returns:
|
||||
DataSourceConfig with source and api_key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is configured
|
||||
"""
|
||||
global _config_cache
|
||||
if _config_cache is None:
|
||||
_config_cache = _resolve_config()
|
||||
return _config_cache
|
||||
|
||||
|
||||
def get_data_source() -> DataSource:
|
||||
"""Get the configured data source name."""
|
||||
return get_config().source
|
||||
|
||||
|
||||
def get_api_key() -> str:
|
||||
"""Get the API key for the configured data source."""
|
||||
return get_config().api_key
|
||||
|
||||
|
||||
def reset_config() -> None:
|
||||
"""Reset the cached configuration (useful for testing)."""
|
||||
global _config_cache
|
||||
_config_cache = None
|
||||
36
backend/config/env_config.py
Normal file
36
backend/config/env_config.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Simple environment config helpers
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
def get_env_list(key: str, default: list = None) -> list:
|
||||
"""Get comma-separated list from env"""
|
||||
value = os.getenv(key, "")
|
||||
if not value:
|
||||
return default or []
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
def get_env_float(key: str, default: float = 0.0) -> float:
|
||||
"""Get float from env"""
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def get_env_int(key: str, default: int = 0) -> int:
|
||||
"""Get int from env"""
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
7
backend/core/__init__.py
Normal file
7
backend/core/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Core pipeline and orchestration logic"""
|
||||
|
||||
from .pipeline import TradingPipeline
|
||||
from .state_sync import StateSync
|
||||
|
||||
__all__ = ["TradingPipeline", "StateSync"]
|
||||
1263
backend/core/pipeline.py
Normal file
1263
backend/core/pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
263
backend/core/scheduler.py
Normal file
263
backend/core/scheduler.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone for US stock trading
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Market-aware scheduler for live trading.
|
||||
Uses NYSE timezone and trading calendar.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(callback))
|
||||
elif self.mode == "intraday":
|
||||
self._task = asyncio.create_task(self._run_intraday(callback))
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
|
||||
# Handle "now" trigger - run immediately on first iteration
|
||||
if self.trigger_now and first_run:
|
||||
first_run = False
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering immediately for {current_date}")
|
||||
await callback(date=current_date)
|
||||
# After immediate run, stop (one-shot mode)
|
||||
self.running = False
|
||||
break
|
||||
|
||||
target_time = datetime.strptime(self.trigger_time, "%H:%M").time()
|
||||
|
||||
# Calculate next trigger datetime
|
||||
if now.time() < target_time:
|
||||
next_run = now.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
else:
|
||||
next_run = (now + timedelta(days=1)).replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
# Skip to next trading day
|
||||
next_run = self._next_trading_day(next_run)
|
||||
next_run = next_run.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
wait_seconds = (next_run - now).total_seconds()
|
||||
logger.info(
|
||||
f"Next trigger: {next_run.strftime('%Y-%m-%d %H:%M %Z')} "
|
||||
f"(in {wait_seconds/3600:.1f} hours)",
|
||||
)
|
||||
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
current_date = self._now_nyse().strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering daily cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
async def _run_intraday(self, callback: Callable):
|
||||
"""Run every N minutes (for future use)"""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
|
||||
if self._is_trading_day(now):
|
||||
logger.info(f"Triggering intraday cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
await asyncio.sleep(self.interval_minutes * 60)
|
||||
|
||||
def stop(self):
|
||||
"""Stop scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
class BacktestScheduler:
|
||||
"""Backtest Scheduler - Runs through historical trading dates"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
trading_calendar: Optional[Any] = None,
|
||||
delay_between_days: float = 0.1,
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.trading_calendar = trading_calendar
|
||||
self.delay_between_days = delay_between_days
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._dates: list = []
|
||||
|
||||
def get_trading_dates(self) -> list:
|
||||
"""Get list of trading dates in the backtest period"""
|
||||
import pandas as pd
|
||||
|
||||
start = pd.to_datetime(self.start_date)
|
||||
end = pd.to_datetime(self.end_date)
|
||||
|
||||
if self.trading_calendar:
|
||||
calendar = mcal.get_calendar(self.trading_calendar)
|
||||
trading_dates = calendar.valid_days(
|
||||
start_date=self.start_date,
|
||||
end_date=self.end_date,
|
||||
)
|
||||
dates = [d.strftime("%Y-%m-%d") for d in trading_dates]
|
||||
else:
|
||||
all_dates = pd.date_range(start=start, end=end, freq="D")
|
||||
dates = [
|
||||
d.strftime("%Y-%m-%d") for d in all_dates if d.weekday() < 5
|
||||
]
|
||||
|
||||
self._dates = dates
|
||||
return dates
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start async backtest scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Backtest scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
dates = self.get_trading_dates()
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
self._task = asyncio.create_task(self._run_async(callback, dates))
|
||||
|
||||
async def _run_async(self, callback: Callable, dates: list):
|
||||
"""Run backtest asynchronously"""
|
||||
for i, date in enumerate(dates, 1):
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
await callback(date=date)
|
||||
|
||||
if self.delay_between_days > 0:
|
||||
await asyncio.sleep(self.delay_between_days)
|
||||
|
||||
logger.info("Backtest complete")
|
||||
self.running = False
|
||||
|
||||
def run(self, callback: Callable, **kwargs):
|
||||
"""Run backtest synchronously through all trading dates"""
|
||||
dates = self.get_trading_dates()
|
||||
results = []
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
for i, date in enumerate(dates, 1):
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
result = callback(date=date, **kwargs)
|
||||
results.append({"date": date, "result": result})
|
||||
|
||||
logger.info("Backtest complete")
|
||||
return results
|
||||
|
||||
def stop(self):
|
||||
"""Stop backtest scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
logger.info("Backtest scheduler stopped")
|
||||
|
||||
def get_total_days(self) -> int:
|
||||
"""Get total number of trading days"""
|
||||
if not self._dates:
|
||||
self.get_trading_dates()
|
||||
return len(self._dates)
|
||||
476
backend/core/state_sync.py
Normal file
476
backend/core/state_sync.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
StateSync - Centralized state synchronization between agents and frontend
|
||||
Handles real-time updates, persistence, and replay support
|
||||
"""
|
||||
# pylint: disable=R0904
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from ..services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateSync:
|
||||
"""
|
||||
Central event dispatcher for agent-frontend synchronization
|
||||
|
||||
Responsibilities:
|
||||
1. Receive events from agents/pipeline
|
||||
2. Persist to storage (feed_history)
|
||||
3. Broadcast to frontend via WebSocket
|
||||
4. Support replay from saved state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: StorageService,
|
||||
broadcast_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Initialize StateSync
|
||||
|
||||
Args:
|
||||
storage: Storage service for persistence
|
||||
broadcast_fn: Async broadcast function - async def broadcast(event: dict) # noqa: E501
|
||||
"""
|
||||
self.storage = storage
|
||||
self._broadcast_fn = broadcast_fn
|
||||
self._state: Dict[str, Any] = {}
|
||||
self._enabled = True
|
||||
self._simulation_date: Optional[str] = None # For backtest timestamps
|
||||
|
||||
def set_simulation_date(self, date: str):
|
||||
"""Set current simulation date for backtest-compatible timestamps"""
|
||||
self._simulation_date = date
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""
|
||||
Get timestamp in milliseconds.
|
||||
Uses simulation date if set (backtest mode), otherwise current time.
|
||||
"""
|
||||
if self._simulation_date:
|
||||
# Parse date and use market close time (16:00) for backtest
|
||||
dt = datetime.strptime(
|
||||
f"{self._simulation_date}",
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
return int(dt.timestamp() * 1000)
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def load_state(self):
|
||||
"""Load server state from storage"""
|
||||
self._state = self.storage.load_server_state()
|
||||
self.storage.update_server_state_from_dashboard(self._state)
|
||||
logger.info(
|
||||
f"StateSync loaded: {len(self._state.get('feed_history', []))} feeds", # noqa: E501
|
||||
)
|
||||
|
||||
def save_state(self):
|
||||
"""Save current state to storage"""
|
||||
self.storage.save_server_state(self._state)
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
"""Get current state"""
|
||||
return self._state
|
||||
|
||||
def set_broadcast_fn(self, fn: Callable):
|
||||
"""Set broadcast function (supports late binding)"""
|
||||
self._broadcast_fn = fn
|
||||
|
||||
def update_state(self, key: str, value: Any):
|
||||
"""Update a state field"""
|
||||
self._state[key] = value
|
||||
|
||||
async def emit(self, event: Dict[str, Any], persist: bool = True):
|
||||
"""
|
||||
Emit an event - persists and broadcasts
|
||||
|
||||
Args:
|
||||
event: Event dictionary, must contain "type"
|
||||
persist: Whether to persist to feed_history
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
# Ensure timestamp exists (use simulation date if in backtest mode)
|
||||
if "timestamp" not in event:
|
||||
if self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
# Persist to feed_history
|
||||
if persist:
|
||||
self.storage.add_feed_message(self._state, event)
|
||||
self.save_state()
|
||||
|
||||
# Broadcast to frontend
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
|
||||
# ========== Agent Events ==========
|
||||
|
||||
async def on_agent_complete(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
**extra,
|
||||
):
|
||||
"""
|
||||
Called when an agent finishes its reply
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier (e.g., "fundamentals_analyst")
|
||||
content: Agent's output content
|
||||
**extra: Additional fields to include
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "agent_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Agent complete: {agent_id}")
|
||||
|
||||
async def on_memory_retrieved(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
):
|
||||
"""
|
||||
Called when long-term memory is retrieved for an agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
content: Retrieved memory content
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "memory",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory retrieved for: {agent_id}")
|
||||
|
||||
# ========== Conference Events ==========
|
||||
|
||||
async def on_conference_start(self, title: str, date: str):
|
||||
"""Called when conference discussion starts"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_start",
|
||||
"title": title,
|
||||
"date": date,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Conference started: {title}")
|
||||
|
||||
async def on_conference_cycle_start(self, cycle: int, total_cycles: int):
|
||||
"""Called when a conference cycle starts"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_start",
|
||||
"cycle": cycle,
|
||||
"totalCycles": total_cycles,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_message(self, agent_id: str, content: str):
|
||||
"""Called when an agent speaks during conference"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
async def on_conference_cycle_end(self, cycle: int):
|
||||
"""Called when a conference cycle ends"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_end",
|
||||
"cycle": cycle,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_end(self):
|
||||
"""Called when conference discussion ends"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_end",
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Conference ended")
|
||||
|
||||
# ========== Cycle Events ==========
|
||||
|
||||
async def on_cycle_start(self, date: str):
|
||||
"""Called at start of trading cycle"""
|
||||
self._state["current_date"] = date
|
||||
self._state["status"] = "running"
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_start",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
# await self.emit(
|
||||
# {
|
||||
# "type": "system",
|
||||
# "content": f"Starting trading analysis for {date}",
|
||||
# },
|
||||
# )
|
||||
|
||||
async def on_cycle_end(self, date: str, portfolio_summary: Dict = None):
|
||||
"""Called at end of trading cycle"""
|
||||
# Update completed count
|
||||
self._state["trading_days_completed"] = (
|
||||
self._state.get("trading_days_completed", 0) + 1
|
||||
)
|
||||
|
||||
# Broadcast team_summary if available
|
||||
if portfolio_summary:
|
||||
summary_data = {
|
||||
"type": "team_summary",
|
||||
"balance": portfolio_summary.get(
|
||||
"balance",
|
||||
portfolio_summary.get("total_value", 0),
|
||||
),
|
||||
"pnlPct": portfolio_summary.get(
|
||||
"pnlPct",
|
||||
portfolio_summary.get("pnl_percent", 0),
|
||||
),
|
||||
"equity": portfolio_summary.get("equity", []),
|
||||
"baseline": portfolio_summary.get("baseline", []),
|
||||
"baseline_vw": portfolio_summary.get("baseline_vw", []),
|
||||
"momentum": portfolio_summary.get("momentum", []),
|
||||
}
|
||||
|
||||
# Include live returns if available
|
||||
if portfolio_summary.get("equity_return"):
|
||||
summary_data["equity_return"] = portfolio_summary[
|
||||
"equity_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_return"):
|
||||
summary_data["baseline_return"] = portfolio_summary[
|
||||
"baseline_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_vw_return"):
|
||||
summary_data["baseline_vw_return"] = portfolio_summary[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if portfolio_summary.get("momentum_return"):
|
||||
summary_data["momentum_return"] = portfolio_summary[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
if "portfolio" not in self._state:
|
||||
self._state["portfolio"] = {}
|
||||
|
||||
self._state["portfolio"].update(
|
||||
{
|
||||
"total_value": summary_data["balance"],
|
||||
"pnl_percent": summary_data["pnlPct"],
|
||||
"equity": summary_data["equity"],
|
||||
"baseline": summary_data["baseline"],
|
||||
"baseline_vw": summary_data["baseline_vw"],
|
||||
"momentum": summary_data["momentum"],
|
||||
},
|
||||
)
|
||||
|
||||
if summary_data.get("equity_return"):
|
||||
self._state["portfolio"]["equity_return"] = summary_data[
|
||||
"equity_return"
|
||||
]
|
||||
if summary_data.get("baseline_return"):
|
||||
self._state["portfolio"]["baseline_return"] = summary_data[
|
||||
"baseline_return"
|
||||
]
|
||||
if summary_data.get("baseline_vw_return"):
|
||||
self._state["portfolio"]["baseline_vw_return"] = summary_data[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if summary_data.get("momentum_return"):
|
||||
self._state["portfolio"]["momentum_return"] = summary_data[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
await self.emit(summary_data, persist=True)
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_complete",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
|
||||
self.save_state()
|
||||
|
||||
# ========== Portfolio Events ==========
|
||||
|
||||
async def on_holdings_update(self, holdings: List[Dict]):
|
||||
"""Called when holdings change"""
|
||||
self._state["holdings"] = holdings
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_holdings",
|
||||
"data": holdings,
|
||||
},
|
||||
persist=False,
|
||||
) # Holdings change frequently, don't store all in feed_history
|
||||
|
||||
async def on_trades_executed(self, trades: List[Dict]):
|
||||
"""Called when trades are executed"""
|
||||
# Update state with new trades
|
||||
existing = self._state.get("trades", [])
|
||||
self._state["trades"] = trades + existing
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_trades",
|
||||
"mode": "incremental",
|
||||
"data": trades,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_stats_update(self, stats: Dict):
|
||||
"""Called when stats are updated"""
|
||||
self._state["stats"] = stats
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_stats",
|
||||
"data": stats,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_leaderboard_update(self, leaderboard: List[Dict]):
|
||||
"""Called when leaderboard is updated"""
|
||||
self._state["leaderboard"] = leaderboard
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_leaderboard",
|
||||
"data": leaderboard,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
# ========== System Events ==========
|
||||
|
||||
async def on_system_message(self, content: str):
|
||||
"""Emit a system message"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "system",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# ========== Replay Support ==========
|
||||
|
||||
async def replay_feed_history(self, delay_ms: int = 100):
|
||||
"""
|
||||
Replay events from feed_history
|
||||
|
||||
Useful for: frontend reconnection or restoring from saved state
|
||||
"""
|
||||
feed_history = self._state.get("feed_history", [])
|
||||
|
||||
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
||||
for event in reversed(feed_history):
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
logger.info(f"Replayed {len(feed_history)} events")
|
||||
|
||||
def get_initial_state_payload(
|
||||
self,
|
||||
include_dashboard: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build initial state payload for new client connections
|
||||
|
||||
Args:
|
||||
include_dashboard: Whether to load dashboard files
|
||||
|
||||
Returns:
|
||||
Dictionary suitable for sending to frontend
|
||||
"""
|
||||
payload = {
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"feed_history": self._state.get("feed_history", []),
|
||||
"current_date": self._state.get("current_date"),
|
||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||
"trading_days_completed": self._state.get(
|
||||
"trading_days_completed",
|
||||
0,
|
||||
),
|
||||
"holdings": self._state.get("holdings", []),
|
||||
"trades": self._state.get("trades", []),
|
||||
"stats": self._state.get("stats", {}),
|
||||
"leaderboard": self._state.get("leaderboard", []),
|
||||
"portfolio": self._state.get("portfolio", {}),
|
||||
"realtime_prices": self._state.get("realtime_prices", {}),
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
payload["dashboard"] = {
|
||||
"summary": self.storage.load_file("summary"),
|
||||
"holdings": self.storage.load_file("holdings"),
|
||||
"stats": self.storage.load_file("stats"),
|
||||
"trades": self.storage.load_file("trades"),
|
||||
"leaderboard": self.storage.load_file("leaderboard"),
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def _calculate_progress(self) -> float:
|
||||
"""Calculate backtest progress percentage"""
|
||||
total = self._state.get("trading_days_total", 0)
|
||||
completed = self._state.get("trading_days_completed", 0)
|
||||
return completed / total if total > 0 else 0.0
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
"""Set total trading days for backtest progress tracking"""
|
||||
self._state["trading_days_total"] = len(dates)
|
||||
self._state["trading_days_completed"] = 0
|
||||
BIN
backend/data/__MACOSX/ret_data/._AAPL.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._AAPL.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._AMZN.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._AMZN.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._GOOGL.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._GOOGL.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._META.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._META.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._MSFT.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._MSFT.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._NVDA.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._NVDA.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._TSLA.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._TSLA.csv
Normal file
Binary file not shown.
|
6
backend/data/__init__.py
Normal file
6
backend/data/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
__all__ = ["MockPriceManager", "PollingPriceManager", "HistoricalPriceManager"]
|
||||
107
backend/data/cache.py
Normal file
107
backend/data/cache.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing_extensions import Any
|
||||
|
||||
|
||||
class Cache:
|
||||
"""In-memory cache for API responses."""
|
||||
|
||||
def __init__(self):
|
||||
self._prices_cache = {}
|
||||
self._financial_metrics_cache = {}
|
||||
self._line_items_cache = {}
|
||||
self._insider_trades_cache = {}
|
||||
self._company_news_cache = {}
|
||||
|
||||
def _merge_data(
|
||||
self,
|
||||
existing: list[dict] | None,
|
||||
new_data: list[dict],
|
||||
key_field: str,
|
||||
) -> list[dict]:
|
||||
"""Merge existing and new data"""
|
||||
if not existing:
|
||||
return new_data
|
||||
|
||||
# Create a set of existing keys for O(1) lookup
|
||||
existing_keys = {item[key_field] for item in existing}
|
||||
|
||||
# Only add items that don't exist yet
|
||||
merged = existing.copy()
|
||||
merged.extend(
|
||||
[
|
||||
item
|
||||
for item in new_data
|
||||
if item[key_field] not in existing_keys
|
||||
],
|
||||
)
|
||||
return merged
|
||||
|
||||
def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached price data if available."""
|
||||
return self._prices_cache.get(ticker)
|
||||
|
||||
def set_prices(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new price data to cache."""
|
||||
self._prices_cache[ticker] = self._merge_data(
|
||||
self._prices_cache.get(ticker),
|
||||
data,
|
||||
key_field="time",
|
||||
)
|
||||
|
||||
def get_financial_metrics(self, ticker: str) -> list[dict[str, Any]]:
|
||||
"""Get cached financial metrics if available."""
|
||||
return self._financial_metrics_cache.get(ticker)
|
||||
|
||||
def set_financial_metrics(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new financial metrics to cache."""
|
||||
self._financial_metrics_cache[ticker] = self._merge_data(
|
||||
self._financial_metrics_cache.get(ticker),
|
||||
data,
|
||||
key_field="report_period",
|
||||
)
|
||||
|
||||
def get_line_items(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached line items if available."""
|
||||
return self._line_items_cache.get(ticker)
|
||||
|
||||
def set_line_items(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new line items to cache."""
|
||||
self._line_items_cache[ticker] = self._merge_data(
|
||||
self._line_items_cache.get(ticker),
|
||||
data,
|
||||
key_field="report_period",
|
||||
)
|
||||
|
||||
def get_insider_trades(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached insider trades if available."""
|
||||
return self._insider_trades_cache.get(ticker)
|
||||
|
||||
def set_insider_trades(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new insider trades to cache."""
|
||||
self._insider_trades_cache[ticker] = self._merge_data(
|
||||
self._insider_trades_cache.get(ticker),
|
||||
data,
|
||||
key_field="filing_date",
|
||||
) # Could also use transaction_date if preferred
|
||||
|
||||
def get_company_news(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached company news if available."""
|
||||
return self._company_news_cache.get(ticker)
|
||||
|
||||
def set_company_news(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new company news to cache."""
|
||||
self._company_news_cache[ticker] = self._merge_data(
|
||||
self._company_news_cache.get(ticker),
|
||||
data,
|
||||
key_field="date",
|
||||
)
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_cache = Cache()
|
||||
|
||||
|
||||
def get_cache() -> Cache:
|
||||
"""Get the global cache instance."""
|
||||
return _cache
|
||||
233
backend/data/historical_price_manager.py
Normal file
233
backend/data/historical_price_manager.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Historical Price Manager for backtest mode
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Path to local CSV data directory
|
||||
_DATA_DIR = Path(__file__).parent / "ret_data"
|
||||
|
||||
|
||||
class HistoricalPriceManager:
|
||||
"""Provides historical prices for backtest mode"""
|
||||
|
||||
def __init__(self):
|
||||
self.subscribed_symbols = []
|
||||
self.price_callbacks = []
|
||||
self._price_cache = {}
|
||||
self._current_date = None
|
||||
self.latest_prices = {}
|
||||
self.open_prices = {}
|
||||
self.close_prices = {}
|
||||
self.running = False
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
):
|
||||
"""Subscribe to symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self._price_cache.pop(symbol, None)
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _load_from_csv(self, symbol: str) -> Optional[pd.DataFrame]:
|
||||
"""Load price data from local CSV file."""
|
||||
csv_path = _DATA_DIR / f"{symbol}.csv"
|
||||
if not csv_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty or "time" not in df.columns:
|
||||
return None
|
||||
|
||||
df["Date"] = pd.to_datetime(df["time"])
|
||||
df.set_index("Date", inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def preload_data(self, start_date: str, end_date: str):
|
||||
"""Preload historical data from local CSV files."""
|
||||
logger.info(f"Preloading data: {start_date} to {end_date}")
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
if symbol in self._price_cache:
|
||||
continue
|
||||
|
||||
# Load from local CSV file directly
|
||||
df = self._load_from_csv(symbol)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
||||
else:
|
||||
logger.warning(f"No CSV data for {symbol}")
|
||||
|
||||
def set_date(self, date: str):
|
||||
"""Set current trading date and update prices"""
|
||||
self._current_date = date
|
||||
date_dt = pd.Timestamp(date)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
df = self._price_cache.get(symbol)
|
||||
if df is None or df.empty:
|
||||
# Keep previous prices if no data available
|
||||
logger.warning(f"No cached data for {symbol} on {date}")
|
||||
continue
|
||||
|
||||
# Find exact date or closest earlier date
|
||||
if date_dt in df.index:
|
||||
row = df.loc[date_dt]
|
||||
else:
|
||||
valid_dates = df.index[df.index <= date_dt]
|
||||
if len(valid_dates) == 0:
|
||||
logger.warning(f"No data for {symbol} on or before {date}")
|
||||
continue
|
||||
row = df.loc[valid_dates[-1]]
|
||||
|
||||
open_price = float(row["open"])
|
||||
close_price = float(row["close"])
|
||||
|
||||
self.open_prices[symbol] = open_price
|
||||
self.close_prices[symbol] = close_price
|
||||
self.latest_prices[symbol] = open_price
|
||||
|
||||
logger.debug(
|
||||
f"{symbol} @ {date}: open={open_price:.2f}, close={close_price:.2f}", # noqa: E501
|
||||
)
|
||||
|
||||
def emit_open_prices(self):
|
||||
"""Emit open prices to callbacks"""
|
||||
if not self._current_date:
|
||||
return
|
||||
|
||||
timestamp = int(
|
||||
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
||||
* 1000,
|
||||
)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
price = self.open_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
logger.warning(f"Invalid open price for {symbol}: {price}")
|
||||
continue
|
||||
|
||||
self.latest_prices[symbol] = price
|
||||
self._emit_price(symbol, price, timestamp)
|
||||
|
||||
def emit_close_prices(self):
|
||||
"""Emit close prices to callbacks"""
|
||||
if not self._current_date:
|
||||
return
|
||||
|
||||
timestamp = int(
|
||||
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
||||
* 1000,
|
||||
)
|
||||
timestamp += 23400000 # Add 6.5 hours
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
price = self.close_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
logger.warning(f"Invalid close price for {symbol}: {price}")
|
||||
continue
|
||||
|
||||
self.latest_prices[symbol] = price
|
||||
self._emit_price(symbol, price, timestamp)
|
||||
|
||||
def _emit_price(self, symbol: str, price: float, timestamp: int):
|
||||
"""Emit single price to callbacks"""
|
||||
open_price = self.open_prices.get(symbol, price)
|
||||
close_price = self.close_prices.get(symbol, price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"timestamp": timestamp,
|
||||
"open": open_price,
|
||||
"close": close_price,
|
||||
"high": max(open_price, close_price),
|
||||
"low": min(open_price, close_price),
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
symbol: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
df = self._price_cache.get(symbol)
|
||||
if df is None or df.empty:
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
date_dt = pd.Timestamp(date)
|
||||
if date_dt in df.index:
|
||||
return float(df.loc[date_dt, price_type])
|
||||
|
||||
valid_dates = df.index[df.index <= date_dt]
|
||||
if len(valid_dates) == 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return float(df.loc[valid_dates[-1], price_type])
|
||||
|
||||
def start(self):
|
||||
"""Start manager"""
|
||||
self.running = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop manager"""
|
||||
self.running = False
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
# Return open price, fallback to latest if not set
|
||||
price = self.open_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return price
|
||||
|
||||
def get_close_price(self, symbol: str) -> Optional[float]:
|
||||
# Return close price, fallback to latest if not set
|
||||
price = self.close_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return price
|
||||
|
||||
def reset_open_prices(self):
|
||||
# Don't clear prices - keep them for continuity
|
||||
pass
|
||||
241
backend/data/mock_price_manager.py
Normal file
241
backend/data/mock_price_manager.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mock Price Manager - For testing during non-trading hours
|
||||
Generates virtual real-time price data
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockPriceManager:
|
||||
"""Mock Price Manager - Generates virtual prices for testing"""
|
||||
|
||||
def __init__(self, poll_interval: int = 10, volatility: float = 0.5):
|
||||
"""
|
||||
Args:
|
||||
poll_interval: Price update interval in seconds
|
||||
volatility: Price volatility percentage
|
||||
"""
|
||||
if poll_interval is None:
|
||||
poll_interval = int(os.getenv("MOCK_POLL_INTERVAL", "5"))
|
||||
if volatility is None:
|
||||
volatility = float(os.getenv("MOCK_VOLATILITY", "0.5"))
|
||||
|
||||
self.poll_interval = poll_interval
|
||||
self.volatility = volatility
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.base_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self.default_base_prices = {
|
||||
"AAPL": 237.50,
|
||||
"MSFT": 425.30,
|
||||
"GOOGL": 161.50,
|
||||
"AMZN": 218.45,
|
||||
"NVDA": 950.00,
|
||||
"META": 573.22,
|
||||
"TSLA": 342.15,
|
||||
"AMD": 168.90,
|
||||
"NFLX": 688.25,
|
||||
"INTC": 42.18,
|
||||
"COIN": 285.50,
|
||||
"PLTR": 45.80,
|
||||
"BABA": 88.30,
|
||||
"DIS": 112.50,
|
||||
"BKNG": 4850.00,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MockPriceManager initialized (interval: {self.poll_interval}s, "
|
||||
f"volatility: {self.volatility}%)",
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
base_prices: Dict[str, float] = None,
|
||||
):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
if base_prices and symbol in base_prices:
|
||||
base_price = base_prices[symbol]
|
||||
elif symbol in self.default_base_prices:
|
||||
base_price = self.default_base_prices[symbol]
|
||||
else:
|
||||
base_price = random.uniform(50, 500)
|
||||
|
||||
self.base_prices[symbol] = base_price
|
||||
self.open_prices[symbol] = base_price
|
||||
self.latest_prices[symbol] = base_price
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to mock price: {symbol} (base: ${base_price:.2f})", # noqa: E501
|
||||
)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self.base_prices.pop(symbol, None)
|
||||
self.open_prices.pop(symbol, None)
|
||||
self.latest_prices.pop(symbol, None)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _generate_price_update(self, symbol: str) -> float:
|
||||
"""Generate price update based on random walk"""
|
||||
current_price = self.latest_prices.get(
|
||||
symbol,
|
||||
self.base_prices[symbol],
|
||||
)
|
||||
|
||||
change_percent = random.uniform(-self.volatility, self.volatility)
|
||||
new_price = current_price * (1 + change_percent / 100)
|
||||
|
||||
# 10% chance of larger movement
|
||||
if random.random() < 0.1:
|
||||
trend_factor = random.uniform(-2, 2)
|
||||
new_price = new_price * (1 + trend_factor / 100)
|
||||
|
||||
# Limit intraday movement to +/-10%
|
||||
open_price = self.open_prices[symbol]
|
||||
max_price = open_price * 1.10
|
||||
min_price = open_price * 0.90
|
||||
new_price = max(min_price, min(max_price, new_price))
|
||||
|
||||
return new_price
|
||||
|
||||
def _update_prices(self):
|
||||
"""Update prices for all subscribed stocks"""
|
||||
timestamp = int(time.time() * 1000)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
new_price = self._generate_price_update(symbol)
|
||||
self.latest_prices[symbol] = new_price
|
||||
|
||||
open_price = self.open_prices[symbol]
|
||||
ret = ((new_price - open_price) / open_price) * 100
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": new_price,
|
||||
"timestamp": timestamp,
|
||||
"volume": random.randint(1000000, 10000000),
|
||||
"open": open_price,
|
||||
"high": max(new_price, open_price),
|
||||
"low": min(new_price, open_price),
|
||||
"previous_close": open_price,
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Mock price callback error ({symbol}): {e}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Mock {symbol}: ${new_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate mock price ({symbol}): {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(
|
||||
f"Mock price generation started (interval: {self.poll_interval}s)",
|
||||
)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._update_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mock polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start mock price generation"""
|
||||
if self.running:
|
||||
logger.warning("Mock price manager already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Mock price manager started: {', '.join(self.subscribed_symbols)}", # noqa: E501
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop mock price generation"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Mock price manager stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
last_close = self.latest_prices[symbol]
|
||||
gap_percent = random.uniform(-1, 1)
|
||||
new_open = last_close * (1 + gap_percent / 100)
|
||||
self.open_prices[symbol] = new_open
|
||||
self.latest_prices[symbol] = new_open
|
||||
logger.info("Open prices reset")
|
||||
|
||||
def set_base_price(self, symbol: str, price: float):
|
||||
"""Manually set base price for testing"""
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.base_prices[symbol] = price
|
||||
self.open_prices[symbol] = price
|
||||
self.latest_prices[symbol] = price
|
||||
logger.info(f"{symbol} base price set to: ${price:.2f}")
|
||||
else:
|
||||
logger.warning(f"{symbol} not subscribed")
|
||||
175
backend/data/polling_price_manager.py
Normal file
175
backend/data/polling_price_manager.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Polling-based Price Manager - Uses Finnhub REST API
|
||||
Supports real-time price fetching via polling
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import finnhub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PollingPriceManager:
|
||||
"""Polling-based price manager using Finnhub Quote API"""
|
||||
|
||||
def __init__(self, api_key: str, poll_interval: int = 30):
|
||||
"""
|
||||
Args:
|
||||
api_key: Finnhub API Key
|
||||
poll_interval: Polling interval in seconds (default 30s)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.poll_interval = poll_interval
|
||||
self.finnhub_client = finnhub.Client(api_key=api_key)
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
logger.info(
|
||||
f"PollingPriceManager initialized (interval: {poll_interval}s)",
|
||||
)
|
||||
|
||||
def subscribe(self, symbols: List[str]):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
logger.info(f"Subscribed to: {symbol}")
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _fetch_prices(self):
|
||||
"""Fetch latest prices for all subscribed stocks"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
quote_data = self.finnhub_client.quote(symbol)
|
||||
|
||||
current_price = quote_data.get("c")
|
||||
open_price = quote_data.get("o")
|
||||
timestamp = quote_data.get("t", int(time.time()))
|
||||
|
||||
if not current_price or current_price <= 0:
|
||||
logger.warning(f"{symbol}: Invalid price data")
|
||||
continue
|
||||
|
||||
# Store open price on first fetch
|
||||
if (
|
||||
symbol not in self.open_prices
|
||||
and open_price
|
||||
and open_price > 0
|
||||
):
|
||||
self.open_prices[symbol] = open_price
|
||||
logger.info(f"{symbol} open price: ${open_price:.2f}")
|
||||
|
||||
stored_open = self.open_prices.get(symbol, open_price)
|
||||
ret = (
|
||||
((current_price - stored_open) / stored_open) * 100
|
||||
if stored_open > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
self.latest_prices[symbol] = current_price
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": current_price,
|
||||
"timestamp": timestamp * 1000,
|
||||
"open": stored_open,
|
||||
"high": quote_data.get("h"),
|
||||
"low": quote_data.get("l"),
|
||||
"previous_close": quote_data.get("pc"),
|
||||
"ret": ret,
|
||||
"change": quote_data.get("d"),
|
||||
"change_percent": quote_data.get("dp"),
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Price callback error ({symbol}): {e}")
|
||||
|
||||
logger.debug(
|
||||
f"{symbol}: ${current_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch {symbol} price: {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(f"Price polling started (interval: {self.poll_interval}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._fetch_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start price polling"""
|
||||
if self.running:
|
||||
logger.warning("Price polling already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Price polling started: {', '.join(self.subscribed_symbols)}",
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop price polling"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Price polling stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
self.open_prices.clear()
|
||||
logger.info("Open prices reset")
|
||||
387
backend/data/ret_data_updater.py
Normal file
387
backend/data/ret_data_updater.py
Normal file
@@ -0,0 +1,387 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Automatic Incremental Historical Data Update Module
|
||||
|
||||
Features:
|
||||
1. Fetch stock historical data from configured API (Finnhub or Financial Datasets)
|
||||
2. Incrementally update CSV files in ret_data directory
|
||||
3. Automatically detect last update date, only download new data
|
||||
4. Calculate returns (ret)
|
||||
5. Support batch updates for multiple stocks
|
||||
"""
|
||||
|
||||
# flake8: noqa: E501
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import exchange_calendars as xcals
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.config.data_config import (
|
||||
get_config,
|
||||
)
|
||||
from backend.tools.data_tools import get_prices, prices_to_df
|
||||
|
||||
# Add project root directory to path
|
||||
BASE_DIR = Path(__file__).resolve().parents[2]
|
||||
if str(BASE_DIR) not in sys.path:
|
||||
sys.path.append(str(BASE_DIR))
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataUpdater:
|
||||
"""Data updater"""
|
||||
|
||||
data_dir: Path
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str = None,
|
||||
start_date: str = "2022-01-01",
|
||||
):
|
||||
"""
|
||||
Initialize data updater
|
||||
|
||||
Args:
|
||||
data_dir: Data storage directory, defaults to backend/data/ret_data
|
||||
start_date: Historical data start date (YYYY-MM-DD)
|
||||
"""
|
||||
# Get config from centralized source
|
||||
config = get_config()
|
||||
self.data_source = config.source
|
||||
self.api_key = config.api_key
|
||||
|
||||
# Set data directory
|
||||
if data_dir is None:
|
||||
self.data_dir = BASE_DIR / "backend" / "data" / "ret_data"
|
||||
else:
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
# Ensure directory exists
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.start_date = start_date
|
||||
|
||||
# Initialize Finnhub client if needed
|
||||
if self.data_source == "finnhub":
|
||||
import finnhub
|
||||
|
||||
self.client = finnhub.Client(api_key=self.api_key)
|
||||
logger.info("Finnhub client initialized")
|
||||
else:
|
||||
self.client = None
|
||||
logger.info("Financial Datasets API configured")
|
||||
|
||||
def get_trading_dates(self, start_date: str, end_date: str) -> List[str]:
|
||||
"""Get US stock market trading date sequence."""
|
||||
try:
|
||||
if mcal is not None:
|
||||
nyse = mcal.get_calendar("NYSE")
|
||||
trading_dates = nyse.valid_days(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return [date.strftime("%Y-%m-%d") for date in trading_dates]
|
||||
|
||||
elif xcals is not None:
|
||||
nyse = xcals.get_calendar("XNYS")
|
||||
trading_dates = nyse.sessions_in_range(start_date, end_date)
|
||||
return [date.strftime("%Y-%m-%d") for date in trading_dates]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get US trading calendar, using business days: {e}",
|
||||
)
|
||||
|
||||
# Fallback to simple business day method
|
||||
date_range = pd.date_range(start_date, end_date, freq="B")
|
||||
return [date.strftime("%Y-%m-%d") for date in date_range]
|
||||
|
||||
def get_last_date_from_csv(self, ticker: str) -> Optional[datetime]:
|
||||
"""Get last data date from CSV file."""
|
||||
csv_path = self.data_dir / f"{ticker}.csv"
|
||||
|
||||
if not csv_path.exists():
|
||||
logger.info(f"{ticker}.csv does not exist, will create new file")
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty or "time" not in df.columns:
|
||||
return None
|
||||
|
||||
last_date_str = df["time"].iloc[-1]
|
||||
last_date = datetime.strptime(last_date_str, "%Y-%m-%d")
|
||||
logger.info(f"{ticker} last data date: {last_date_str}")
|
||||
return last_date
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {ticker}.csv: {e}")
|
||||
return None
|
||||
|
||||
def fetch_data_from_api(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""Fetch data from configured API."""
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Fetching {ticker} data from {self.data_source}: {start_date_str} to {end_date_str}",
|
||||
)
|
||||
|
||||
prices = get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
)
|
||||
|
||||
if not prices:
|
||||
logger.warning(f"{ticker} no data returned from API")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = prices_to_df(prices)
|
||||
df = df.reset_index()
|
||||
df["time"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate returns (next day return)
|
||||
df["ret"] = df["close"].pct_change().shift(-1)
|
||||
|
||||
# Select needed columns
|
||||
df = df[["open", "close", "high", "low", "volume", "time", "ret"]]
|
||||
|
||||
logger.info(f"Successfully fetched {ticker} data: {len(df)} records")
|
||||
return df
|
||||
|
||||
def merge_and_save(self, ticker: str, new_data: pd.DataFrame) -> bool:
|
||||
"""Merge old and new data and save."""
|
||||
csv_path = self.data_dir / f"{ticker}.csv"
|
||||
|
||||
try:
|
||||
if csv_path.exists():
|
||||
old_data = pd.read_csv(csv_path)
|
||||
logger.info(f"{ticker} existing data: {len(old_data)} records")
|
||||
|
||||
# Merge and deduplicate
|
||||
combined = pd.concat([old_data, new_data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["time"],
|
||||
keep="last",
|
||||
)
|
||||
combined = combined.sort_values("time").reset_index(drop=True)
|
||||
|
||||
# Recalculate returns
|
||||
combined["ret"] = combined["close"].pct_change().shift(-1)
|
||||
|
||||
logger.info(f"{ticker} merged data: {len(combined)} records")
|
||||
else:
|
||||
combined = new_data
|
||||
logger.info(f"{ticker} new file: {len(combined)} records")
|
||||
|
||||
combined.to_csv(csv_path, index=False)
|
||||
logger.info(f"{ticker} data saved to: {csv_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {ticker} data: {e}")
|
||||
return False
|
||||
|
||||
def update_ticker(
|
||||
self,
|
||||
ticker: str,
|
||||
force_full_update: bool = False,
|
||||
) -> bool:
|
||||
"""Update data for a single stock."""
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Starting update for {ticker}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# Determine start date
|
||||
if force_full_update:
|
||||
start_date = datetime.strptime(self.start_date, "%Y-%m-%d")
|
||||
logger.info(f"Force full update, start date: {start_date.date()}")
|
||||
else:
|
||||
last_date = self.get_last_date_from_csv(ticker)
|
||||
if last_date:
|
||||
start_date = last_date + timedelta(days=1)
|
||||
logger.info(
|
||||
f"Incremental update, start date: {start_date.date()}",
|
||||
)
|
||||
else:
|
||||
start_date = datetime.strptime(self.start_date, "%Y-%m-%d")
|
||||
logger.info(f"First update, start date: {start_date.date()}")
|
||||
|
||||
end_date = datetime.now()
|
||||
|
||||
if start_date.date() >= end_date.date():
|
||||
logger.info(f"{ticker} data is up to date, no update needed")
|
||||
return True
|
||||
|
||||
new_data = self.fetch_data_from_api(ticker, start_date, end_date)
|
||||
|
||||
if new_data is None or new_data.empty:
|
||||
days_diff = (end_date - start_date).days
|
||||
if days_diff <= 3:
|
||||
logger.info(
|
||||
f"{ticker} has no new data (may be weekend/holiday)",
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"{ticker} has no new data")
|
||||
return False
|
||||
|
||||
success = self.merge_and_save(ticker, new_data)
|
||||
|
||||
if success:
|
||||
logger.info(f"{ticker} update completed")
|
||||
else:
|
||||
logger.error(f"{ticker} update failed")
|
||||
|
||||
return success
|
||||
|
||||
def update_all_tickers(
|
||||
self,
|
||||
tickers: List[str],
|
||||
force_full_update: bool = False,
|
||||
) -> Dict[str, bool]:
|
||||
"""Batch update multiple stocks."""
|
||||
results = {}
|
||||
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Starting batch update for {len(tickers)} stocks")
|
||||
logger.info(f"Stock list: {', '.join(tickers)}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
for i, ticker in enumerate(tickers, 1):
|
||||
logger.info(f"[{i}/{len(tickers)}] Processing {ticker}")
|
||||
results[ticker] = self.update_ticker(ticker, force_full_update)
|
||||
|
||||
# API rate limiting
|
||||
if i < len(tickers):
|
||||
time.sleep(1)
|
||||
|
||||
# Print summary
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info("Update Summary")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
success_count = sum(results.values())
|
||||
fail_count = len(results) - success_count
|
||||
|
||||
logger.info(f"Success: {success_count}")
|
||||
logger.info(f"Failed: {fail_count}")
|
||||
|
||||
if fail_count > 0:
|
||||
failed_tickers = [t for t, s in results.items() if not s]
|
||||
logger.warning(f"Failed stocks: {', '.join(failed_tickers)}")
|
||||
|
||||
logger.info(f"{'='*60}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Command line entry point"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Automatically update stock historical data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tickers",
|
||||
type=str,
|
||||
help="Stock ticker list (comma-separated), e.g.: AAPL,MSFT,GOOGL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
help="Data storage directory (default: backend/data/ret_data)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-date",
|
||||
type=str,
|
||||
default="2022-01-01",
|
||||
help="Historical data start date (YYYY-MM-DD, default: 2022-01-01)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force full update (re-download all data)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Validate API key is available
|
||||
try:
|
||||
config = get_config()
|
||||
logger.info(f"Using data source: {config.source}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
# Get stock list
|
||||
if args.tickers:
|
||||
tickers = [t.strip().upper() for t in args.tickers.split(",")]
|
||||
else:
|
||||
tickers_env = os.getenv("TICKERS", "")
|
||||
if tickers_env:
|
||||
tickers = [t.strip().upper() for t in tickers_env.split(",")]
|
||||
else:
|
||||
logger.error("Stock list not provided")
|
||||
logger.error(
|
||||
"Please set via --tickers parameter or TICKERS environment variable",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create updater
|
||||
updater = DataUpdater(
|
||||
data_dir=args.data_dir,
|
||||
start_date=args.start_date,
|
||||
)
|
||||
|
||||
# Execute update
|
||||
try:
|
||||
results = updater.update_all_tickers(
|
||||
tickers,
|
||||
force_full_update=args.force,
|
||||
)
|
||||
except Exception:
|
||||
# API error (e.g., weekend/holiday with no data)
|
||||
sys.exit(1)
|
||||
|
||||
# Return status code
|
||||
success_count = sum(results.values())
|
||||
if success_count == len(results):
|
||||
logger.info("All stocks updated successfully!")
|
||||
sys.exit(0)
|
||||
elif success_count == 0:
|
||||
logger.warning("All stocks have no new data (may be weekend/holiday)")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.warning("Some stocks failed to update, but will continue")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
backend/data/schema.py
Normal file
184
backend/data/schema.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Price(BaseModel):
|
||||
open: float
|
||||
close: float
|
||||
high: float
|
||||
low: float
|
||||
volume: int
|
||||
time: str
|
||||
|
||||
|
||||
class PriceResponse(BaseModel):
|
||||
ticker: str
|
||||
prices: list[Price]
|
||||
|
||||
|
||||
class FinancialMetrics(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
market_cap: float | None
|
||||
enterprise_value: float | None
|
||||
price_to_earnings_ratio: float | None
|
||||
price_to_book_ratio: float | None
|
||||
price_to_sales_ratio: float | None
|
||||
enterprise_value_to_ebitda_ratio: float | None
|
||||
enterprise_value_to_revenue_ratio: float | None
|
||||
free_cash_flow_yield: float | None
|
||||
peg_ratio: float | None
|
||||
gross_margin: float | None
|
||||
operating_margin: float | None
|
||||
net_margin: float | None
|
||||
return_on_equity: float | None
|
||||
return_on_assets: float | None
|
||||
return_on_invested_capital: float | None
|
||||
asset_turnover: float | None
|
||||
inventory_turnover: float | None
|
||||
receivables_turnover: float | None
|
||||
days_sales_outstanding: float | None
|
||||
operating_cycle: float | None
|
||||
working_capital_turnover: float | None
|
||||
current_ratio: float | None
|
||||
quick_ratio: float | None
|
||||
cash_ratio: float | None
|
||||
operating_cash_flow_ratio: float | None
|
||||
debt_to_equity: float | None
|
||||
debt_to_assets: float | None
|
||||
interest_coverage: float | None
|
||||
revenue_growth: float | None
|
||||
earnings_growth: float | None
|
||||
book_value_growth: float | None
|
||||
earnings_per_share_growth: float | None
|
||||
free_cash_flow_growth: float | None
|
||||
operating_income_growth: float | None
|
||||
ebitda_growth: float | None
|
||||
payout_ratio: float | None
|
||||
earnings_per_share: float | None
|
||||
book_value_per_share: float | None
|
||||
free_cash_flow_per_share: float | None
|
||||
|
||||
|
||||
class FinancialMetricsResponse(BaseModel):
|
||||
financial_metrics: list[FinancialMetrics]
|
||||
|
||||
|
||||
class LineItem(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
|
||||
# Allow additional fields dynamically
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
search_results: list[LineItem]
|
||||
|
||||
|
||||
class InsiderTrade(BaseModel):
|
||||
ticker: str
|
||||
issuer: str | None
|
||||
name: str | None
|
||||
title: str | None
|
||||
is_board_director: bool | None
|
||||
transaction_date: str | None
|
||||
transaction_shares: float | None
|
||||
transaction_price_per_share: float | None
|
||||
transaction_value: float | None
|
||||
shares_owned_before_transaction: float | None
|
||||
shares_owned_after_transaction: float | None
|
||||
security_title: str | None
|
||||
filing_date: str
|
||||
|
||||
|
||||
class InsiderTradeResponse(BaseModel):
|
||||
insider_trades: list[InsiderTrade]
|
||||
|
||||
|
||||
class CompanyNews(BaseModel):
|
||||
category: str | None = None
|
||||
ticker: str
|
||||
title: str
|
||||
related: str | None = None
|
||||
source: str
|
||||
date: str | None = None
|
||||
url: str
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class CompanyNewsResponse(BaseModel):
|
||||
news: list[CompanyNews]
|
||||
|
||||
|
||||
class CompanyFacts(BaseModel):
|
||||
ticker: str
|
||||
name: str
|
||||
cik: str | None = None
|
||||
industry: str | None = None
|
||||
sector: str | None = None
|
||||
category: str | None = None
|
||||
exchange: str | None = None
|
||||
is_active: bool | None = None
|
||||
listing_date: str | None = None
|
||||
location: str | None = None
|
||||
market_cap: float | None = None
|
||||
number_of_employees: int | None = None
|
||||
sec_filings_url: str | None = None
|
||||
sic_code: str | None = None
|
||||
sic_industry: str | None = None
|
||||
sic_sector: str | None = None
|
||||
website_url: str | None = None
|
||||
weighted_average_shares: int | None = None
|
||||
|
||||
|
||||
class CompanyFactsResponse(BaseModel):
|
||||
company_facts: CompanyFacts
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""Position information - for Portfolio mode"""
|
||||
|
||||
long: int = 0 # Long position quantity (shares)
|
||||
short: int = 0 # Short position quantity (shares)
|
||||
long_cost_basis: float = 0.0 # Long position average cost
|
||||
short_cost_basis: float = 0.0 # Short position average cost
|
||||
|
||||
|
||||
class Portfolio(BaseModel):
|
||||
"""Portfolio - for Portfolio mode"""
|
||||
|
||||
cash: float = 100000.0 # Available cash
|
||||
positions: dict[str, Position] = {} # ticker -> Position mapping
|
||||
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
|
||||
margin_requirement: float = 0.0
|
||||
margin_used: float = 0.0 # Margin used
|
||||
|
||||
|
||||
class AnalystSignal(BaseModel):
|
||||
signal: str | None = None
|
||||
confidence: float | None = None
|
||||
reasoning: dict | str | None = None
|
||||
max_position_size: float | None = None # For risk management signals
|
||||
|
||||
|
||||
class TickerAnalysis(BaseModel):
|
||||
ticker: str
|
||||
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
|
||||
|
||||
|
||||
class AgentStateData(BaseModel):
|
||||
tickers: list[str]
|
||||
portfolio: Portfolio
|
||||
start_date: str
|
||||
end_date: str
|
||||
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
|
||||
|
||||
|
||||
class AgentStateMetadata(BaseModel):
|
||||
show_reasoning: bool = False
|
||||
model_config = {"extra": "allow"}
|
||||
0
backend/llm/__init__.py
Normal file
0
backend/llm/__init__.py
Normal file
243
backend/llm/models.py
Normal file
243
backend/llm/models.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from agentscope.formatter import (
|
||||
AnthropicChatFormatter,
|
||||
DashScopeChatFormatter,
|
||||
GeminiChatFormatter,
|
||||
OllamaChatFormatter,
|
||||
OpenAIChatFormatter,
|
||||
)
|
||||
from agentscope.model import (
|
||||
AnthropicChatModel,
|
||||
DashScopeChatModel,
|
||||
GeminiChatModel,
|
||||
OllamaChatModel,
|
||||
OpenAIChatModel,
|
||||
)
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""Supported model providers"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
DASHSCOPE = "DASHSCOPE"
|
||||
ALIBABA = "ALIBABA"
|
||||
GEMINI = "GEMINI"
|
||||
GOOGLE = "GOOGLE"
|
||||
OLLAMA = "OLLAMA"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
GROQ = "GROQ"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
|
||||
|
||||
# Provider to AgentScope model class mapping
|
||||
PROVIDER_MODEL_MAP = {
|
||||
"OPENAI": OpenAIChatModel,
|
||||
"ANTHROPIC": AnthropicChatModel,
|
||||
"DASHSCOPE": DashScopeChatModel,
|
||||
"ALIBABA": DashScopeChatModel,
|
||||
"GEMINI": GeminiChatModel,
|
||||
"GOOGLE": GeminiChatModel,
|
||||
"OLLAMA": OllamaChatModel,
|
||||
# OpenAI-compatible providers use OpenAIChatModel with custom base_url
|
||||
"DEEPSEEK": OpenAIChatModel,
|
||||
"GROQ": OpenAIChatModel,
|
||||
"OPENROUTER": OpenAIChatModel,
|
||||
}
|
||||
|
||||
# Provider to formatter mapping
|
||||
PROVIDER_FORMATTER_MAP = {
|
||||
"OPENAI": OpenAIChatFormatter,
|
||||
"ANTHROPIC": AnthropicChatFormatter,
|
||||
"DASHSCOPE": DashScopeChatFormatter,
|
||||
"ALIBABA": DashScopeChatFormatter,
|
||||
"GEMINI": GeminiChatFormatter,
|
||||
"GOOGLE": GeminiChatFormatter,
|
||||
"OLLAMA": OllamaChatFormatter,
|
||||
# OpenAI-compatible providers use OpenAIChatFormatter
|
||||
"DEEPSEEK": OpenAIChatFormatter,
|
||||
"GROQ": OpenAIChatFormatter,
|
||||
"OPENROUTER": OpenAIChatFormatter,
|
||||
}
|
||||
|
||||
# Provider-specific base URLs
|
||||
PROVIDER_BASE_URLS = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"GROQ": "https://api.groq.com/openai/v1",
|
||||
"OPENROUTER": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
|
||||
# Provider-specific API key environment variable names
|
||||
PROVIDER_API_KEY_ENV = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create an AgentScope model instance
|
||||
|
||||
Args:
|
||||
model_name: Model name (e.g., "gpt-4o", "claude-3-opus")
|
||||
provider: Provider name (e.g., "OPENAI", "ANTHROPIC")
|
||||
api_key: API key (optional, will read from env if not provided)
|
||||
stream: Whether to use streaming mode
|
||||
**kwargs: Additional model-specific arguments
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
provider = provider.upper()
|
||||
|
||||
model_class = PROVIDER_MODEL_MAP.get(provider)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
# Get API key from env if not provided
|
||||
if api_key is None:
|
||||
env_key = PROVIDER_API_KEY_ENV.get(provider)
|
||||
if env_key:
|
||||
api_key = os.getenv(env_key)
|
||||
|
||||
# Build model kwargs
|
||||
model_kwargs = {
|
||||
"model_name": model_name,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Add API key if needed (Ollama doesn't need it)
|
||||
if provider != "OLLAMA" and api_key:
|
||||
model_kwargs["api_key"] = api_key
|
||||
|
||||
# Handle OpenAI-compatible providers with custom base_url
|
||||
if provider in PROVIDER_BASE_URLS:
|
||||
base_url = PROVIDER_BASE_URLS[provider]
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle custom OpenAI base URL
|
||||
if provider == "OPENAI":
|
||||
base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE")
|
||||
if base_url:
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle DashScope base URL (uses different parameter)
|
||||
if provider in ("DASHSCOPE", "ALIBABA"):
|
||||
base_url = os.getenv("DASHSCOPE_BASE_URL")
|
||||
if base_url:
|
||||
model_kwargs["base_http_api_url"] = base_url
|
||||
|
||||
# Handle Ollama host
|
||||
if provider == "OLLAMA":
|
||||
host = os.getenv("OLLAMA_HOST")
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
return model_class(**model_kwargs)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
"""
|
||||
Get model for a specific agent based on environment variables
|
||||
|
||||
Environment variable pattern:
|
||||
AGENT_{AGENT_ID}_MODEL_NAME: Model name
|
||||
AGENT_{AGENT_ID}_MODEL_PROVIDER: Provider name
|
||||
|
||||
fallback to global MODEL_NAME & MODEL_PROVIDER if agent-specific not given
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
stream: Whether to use streaming mode
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
# Normalize agent_id to uppercase for env var lookup
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
# Try agent-specific config first
|
||||
model_name = os.getenv(f"AGENT_{agent_key}_MODEL_NAME")
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
print(f"Using specific model {model_name} for agent {agent_key}")
|
||||
# Fall back to global config
|
||||
if not model_name:
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-4o")
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
return create_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def get_agent_formatter(agent_id: str):
|
||||
"""
|
||||
Get formatter for a specific agent based on environment variables
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
AgentScope formatter instance
|
||||
"""
|
||||
# Normalize agent_id to uppercase for env var lookup
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
# Try agent-specific config first
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
# Fall back to global config
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
provider = provider.upper()
|
||||
formatter_class = PROVIDER_FORMATTER_MAP.get(provider, OpenAIChatFormatter)
|
||||
return formatter_class()
|
||||
|
||||
|
||||
def get_agent_model_info(agent_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get model name and provider for a specific agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, provider_name)
|
||||
"""
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
model_name = os.getenv(f"AGENT_{agent_key}_MODEL_NAME")
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-4o")
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
return model_name, provider.upper()
|
||||
332
backend/main.py
Normal file
332
backend/main.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Main Entry Point
|
||||
Supports: backtest, live, mock modes
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
import loguru
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
loguru.logger.disable("flowllm")
|
||||
loguru.logger.disable("reme_ai")
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, config_name: str):
|
||||
"""
|
||||
Create ReMeTaskLongTermMemory for an agent
|
||||
|
||||
Requires DASHSCOPE_API_KEY env var
|
||||
"""
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
from agentscope.model import DashScopeChatModel
|
||||
from agentscope.embedding import DashScopeTextEmbedding
|
||||
|
||||
api_key = os.getenv("MEMORY_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
||||
return None
|
||||
|
||||
memory_dir = str(Path(config_name) / "memory")
|
||||
|
||||
return ReMeTaskLongTermMemory(
|
||||
agent_name=agent_name,
|
||||
user_name=agent_name,
|
||||
model=DashScopeChatModel(
|
||||
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
||||
api_key=api_key,
|
||||
stream=False,
|
||||
),
|
||||
embedding_model=DashScopeTextEmbedding(
|
||||
model_name=os.getenv(
|
||||
"MEMORY_EMBEDDING_MODEL",
|
||||
"text-embedding-v4",
|
||||
),
|
||||
api_key=api_key,
|
||||
dimensions=1024,
|
||||
),
|
||||
**{
|
||||
"vector_store.default.backend": "local",
|
||||
"vector_store.default.params.store_dir": memory_dir,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_agents(
|
||||
config_name: str,
|
||||
initial_cash: float,
|
||||
margin_requirement: float,
|
||||
enable_long_term_memory: bool = False,
|
||||
):
|
||||
"""Create all agents for the system
|
||||
|
||||
Returns:
|
||||
tuple: (analysts, risk_manager, portfolio_manager, long_term_memories)
|
||||
long_term_memories is a list of memory
|
||||
"""
|
||||
analysts = []
|
||||
long_term_memories = []
|
||||
|
||||
for analyst_type in ANALYST_TYPES:
|
||||
model = get_agent_model(analyst_type)
|
||||
formatter = get_agent_formatter(analyst_type)
|
||||
toolkit = create_toolkit(analyst_type)
|
||||
|
||||
long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
long_term_memory = create_long_term_memory(
|
||||
analyst_type,
|
||||
config_name,
|
||||
)
|
||||
if long_term_memory:
|
||||
long_term_memories.append(long_term_memory)
|
||||
|
||||
analyst = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=toolkit,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
agent_id=analyst_type,
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=long_term_memory,
|
||||
)
|
||||
analysts.append(analyst)
|
||||
|
||||
risk_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
risk_long_term_memory = create_long_term_memory(
|
||||
"risk_manager",
|
||||
config_name,
|
||||
)
|
||||
if risk_long_term_memory:
|
||||
long_term_memories.append(risk_long_term_memory)
|
||||
|
||||
risk_manager = RiskAgent(
|
||||
model=get_agent_model("risk_manager"),
|
||||
formatter=get_agent_formatter("risk_manager"),
|
||||
name="risk_manager",
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=risk_long_term_memory,
|
||||
)
|
||||
|
||||
pm_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
pm_long_term_memory = create_long_term_memory(
|
||||
"portfolio_manager",
|
||||
config_name,
|
||||
)
|
||||
if pm_long_term_memory:
|
||||
long_term_memories.append(pm_long_term_memory)
|
||||
|
||||
portfolio_manager = PMAgent(
|
||||
name="portfolio_manager",
|
||||
model=get_agent_model("portfolio_manager"),
|
||||
formatter=get_agent_formatter("portfolio_manager"),
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=pm_long_term_memory,
|
||||
)
|
||||
|
||||
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||
|
||||
|
||||
def create_toolkit(analyst_type: str):
|
||||
"""Create AgentScope Toolkit with tools for specific analyst type"""
|
||||
from agentscope.tool import Toolkit
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.tools.analysis_tools import TOOL_REGISTRY
|
||||
|
||||
# Load analyst persona config
|
||||
prompt_loader = PromptLoader()
|
||||
personas_config = prompt_loader.load_yaml_config("analyst", "personas")
|
||||
persona = personas_config.get(analyst_type, {})
|
||||
|
||||
# Get tool names for this analyst type
|
||||
tool_names = persona.get("tools", [])
|
||||
|
||||
# Create toolkit and register tools
|
||||
toolkit = Toolkit()
|
||||
for tool_name in tool_names:
|
||||
tool_func = TOOL_REGISTRY.get(tool_name)
|
||||
if tool_func:
|
||||
toolkit.register_tool_function(tool_func)
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
async def run_with_gateway(args):
|
||||
"""Run with WebSocket gateway"""
|
||||
is_backtest = args.mode == "backtest"
|
||||
|
||||
# Load config from env, override with args
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
margin_requirement = get_env_float("MARGIN_REQUIREMENT", 0.0)
|
||||
config_name = args.config_name
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=args.poll_interval,
|
||||
mock_mode=args.mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY")
|
||||
if not args.mock and not is_backtest
|
||||
else None,
|
||||
backtest_start_date=args.start_date if is_backtest else None,
|
||||
backtest_end_date=args.end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=Path(config_name) / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=config_name,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents and pipeline
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
config_name=config_name,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=args.enable_memory,
|
||||
)
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=get_env_int("MAX_COMM_CYCLES", 2),
|
||||
)
|
||||
|
||||
# Create scheduler callback
|
||||
scheduler_callback = None
|
||||
trading_dates = []
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
trading_dates = backtest_scheduler.get_trading_dates()
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
# Live mode: use daily scheduler with NYSE timezone
|
||||
live_scheduler = Scheduler(
|
||||
mode="daily",
|
||||
trigger_time=args.trigger_time,
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Create gateway
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": args.mode,
|
||||
"mock_mode": args.mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": config_name,
|
||||
},
|
||||
)
|
||||
|
||||
if is_backtest:
|
||||
gateway.set_backtest_dates(trading_dates)
|
||||
|
||||
# Start long-term memory contexts and run gateway
|
||||
async with AsyncExitStack() as stack:
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
await gateway.start(host=args.host, port=args.port)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Trading System")
|
||||
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
|
||||
parser.add_argument("--mock", action="store_true")
|
||||
parser.add_argument("--config-name", default="mock")
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
||||
parser.add_argument("--poll-interval", type=int, default=10)
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument(
|
||||
"--enable-memory",
|
||||
action="store_true",
|
||||
help="Enable ReMeTaskLongTermMemory for agents",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config from env for logging
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Mode: {args.mode}, Config: {args.config_name}")
|
||||
logger.info(f"Tickers: {tickers}")
|
||||
logger.info(f"Initial Cash: ${initial_cash:,.2f}")
|
||||
logger.info(
|
||||
f"Long-term Memory: {'enabled' if args.enable_memory else 'disabled'}",
|
||||
)
|
||||
if args.mode == "backtest":
|
||||
if not args.start_date or not args.end_date:
|
||||
parser.error(
|
||||
"--start-date and --end-date required for backtest mode",
|
||||
)
|
||||
logger.info(f"Backtest: {args.start_date} to {args.end_date}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
asyncio.run(run_with_gateway(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
backend/services/__init__.py
Normal file
2
backend/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Services layer for infrastructure components"""
|
||||
569
backend/services/gateway.py
Normal file
569
backend/services/gateway.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket Gateway for frontend communication
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""WebSocket Gateway for frontend communication"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
market_service: MarketService,
|
||||
storage_service: StorageService,
|
||||
pipeline: TradingPipeline,
|
||||
state_sync: Optional[StateSync] = None,
|
||||
scheduler_callback: Optional[Callable] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
self.market_service = market_service
|
||||
self.storage = storage_service
|
||||
self.pipeline = pipeline
|
||||
self.scheduler_callback = scheduler_callback
|
||||
self.config = config or {}
|
||||
|
||||
self.mode = self.config.get("mode", "live")
|
||||
self.is_backtest = self.mode == "backtest" or self.config.get(
|
||||
"backtest_mode",
|
||||
False,
|
||||
)
|
||||
|
||||
self.state_sync = state_sync or StateSync(storage=storage_service)
|
||||
# self.state_sync.set_mode(self.is_backtest)
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._dashboard = get_dashboard()
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
logger.info(f"Starting gateway on {host}:{port}")
|
||||
|
||||
# Initialize terminal dashboard
|
||||
self._dashboard.set_config(
|
||||
mode=self.mode,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
host=host,
|
||||
port=port,
|
||||
poll_interval=self.config.get("poll_interval", 10),
|
||||
mock=self.config.get("mock_mode", False),
|
||||
tickers=self.config.get("tickers", []),
|
||||
initial_cash=self.storage.initial_cash,
|
||||
start_date=self._backtest_start_date or "",
|
||||
end_date=self._backtest_end_date or "",
|
||||
)
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
self.state_sync.update_state(
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
|
||||
# Load and display existing portfolio state if available
|
||||
summary = self.storage.load_file("summary")
|
||||
if summary:
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
current_date = self.state_sync.state.get("current_date")
|
||||
self._dashboard.update(
|
||||
date=current_date or "-",
|
||||
status="running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded existing portfolio: $%s",
|
||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||
)
|
||||
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
|
||||
if self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
# Start market status monitoring (only for live mode)
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
|
||||
async with websockets.serve(
|
||||
self.handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=30,
|
||||
ping_timeout=60,
|
||||
):
|
||||
logger.info(
|
||||
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
await asyncio.Future()
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
|
||||
await self._send_initial_state(websocket)
|
||||
await self._handle_client_messages(websocket)
|
||||
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type", "unknown")
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
elif msg_type == "get_state":
|
||||
await self._send_initial_state(websocket)
|
||||
elif msg_type == "start_backtest":
|
||||
await self._handle_start_backtest(data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
dates = data.get("dates", [])
|
||||
if dates and self._backtest_task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_backtest_dates(dates),
|
||||
)
|
||||
task.add_done_callback(self._handle_backtest_exception)
|
||||
self._backtest_task = task
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message, ensure_ascii=False, default=str)
|
||||
|
||||
async with self.lock:
|
||||
tasks = [
|
||||
self._send_to_client(client, message_json)
|
||||
for client in self.connected_clients.copy()
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.ConnectionClosed:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(client)
|
||||
|
||||
async def _market_status_monitor(self):
|
||||
"""Periodically check and broadcast market status changes"""
|
||||
while True:
|
||||
try:
|
||||
await self.market_service.check_and_broadcast_market_status()
|
||||
|
||||
# On market open, start live session tracking
|
||||
status = self.market_service.get_market_status()
|
||||
if (
|
||||
status["status"] == "open"
|
||||
and not self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.start_live_session()
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
self.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: "
|
||||
f"${self._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif (
|
||||
status["status"] != "open"
|
||||
and self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.end_live_session()
|
||||
self._session_start_portfolio_value = None
|
||||
|
||||
# Update and broadcast live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
await self._update_and_broadcast_live_returns()
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Market status monitor error: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _update_and_broadcast_live_returns(self):
|
||||
"""Calculate and broadcast live returns for current session"""
|
||||
if not self.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
# Get current prices and calculate portfolio value
|
||||
prices = self.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
# Load current internal state to get baseline values
|
||||
state = self.storage.load_internal_state()
|
||||
|
||||
# Get latest values from history (if available)
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = (
|
||||
baseline_history[-1]["v"] if baseline_history else None
|
||||
)
|
||||
current_baseline_vw = (
|
||||
baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
)
|
||||
current_momentum = (
|
||||
momentum_history[-1]["v"] if momentum_history else None
|
||||
)
|
||||
|
||||
# Update live returns with current values
|
||||
point = self.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
|
||||
# Broadcast if we have new data
|
||||
if point:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
async def on_strategy_trigger(self, date: str):
|
||||
"""Handle trading cycle trigger"""
|
||||
logger.info(f"Strategy triggered for {date}")
|
||||
|
||||
tickers = self.config.get("tickers", [])
|
||||
|
||||
if self.is_backtest:
|
||||
await self._run_backtest_cycle(date, tickers)
|
||||
else:
|
||||
await self._run_live_cycle(date, tickers)
|
||||
|
||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||
"""Run backtest cycle with pre-loaded prices"""
|
||||
self.market_service.set_backtest_date(date)
|
||||
await self.market_service.emit_market_open()
|
||||
|
||||
await self.state_sync.on_cycle_start(date)
|
||||
self._dashboard.update(date=date, status="Analyzing...")
|
||||
|
||||
prices = self.market_service.get_open_prices()
|
||||
close_prices = self.market_service.get_close_prices()
|
||||
market_caps = self._get_market_caps(tickers, date)
|
||||
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await self.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(result, date, close_prices, settlement_result)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(date)
|
||||
|
||||
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
||||
"""
|
||||
Run live cycle with real market timing.
|
||||
|
||||
- Analysis runs immediately
|
||||
- Execution waits for market open
|
||||
(or uses current prices if already open)
|
||||
- Settlement waits for market close
|
||||
"""
|
||||
# Get actual trading date (might be next trading day if weekend)
|
||||
trading_date = self.market_service.get_live_trading_date()
|
||||
logger.info(
|
||||
f"Live cycle: triggered={date}, trading_date={trading_date}",
|
||||
)
|
||||
|
||||
await self.state_sync.on_cycle_start(trading_date)
|
||||
self._dashboard.update(date=trading_date, status="Analyzing...")
|
||||
|
||||
market_caps = self._get_market_caps(tickers, trading_date)
|
||||
|
||||
# Run pipeline with async price callbacks
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=self.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=self.market_service.wait_for_close_prices,
|
||||
)
|
||||
|
||||
close_prices = self.market_service.get_all_prices()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(
|
||||
result,
|
||||
trading_date,
|
||||
close_prices,
|
||||
settlement_result,
|
||||
)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(trading_date)
|
||||
|
||||
async def _finalize_cycle(self, date: str):
|
||||
"""Finalize cycle: broadcast state and update dashboard"""
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
summary.update(live_returns)
|
||||
|
||||
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
|
||||
if leaderboard:
|
||||
await self.state_sync.on_leaderboard_update(leaderboard)
|
||||
|
||||
self._dashboard.update(
|
||||
date=date,
|
||||
status="Running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
def _get_market_caps(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Get market caps for tickers (stub implementation)
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
Dict mapping ticker to market cap
|
||||
"""
|
||||
from ..tools.data_tools import get_market_cap
|
||||
|
||||
market_caps = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = get_market_cap(ticker, date)
|
||||
if market_cap:
|
||||
market_caps[ticker] = market_cap
|
||||
else:
|
||||
market_caps[ticker] = 1e9
|
||||
except Exception:
|
||||
market_caps[ticker] = 1e9
|
||||
|
||||
return market_caps
|
||||
|
||||
async def _broadcast_portfolio_updates(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await self.state_sync.on_holdings_update(holdings)
|
||||
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await self.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await self.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
def _save_cycle_results(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
settlement_result: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
|
||||
# Extract baseline values from settlement result
|
||||
baseline_values = None
|
||||
if settlement_result:
|
||||
baseline_values = settlement_result.get("baseline_values")
|
||||
|
||||
if portfolio:
|
||||
self.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
async def _run_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
self._dashboard.update(days_total=len(dates), days_completed=0)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Starting backtest - {len(dates)} trading days",
|
||||
)
|
||||
|
||||
try:
|
||||
for i, date in enumerate(dates):
|
||||
self._dashboard.update(days_completed=i)
|
||||
await self.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Backtest complete - {len(dates)} days",
|
||||
)
|
||||
|
||||
# Update dashboard with final state
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._dashboard.update(
|
||||
status="Complete",
|
||||
portfolio=summary,
|
||||
days_completed=len(dates),
|
||||
)
|
||||
self._dashboard.stop()
|
||||
self._dashboard.print_final_summary()
|
||||
except Exception as e:
|
||||
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
await self.state_sync.on_system_message(error_msg)
|
||||
self._dashboard.update(status=f"Failed: {str(e)}")
|
||||
self._dashboard.stop()
|
||||
raise
|
||||
finally:
|
||||
self._backtest_task = None
|
||||
|
||||
def _handle_backtest_exception(self, task: asyncio.Task):
|
||||
"""Handle exceptions from backtest task"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Backtest task failed with exception:{type(e).__name__}:{e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
self._backtest_start_date = dates[0]
|
||||
self._backtest_end_date = dates[-1]
|
||||
self._dashboard.days_total = len(dates)
|
||||
|
||||
def stop(self):
|
||||
self.state_sync.save_state()
|
||||
self.market_service.stop()
|
||||
if self._backtest_task:
|
||||
self._backtest_task.cancel()
|
||||
if self._market_status_task:
|
||||
self._market_status_task.cancel()
|
||||
self._dashboard.stop()
|
||||
625
backend/services/market.py
Normal file
625
backend/services/market.py
Normal file
@@ -0,0 +1,625 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Market Data Service
|
||||
Supports live, mock, and backtest modes
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone and calendar
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class MarketStatus:
|
||||
"""Market status enum-like class"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
PREMARKET = "premarket"
|
||||
AFTERHOURS = "afterhours"
|
||||
|
||||
|
||||
class MarketService:
|
||||
"""Market data service for price management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tickers: List[str],
|
||||
poll_interval: int = 10,
|
||||
mock_mode: bool = False,
|
||||
backtest_mode: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
backtest_start_date: Optional[str] = None,
|
||||
backtest_end_date: Optional[str] = None,
|
||||
):
|
||||
self.tickers = tickers
|
||||
self.poll_interval = poll_interval
|
||||
self.mock_mode = mock_mode
|
||||
self.backtest_mode = backtest_mode
|
||||
self.api_key = api_key
|
||||
self.backtest_start_date = backtest_start_date
|
||||
self.backtest_end_date = backtest_end_date
|
||||
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
# Market status tracking
|
||||
self._last_market_status: Optional[str] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
return "BACKTEST"
|
||||
elif self.mock_mode:
|
||||
return "MOCK"
|
||||
return "LIVE"
|
||||
|
||||
async def start(self, broadcast_func: Callable):
|
||||
"""Start market data service"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._broadcast_func = broadcast_func
|
||||
|
||||
if self.backtest_mode:
|
||||
self._start_backtest_mode()
|
||||
elif self.mock_mode:
|
||||
self._start_mock_mode()
|
||||
else:
|
||||
self._start_real_mode()
|
||||
|
||||
logger.info(
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
def callback(price_data: Dict[str, Any]):
|
||||
symbol = price_data["symbol"]
|
||||
self.cache[symbol] = price_data
|
||||
|
||||
loop = self._loop
|
||||
if loop and loop.is_running() and self._broadcast_func:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_price_update(price_data),
|
||||
loop,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
def _start_mock_mode(self):
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
|
||||
self._price_manager = MockPriceManager(
|
||||
poll_interval=self.poll_interval,
|
||||
volatility=0.5,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(
|
||||
self.tickers,
|
||||
base_prices={t: 100.0 for t in self.tickers},
|
||||
)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
self._price_manager = PollingPriceManager(
|
||||
api_key=self.api_key,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
)
|
||||
|
||||
self._price_manager = HistoricalPriceManager()
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
|
||||
if self.backtest_start_date and self.backtest_end_date:
|
||||
self._price_manager.preload_data(
|
||||
self.backtest_start_date,
|
||||
self.backtest_end_date,
|
||||
)
|
||||
|
||||
self._price_manager.start()
|
||||
|
||||
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
|
||||
"""Broadcast price update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
symbol = price_data["symbol"]
|
||||
price = price_data["price"]
|
||||
open_price = price_data.get("open", price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"open": open_price,
|
||||
"ret": ret,
|
||||
"timestamp": price_data.get("timestamp"),
|
||||
"realtime_prices": {
|
||||
t: self._get_cached_price(t) for t in self.tickers
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
|
||||
"""Get cached price data for a ticker"""
|
||||
if ticker in self.cache:
|
||||
return self.cache[ticker]
|
||||
# Return from price manager if not in cache
|
||||
if self._price_manager:
|
||||
price = self._price_manager.get_latest_price(ticker)
|
||||
if price:
|
||||
return {"price": price, "symbol": ticker}
|
||||
return {"price": 0, "symbol": ticker}
|
||||
|
||||
def stop(self):
|
||||
"""Stop market service"""
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
if self._price_manager:
|
||||
self._price_manager.stop()
|
||||
self._price_manager = None
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
if not self.backtest_mode or not self._price_manager:
|
||||
return
|
||||
self._current_date = date
|
||||
self._price_manager.set_date(date)
|
||||
logger.info(f"Backtest date: {date}")
|
||||
|
||||
async def emit_market_open(self):
|
||||
"""Emit market open prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_open_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_open_prices()
|
||||
logger.info(f"Open prices: {prices}")
|
||||
|
||||
async def emit_market_close(self):
|
||||
"""Emit market close prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_close_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_close_prices()
|
||||
logger.info(f"Close prices: {prices}")
|
||||
|
||||
def get_open_prices(self) -> Dict[str, float]:
|
||||
"""Get open prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_open_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("open") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_close_prices(self) -> Dict[str, float]:
|
||||
"""Get close prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_close_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("close") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
ticker: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
return self._price_manager.get_price_for_date(
|
||||
ticker,
|
||||
date,
|
||||
price_type,
|
||||
)
|
||||
return self.get_price_sync(ticker)
|
||||
|
||||
# Common methods
|
||||
def get_price_sync(self, ticker: str) -> Optional[float]:
|
||||
"""Get latest price synchronously"""
|
||||
# Try cache first
|
||||
data = self.cache.get(ticker)
|
||||
if data and data.get("price"):
|
||||
return data["price"]
|
||||
# Try price manager
|
||||
if self._price_manager:
|
||||
return self._price_manager.get_latest_price(ticker)
|
||||
return None
|
||||
|
||||
def get_all_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = self.get_price_sync(ticker)
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
# Live mode async waiting methods
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _get_market_hours(self, date: datetime) -> tuple:
|
||||
"""Get market open and close times for a given date"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
schedule = NYSE_CALENDAR.schedule(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
if schedule.empty:
|
||||
return None, None
|
||||
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
|
||||
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
|
||||
return market_open, market_close
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date + timedelta(days=1)
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
def _get_trading_date_for_execution(self) -> tuple:
|
||||
"""
|
||||
Determine the trading date for execution.
|
||||
|
||||
Returns:
|
||||
(trading_date, market_open_time, market_close_time)
|
||||
|
||||
Logic:
|
||||
- If today is a trading day and market has opened: use today
|
||||
- If today is a trading day but market hasn't opened: wait for open
|
||||
- If today is not a trading day: use next trading day
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if self._is_trading_day(today):
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
return today, market_open, market_close
|
||||
else:
|
||||
# Weekend or holiday - find next trading day
|
||||
next_day = self._next_trading_day(today)
|
||||
market_open, market_close = self._get_market_hours(next_day)
|
||||
return next_day, market_open, market_close
|
||||
|
||||
async def wait_for_open_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market open and return open prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already open today: return current prices immediately
|
||||
- If market hasn't opened yet today: wait until open
|
||||
- If not a trading day: wait until next trading day opens
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, market_open, _ = self._get_trading_date_for_execution()
|
||||
|
||||
if market_open is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_open:
|
||||
wait_seconds = (market_open - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market open "
|
||||
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure prices are available
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already open for {trading_date_str}, "
|
||||
f"getting current prices",
|
||||
)
|
||||
|
||||
# Poll until we have valid prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got open prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
async def wait_for_close_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market close and return close prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already closed today: return current prices immediately
|
||||
- If market hasn't closed yet: wait until close
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, _, market_close = self._get_trading_date_for_execution()
|
||||
|
||||
if market_close is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_close:
|
||||
wait_seconds = (market_close - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market close "
|
||||
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure final prices settle
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already closed for {trading_date_str}, "
|
||||
f"getting close prices",
|
||||
)
|
||||
|
||||
# Get final prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got close prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
def get_live_trading_date(self) -> str:
|
||||
"""Get the trading date that will be used for live execution"""
|
||||
trading_date, _, _ = self._get_trading_date_for_execution()
|
||||
return trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
async def _poll_for_prices(
|
||||
self,
|
||||
max_retries: int = 12,
|
||||
) -> Dict[str, float]:
|
||||
"""Poll until all prices are available"""
|
||||
for _ in range(max_retries):
|
||||
prices = self.get_all_prices()
|
||||
if all(p > 0 for p in prices.values()):
|
||||
return prices
|
||||
logger.debug("Waiting for prices to be available...")
|
||||
await asyncio.sleep(5)
|
||||
# Return whatever we have
|
||||
return self.get_all_prices()
|
||||
|
||||
# ========== Market Status Methods ==========
|
||||
|
||||
def get_market_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market status
|
||||
|
||||
Returns:
|
||||
Dict with status info:
|
||||
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
|
||||
- status_text: Human readable status
|
||||
- is_trading_day: Whether today is a trading day
|
||||
- market_open: Market open time (if trading day)
|
||||
- market_close: Market close time (if trading day)
|
||||
"""
|
||||
if self.backtest_mode:
|
||||
# In backtest mode, always return open
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
is_trading = self._is_trading_day(today)
|
||||
|
||||
if not is_trading:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
|
||||
if market_open is None or market_close is None:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
if now < market_open:
|
||||
return {
|
||||
"status": MarketStatus.PREMARKET,
|
||||
"status_text": "Pre-Market",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Market Open",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
"""Check market status and broadcast if changed"""
|
||||
status = self.get_market_status()
|
||||
current_status = status["status"]
|
||||
|
||||
if current_status != self._last_market_status:
|
||||
self._last_market_status = current_status
|
||||
await self._broadcast_market_status(status)
|
||||
|
||||
# Handle session transitions
|
||||
if current_status == MarketStatus.OPEN:
|
||||
await self._on_session_start()
|
||||
elif (
|
||||
current_status == MarketStatus.CLOSED
|
||||
and self._session_start_values is not None
|
||||
):
|
||||
self._on_session_end()
|
||||
|
||||
async def _broadcast_market_status(self, status: Dict[str, Any]):
|
||||
"""Broadcast market status update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "market_status_update",
|
||||
"market_status": status,
|
||||
"timestamp": datetime.now(NYSE_TZ).isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(f"Market status: {status['status_text']}")
|
||||
|
||||
async def _on_session_start(self):
|
||||
"""Called when market session starts - capture baseline values"""
|
||||
# Wait briefly for prices to be available
|
||||
await asyncio.sleep(2)
|
||||
|
||||
prices = self.get_all_prices()
|
||||
if prices and any(p > 0 for p in prices.values()):
|
||||
self._session_start_values = prices.copy()
|
||||
self._session_start_timestamp = int(
|
||||
datetime.now().timestamp() * 1000,
|
||||
)
|
||||
logger.info(f"Session started with prices: {prices}")
|
||||
|
||||
def _on_session_end(self):
|
||||
"""Called when market session ends - clear session data"""
|
||||
self._session_start_values = None
|
||||
self._session_start_timestamp = None
|
||||
logger.info("Session ended, cleared session data")
|
||||
|
||||
def get_session_returns(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
portfolio_value: Optional[float] = None,
|
||||
session_start_portfolio_value: Optional[float] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate session returns (from session start to now)
|
||||
|
||||
Args:
|
||||
current_prices: Current prices for tickers
|
||||
portfolio_value: Current portfolio value (optional)
|
||||
session_start_portfolio_value:
|
||||
|
||||
Returns:
|
||||
Dict with return data or None if session not started
|
||||
"""
|
||||
if self._session_start_values is None:
|
||||
return None
|
||||
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
returns = {}
|
||||
|
||||
# Calculate individual ticker returns
|
||||
for ticker, start_price in self._session_start_values.items():
|
||||
current = current_prices.get(ticker)
|
||||
if current and start_price and start_price > 0:
|
||||
ret = ((current - start_price) / start_price) * 100
|
||||
returns[ticker] = round(ret, 4)
|
||||
|
||||
result = {
|
||||
"timestamp": timestamp,
|
||||
"ticker_returns": returns,
|
||||
}
|
||||
|
||||
# Calculate portfolio return if values provided
|
||||
if (
|
||||
portfolio_value is not None
|
||||
and session_start_portfolio_value is not None
|
||||
):
|
||||
if session_start_portfolio_value > 0:
|
||||
portfolio_ret = (
|
||||
(portfolio_value - session_start_portfolio_value)
|
||||
/ session_start_portfolio_value
|
||||
) * 100
|
||||
result["portfolio_return"] = round(portfolio_ret, 4)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def session_start_values(self) -> Optional[Dict[str, float]]:
|
||||
"""Get session start values for external use"""
|
||||
return self._session_start_values
|
||||
|
||||
@property
|
||||
def session_start_timestamp(self) -> Optional[int]:
|
||||
"""Get session start timestamp"""
|
||||
return self._session_start_timestamp
|
||||
1099
backend/services/storage.py
Normal file
1099
backend/services/storage.py
Normal file
File diff suppressed because it is too large
Load Diff
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
580
backend/tests/test_agents.py
Normal file
580
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
class TestAnalystAgent:
|
||||
def test_init_valid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="technical_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.analyst_type_key == "technical_analyst"
|
||||
assert agent.name == "technical_analyst_analyst"
|
||||
assert agent.analyst_persona == "Technical Analyst"
|
||||
|
||||
def test_init_invalid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AnalystAgent(
|
||||
analyst_type="invalid_type",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert "Unknown analyst type" in str(excinfo.value)
|
||||
|
||||
def test_init_custom_agent_id(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="fundamentals_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
agent_id="custom_analyst_id",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_analyst_id"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="sentiment_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestPMAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "portfolio_manager"
|
||||
assert agent.portfolio["cash"] == 100000.0
|
||||
assert agent.portfolio["positions"] == {}
|
||||
assert agent.portfolio["margin_requirement"] == 0.25
|
||||
|
||||
def test_init_custom_cash(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=50000.0,
|
||||
margin_requirement=0.5,
|
||||
)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert agent.portfolio["margin_requirement"] == 0.5
|
||||
|
||||
def test_get_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=75000.0,
|
||||
)
|
||||
|
||||
state = agent.get_portfolio_state()
|
||||
|
||||
assert state["cash"] == 75000.0
|
||||
assert state is not agent.portfolio # Should be a copy
|
||||
|
||||
def test_load_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
new_portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0, "long_cost_basis": 150.0},
|
||||
},
|
||||
"margin_used": 1000.0,
|
||||
}
|
||||
|
||||
agent.load_portfolio_state(new_portfolio)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert "AAPL" in agent.portfolio["positions"]
|
||||
|
||||
def test_update_portfolio(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent.update_portfolio({"cash": 80000.0})
|
||||
assert agent.portfolio["cash"] == 80000.0
|
||||
|
||||
def _get_text_from_tool_response(self, result):
|
||||
"""Helper to extract text from ToolResponse content"""
|
||||
content = result.content[0]
|
||||
if hasattr(content, "text"):
|
||||
return content.text
|
||||
elif isinstance(content, dict):
|
||||
return content.get("text", "")
|
||||
return str(content)
|
||||
|
||||
def test_make_decision_long(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=100,
|
||||
confidence=80,
|
||||
reasoning="Strong fundamentals",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["AAPL"]["action"] == "long"
|
||||
assert agent._decisions["AAPL"]["quantity"] == 100
|
||||
|
||||
def test_make_decision_hold(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="GOOGL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
confidence=50,
|
||||
reasoning="Neutral outlook",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["GOOGL"]["action"] == "hold"
|
||||
assert agent._decisions["GOOGL"]["quantity"] == 0
|
||||
|
||||
def test_make_decision_invalid_action(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="invalid",
|
||||
quantity=10,
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Invalid action" in text
|
||||
|
||||
def test_get_decisions(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent._make_decision("AAPL", "long", 100)
|
||||
agent._make_decision("GOOGL", "short", 50)
|
||||
|
||||
decisions = agent.get_decisions()
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"]["action"] == "long"
|
||||
assert decisions["GOOGL"]["action"] == "short"
|
||||
|
||||
|
||||
class TestRiskAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "risk_manager"
|
||||
|
||||
def test_init_custom_name(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
name="custom_risk_manager",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_risk_manager"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestStorageService:
|
||||
def test_calculate_portfolio_value_cash_only(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {"cash": 100000.0, "positions": {}, "margin_used": 0.0}
|
||||
prices = {}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 100000.0
|
||||
|
||||
def test_calculate_portfolio_value_with_positions(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0},
|
||||
"GOOGL": {"long": 0, "short": 10},
|
||||
},
|
||||
"margin_used": 5000.0,
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 69000.0
|
||||
|
||||
def test_update_dashboard_after_cycle(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 90000.0,
|
||||
"positions": {"AAPL": {"long": 50, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
prices = {"AAPL": 200.0}
|
||||
|
||||
storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date="2024-01-15",
|
||||
executed_trades=[
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"action": "long",
|
||||
"quantity": 50,
|
||||
"price": 200.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary is not None
|
||||
assert summary["totalAssetValue"] == 100000.0 # 90000 + 50*200
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert holdings is not None
|
||||
assert len(holdings) > 0
|
||||
|
||||
trades = storage.load_file("trades")
|
||||
assert trades is not None
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["ticker"] == "AAPL"
|
||||
assert trades[0]["qty"] == 50
|
||||
assert trades[0]["price"] == 200.0
|
||||
|
||||
def test_generate_summary(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
"equity_history": [{"t": 1000, "v": 100000}],
|
||||
"all_trades": [],
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_summary(state, 100000.0, prices)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary["totalAssetValue"] == 100000.0
|
||||
assert summary["totalReturn"] == 0.0
|
||||
|
||||
def test_generate_holdings(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_holdings(state, prices)
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert len(holdings) == 2 # AAPL + CASH
|
||||
|
||||
aapl_holding = next(
|
||||
(h for h in holdings if h["ticker"] == "AAPL"),
|
||||
None,
|
||||
)
|
||||
assert aapl_holding is not None
|
||||
assert aapl_holding["quantity"] == 100
|
||||
assert aapl_holding["currentPrice"] == 500.0
|
||||
|
||||
|
||||
class TestTradeExecutor:
|
||||
def test_execute_trade_long(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
assert executor.portfolio["cash"] == 98500.0 # 100000 - 10*150
|
||||
|
||||
def test_execute_trade_short(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {
|
||||
"AAPL": {
|
||||
"long": 50,
|
||||
"short": 0,
|
||||
"long_cost_basis": 100.0,
|
||||
"short_cost_basis": 0.0,
|
||||
},
|
||||
},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="short",
|
||||
quantity=30,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 20 # 50 - 30
|
||||
|
||||
def test_execute_trade_hold(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor()
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "No trade needed"
|
||||
|
||||
|
||||
class TestPipelineExecution:
|
||||
def test_execute_decisions(self):
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
pm = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=[],
|
||||
risk_manager=MagicMock(),
|
||||
portfolio_manager=pm,
|
||||
max_comm_cycles=0,
|
||||
)
|
||||
|
||||
decisions = {
|
||||
"AAPL": {"action": "long", "quantity": 10},
|
||||
"GOOGL": {"action": "short", "quantity": 5},
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
result = pipeline._execute_decisions(decisions, prices, "2024-01-15")
|
||||
|
||||
assert len(result["executed_trades"]) == 2
|
||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||
assert result["executed_trades"][0]["quantity"] == 10
|
||||
assert pm.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
|
||||
|
||||
class TestMsgContentIsString:
|
||||
def test_msg_content_string(self):
|
||||
msg = Msg(name="test", content="simple string", role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
def test_msg_content_json_string(self):
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
parsed = json.loads(msg.content)
|
||||
assert parsed["key"] == "value"
|
||||
|
||||
def test_msg_content_should_not_be_dict(self):
|
||||
data = {"key": "value"}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="assistant")
|
||||
|
||||
assert not isinstance(msg.content, dict)
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
438
backend/tests/test_market_service.py
Normal file
438
backend/tests/test_market_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
|
||||
class TestMockPriceManager:
|
||||
def test_init_default(self):
|
||||
manager = MockPriceManager()
|
||||
|
||||
assert manager.poll_interval == 10
|
||||
assert manager.volatility == 0.5
|
||||
assert manager.running is False
|
||||
assert len(manager.subscribed_symbols) == 0
|
||||
|
||||
def test_init_custom(self):
|
||||
manager = MockPriceManager(poll_interval=5, volatility=1.0)
|
||||
|
||||
assert manager.poll_interval == 5
|
||||
assert manager.volatility == 1.0
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
assert manager.base_prices["AAPL"] == 237.50 # default price
|
||||
assert manager.base_prices["MSFT"] == 425.30 # default price
|
||||
|
||||
def test_subscribe_with_base_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
assert manager.base_prices["AAPL"] == 100.0
|
||||
assert manager.open_prices["AAPL"] == 100.0
|
||||
assert manager.latest_prices["AAPL"] == 100.0
|
||||
|
||||
def test_subscribe_unknown_symbol(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["UNKNOWN"])
|
||||
|
||||
assert "UNKNOWN" in manager.subscribed_symbols
|
||||
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = MockPriceManager()
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
def test_generate_price_update_within_bounds(self):
|
||||
manager = MockPriceManager(volatility=0.5)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
for _ in range(100):
|
||||
new_price = manager._generate_price_update("AAPL")
|
||||
# Should be within +/-10% of open
|
||||
assert 90.0 <= new_price <= 110.0
|
||||
|
||||
def test_update_prices_triggers_callback(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
manager._update_prices()
|
||||
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert "price" in call_args
|
||||
assert "timestamp" in call_args
|
||||
|
||||
def test_start_stop(self):
|
||||
manager = MockPriceManager(poll_interval=1)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1) # let thread start
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = MockPriceManager()
|
||||
manager.start()
|
||||
|
||||
assert (
|
||||
manager.running is False
|
||||
) # should not start without subscriptions
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 100.0
|
||||
|
||||
def test_get_latest_price_unknown(self):
|
||||
manager = MockPriceManager()
|
||||
price = manager.get_latest_price("UNKNOWN")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_latest_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(
|
||||
["AAPL", "MSFT"],
|
||||
base_prices={"AAPL": 100.0, "MSFT": 200.0},
|
||||
)
|
||||
|
||||
prices = manager.get_all_latest_prices()
|
||||
assert prices["AAPL"] == 100.0
|
||||
assert prices["MSFT"] == 200.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
manager.latest_prices["AAPL"] = 105.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
# Open price should change (based on latest with small gap)
|
||||
assert manager.open_prices["AAPL"] != 100.0
|
||||
|
||||
def test_set_base_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.set_base_price("AAPL", 150.0)
|
||||
|
||||
assert manager.base_prices["AAPL"] == 150.0
|
||||
assert manager.open_prices["AAPL"] == 150.0
|
||||
assert manager.latest_prices["AAPL"] == 150.0
|
||||
|
||||
|
||||
class TestPollingPriceManager:
|
||||
def test_init(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
|
||||
|
||||
assert manager.api_key == "test_key"
|
||||
assert manager.poll_interval == 30
|
||||
assert manager.running is False
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
@patch.object(PollingPriceManager, "_fetch_prices")
|
||||
def test_start_stop(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.start()
|
||||
|
||||
assert manager.running is False
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.latest_prices["AAPL"] = 150.0
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_open_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 148.0
|
||||
|
||||
price = manager.get_open_price("AAPL")
|
||||
assert price == 148.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 150.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
def test_init_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
assert service.tickers == ["AAPL", "MSFT"]
|
||||
assert service.poll_interval == 10
|
||||
assert service.mock_mode is True
|
||||
assert service.running is False
|
||||
|
||||
def test_init_real_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key="test_key",
|
||||
)
|
||||
|
||||
assert service.mock_mode is False
|
||||
assert service.api_key == "test_key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert service.running is True
|
||||
assert service._price_manager is not None
|
||||
assert isinstance(service._price_manager, MockPriceManager)
|
||||
|
||||
service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert "API key required" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
assert service.running is True
|
||||
|
||||
# Start again should not fail
|
||||
await service.start(broadcast_func)
|
||||
|
||||
service.stop()
|
||||
|
||||
def test_stop(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
service.running = True
|
||||
service._price_manager = MagicMock()
|
||||
|
||||
service.stop()
|
||||
|
||||
assert service.running is False
|
||||
assert service._price_manager is None
|
||||
|
||||
def test_stop_when_not_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
service.stop()
|
||||
assert service.running is False
|
||||
|
||||
def test_get_price_sync(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
||||
|
||||
price = service.get_price_sync("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_price_sync_not_found(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
|
||||
price = service.get_price_sync("MSFT")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_prices(self):
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0}
|
||||
service.cache["MSFT"] = {"price": 400.0}
|
||||
|
||||
prices = service.get_all_prices()
|
||||
|
||||
assert prices["AAPL"] == 150.0
|
||||
assert prices["MSFT"] == 400.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = AsyncMock()
|
||||
|
||||
price_data = {
|
||||
"symbol": "AAPL",
|
||||
"price": 150.0,
|
||||
"open": 148.0,
|
||||
"timestamp": 1234567890,
|
||||
}
|
||||
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
service._broadcast_func.assert_called_once()
|
||||
call_args = service._broadcast_func.call_args[0][0]
|
||||
assert call_args["type"] == "price_update"
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert call_args["price"] == 150.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update_no_func(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = None
|
||||
|
||||
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
||||
|
||||
# Should not raise
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_callback_thread_safety(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
received_prices = []
|
||||
|
||||
async def capture_broadcast(msg):
|
||||
received_prices.append(msg)
|
||||
|
||||
await service.start(capture_broadcast)
|
||||
|
||||
# Wait for at least one price update
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received at least one price update
|
||||
assert len(received_prices) >= 1
|
||||
assert received_prices[0]["type"] == "price_update"
|
||||
|
||||
|
||||
class TestMarketServiceIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mock_cycle(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
async def collect_messages(msg):
|
||||
messages.append(msg)
|
||||
|
||||
await service.start(collect_messages)
|
||||
|
||||
# Wait for price updates
|
||||
await asyncio.sleep(2.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received multiple price updates
|
||||
assert len(messages) >= 2
|
||||
|
||||
# Check message structure
|
||||
symbols_seen = set()
|
||||
for msg in messages:
|
||||
assert msg["type"] == "price_update"
|
||||
assert "symbol" in msg
|
||||
assert "price" in msg
|
||||
assert "ret" in msg
|
||||
symbols_seen.add(msg["symbol"])
|
||||
|
||||
# Should have prices for both tickers
|
||||
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
201
backend/tests/test_settlement.py
Normal file
201
backend/tests/test_settlement.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Test Settlement Coordinator and Baseline Calculations
|
||||
"""
|
||||
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
|
||||
|
||||
def test_baseline_equal_weight():
|
||||
"""Test equal weight baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
value = calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.equal_weight_initialized is True
|
||||
|
||||
|
||||
def test_baseline_market_cap_weighted():
|
||||
"""Test market cap weighted baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
market_caps = {"AAPL": 3e12, "MSFT": 2e12, "GOOGL": 1.5e12}
|
||||
|
||||
value = calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
market_caps,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.market_cap_initialized is True
|
||||
|
||||
|
||||
def test_momentum_scores():
|
||||
"""Test momentum score calculation"""
|
||||
tickers = ["AAPL", "MSFT"]
|
||||
prices_history = {
|
||||
"AAPL": [
|
||||
("2024-01-01", 100.0),
|
||||
("2024-01-02", 105.0),
|
||||
("2024-01-03", 110.0),
|
||||
],
|
||||
"MSFT": [
|
||||
("2024-01-01", 200.0),
|
||||
("2024-01-02", 195.0),
|
||||
("2024-01-03", 190.0),
|
||||
],
|
||||
}
|
||||
|
||||
scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
prices_history,
|
||||
lookback_days=2,
|
||||
)
|
||||
|
||||
assert scores["AAPL"] > 0
|
||||
assert scores["MSFT"] < 0
|
||||
|
||||
|
||||
def test_analyst_tracker_predictions():
|
||||
"""Test analyst prediction recording with structured format"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
final_predictions = [
|
||||
{
|
||||
"agent": "technical_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.8},
|
||||
{"ticker": "MSFT", "direction": "down", "confidence": 0.7},
|
||||
{"ticker": "GOOGL", "direction": "neutral", "confidence": 0.5},
|
||||
],
|
||||
},
|
||||
{
|
||||
"agent": "fundamentals_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.9},
|
||||
{"ticker": "MSFT", "direction": "up", "confidence": 0.6},
|
||||
{"ticker": "GOOGL", "direction": "down", "confidence": 0.75},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
assert "technical_analyst" in tracker.daily_predictions
|
||||
assert "fundamentals_analyst" in tracker.daily_predictions
|
||||
assert tracker.daily_predictions["technical_analyst"]["AAPL"] == "long"
|
||||
assert tracker.daily_predictions["technical_analyst"]["MSFT"] == "short"
|
||||
assert tracker.daily_predictions["technical_analyst"]["GOOGL"] == "hold"
|
||||
|
||||
|
||||
def test_analyst_evaluation():
|
||||
"""Test analyst prediction evaluation"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
tracker.daily_predictions = {
|
||||
"technical_analyst": {
|
||||
"AAPL": "long",
|
||||
"MSFT": "short",
|
||||
},
|
||||
}
|
||||
|
||||
open_prices = {"AAPL": 100.0, "MSFT": 200.0}
|
||||
close_prices = {"AAPL": 105.0, "MSFT": 195.0}
|
||||
|
||||
evaluations = tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
"2024-01-15",
|
||||
)
|
||||
|
||||
assert "technical_analyst" in evaluations
|
||||
eval_result = evaluations["technical_analyst"]
|
||||
assert eval_result["correct_predictions"] == 2
|
||||
assert eval_result["win_rate"] == 1.0
|
||||
|
||||
# Verify individual signals format
|
||||
assert "signals" in eval_result
|
||||
assert len(eval_result["signals"]) == 2
|
||||
for signal in eval_result["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
assert signal["date"] == "2024-01-15"
|
||||
|
||||
|
||||
def test_leaderboard_update():
|
||||
"""Test leaderboard update with evaluations"""
|
||||
leaderboard = [
|
||||
{
|
||||
"agentId": "technical_analyst",
|
||||
"name": "Technical Analyst",
|
||||
"rank": 0,
|
||||
"winRate": None,
|
||||
"bull": {"n": 0, "win": 0, "unknown": 0},
|
||||
"bear": {"n": 0, "win": 0, "unknown": 0},
|
||||
"signals": [],
|
||||
},
|
||||
]
|
||||
|
||||
evaluations = {
|
||||
"technical_analyst": {
|
||||
"total_predictions": 2,
|
||||
"correct_predictions": 1,
|
||||
"win_rate": 0.5,
|
||||
"bull": {"n": 1, "win": 1, "unknown": 0},
|
||||
"bear": {"n": 1, "win": 0, "unknown": 0},
|
||||
"hold": 0,
|
||||
"signals": [
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"signal": "bull",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": True,
|
||||
},
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"signal": "bear",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
updated = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
evaluations,
|
||||
)
|
||||
|
||||
assert updated[0]["bull"]["n"] == 1
|
||||
assert updated[0]["bull"]["win"] == 1
|
||||
assert updated[0]["winRate"] == 0.5
|
||||
assert len(updated[0]["signals"]) == 2
|
||||
|
||||
# Verify signal format matches frontend expectations
|
||||
for signal in updated[0]["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
0
backend/tools/__init__.py
Normal file
0
backend/tools/__init__.py
Normal file
1289
backend/tools/analysis_tools.py
Normal file
1289
backend/tools/analysis_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
742
backend/tools/data_tools.py
Normal file
742
backend/tools/data_tools.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301
|
||||
"""
|
||||
Data fetching tools for financial data.
|
||||
|
||||
All functions use centralized data source configuration from data_config.py.
|
||||
The data source is automatically determined based on available API keys:
|
||||
- Priority: FINNHUB_API_KEY > FINANCIAL_DATASETS_API_KEY
|
||||
"""
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import finnhub
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
import requests
|
||||
|
||||
from backend.config.data_config import (
|
||||
get_config,
|
||||
get_api_key,
|
||||
)
|
||||
from backend.data.cache import get_cache
|
||||
from backend.data.schema import (
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
FinancialMetrics,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTrade,
|
||||
InsiderTradeResponse,
|
||||
LineItem,
|
||||
LineItemResponse,
|
||||
Price,
|
||||
PriceResponse,
|
||||
)
|
||||
from backend.utils.settlement import logger
|
||||
|
||||
# Global cache instance
|
||||
_cache = get_cache()
|
||||
|
||||
|
||||
def get_last_tradeday(date: str) -> str:
|
||||
"""
|
||||
Get the previous trading day for the specified date
|
||||
|
||||
Args:
|
||||
date: Date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Previous trading day date string (YYYY-MM-DD)
|
||||
"""
|
||||
current_date = datetime.datetime.strptime(date, "%Y-%m-%d")
|
||||
_NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
if _NYSE_CALENDAR is not None:
|
||||
# Get trading days before current date
|
||||
# Go back 90 days from current date to get all trading days
|
||||
start_search = current_date - datetime.timedelta(days=90)
|
||||
|
||||
if hasattr(_NYSE_CALENDAR, "valid_days"):
|
||||
# pandas_market_calendars
|
||||
trading_dates = _NYSE_CALENDAR.valid_days(
|
||||
start_date=start_search.strftime("%Y-%m-%d"),
|
||||
end_date=current_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
else:
|
||||
# exchange_calendars
|
||||
trading_dates = _NYSE_CALENDAR.sessions_in_range(
|
||||
start_search.strftime("%Y-%m-%d"),
|
||||
current_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
# Convert to date list
|
||||
trading_dates_list = [
|
||||
pd.Timestamp(d).strftime("%Y-%m-%d") for d in trading_dates
|
||||
]
|
||||
|
||||
# Find current date position in the list
|
||||
if date in trading_dates_list:
|
||||
# If current date is a trading day, return previous trading day
|
||||
idx = trading_dates_list.index(date)
|
||||
if idx > 0:
|
||||
return trading_dates_list[idx - 1]
|
||||
else:
|
||||
# If it's the first trading day, go back further
|
||||
prev_date = current_date - datetime.timedelta(days=1)
|
||||
return get_last_tradeday(prev_date.strftime("%Y-%m-%d"))
|
||||
else:
|
||||
# If current date is not a trading day, return the nearest trading day
|
||||
if trading_dates_list:
|
||||
return trading_dates_list[-1]
|
||||
|
||||
return prev_date.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def _make_api_request(
|
||||
url: str,
|
||||
headers: dict,
|
||||
method: str = "GET",
|
||||
json_data: dict = None,
|
||||
max_retries: int = 3,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make an API request with rate limiting handling and moderate backoff.
|
||||
|
||||
Args:
|
||||
url: The URL to request
|
||||
headers: Headers to include in the request
|
||||
method: HTTP method (GET or POST)
|
||||
json_data: JSON data for POST requests
|
||||
max_retries: Maximum number of retries (default: 3)
|
||||
|
||||
Returns:
|
||||
requests.Response: The response object
|
||||
|
||||
Raises:
|
||||
Exception: If the request fails with a non-429 error
|
||||
"""
|
||||
for attempt in range(max_retries + 1): # +1 for initial attempt
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(url, headers=headers, json=json_data)
|
||||
else:
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 429 and attempt < max_retries:
|
||||
# Linear backoff: 60s, 90s, 120s, 150s...
|
||||
delay = 60 + (30 * attempt)
|
||||
print(
|
||||
f"Rate limited (429). Attempt {attempt + 1}/{max_retries + 1}. Waiting {delay}s before retrying...",
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
# Return the response (whether success, other errors, or final 429)
|
||||
return response
|
||||
|
||||
|
||||
def get_prices(
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> list[Price]:
|
||||
"""
|
||||
Fetch price data from cache or API.
|
||||
|
||||
Uses centralized data source configuration (FINNHUB_API_KEY prioritized).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
list[Price]: List of Price objects
|
||||
"""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# Create a cache key that includes all parameters to ensure exact matches
|
||||
cache_key = f"{ticker}_{start_date}_{end_date}_{data_source}"
|
||||
|
||||
# Check cache first - simple exact match
|
||||
if cached_data := _cache.get_prices(cache_key):
|
||||
return [Price(**price) for price in cached_data]
|
||||
|
||||
prices = []
|
||||
|
||||
if data_source == "finnhub":
|
||||
# Use Finnhub API
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
# Convert dates to timestamps
|
||||
start_timestamp = int(
|
||||
datetime.datetime.strptime(start_date, "%Y-%m-%d").timestamp(),
|
||||
)
|
||||
end_timestamp = int(
|
||||
(
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
+ datetime.timedelta(days=1)
|
||||
).timestamp(),
|
||||
)
|
||||
|
||||
# Fetch candle data from Finnhub
|
||||
candles = client.stock_candles(
|
||||
ticker,
|
||||
"D",
|
||||
start_timestamp,
|
||||
end_timestamp,
|
||||
)
|
||||
|
||||
# Convert to Price objects
|
||||
for i in range(len(candles["t"])):
|
||||
price = Price(
|
||||
open=candles["o"][i],
|
||||
close=candles["c"][i],
|
||||
high=candles["h"][i],
|
||||
low=candles["l"][i],
|
||||
volume=int(candles["v"][i]),
|
||||
time=datetime.datetime.fromtimestamp(candles["t"][i]).strftime(
|
||||
"%Y-%m-%d",
|
||||
),
|
||||
)
|
||||
prices.append(price)
|
||||
|
||||
else: # financial_datasets
|
||||
# Use Financial Datasets API
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}"
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
# Parse response with Pydantic model
|
||||
price_response = PriceResponse(**response.json())
|
||||
prices = price_response.prices
|
||||
|
||||
if not prices:
|
||||
return []
|
||||
|
||||
# Cache the results using the comprehensive cache key
|
||||
_cache.set_prices(cache_key, [p.model_dump() for p in prices])
|
||||
return prices
|
||||
|
||||
|
||||
def get_financial_metrics(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> list[FinancialMetrics]:
|
||||
"""
|
||||
Fetch financial metrics from cache or API.
|
||||
|
||||
Uses centralized data source configuration (FINNHUB_API_KEY prioritized).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
period: Period type (default: "ttm")
|
||||
limit: Number of records to fetch
|
||||
|
||||
Returns:
|
||||
list[FinancialMetrics]: List of financial metrics
|
||||
"""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# Create a cache key that includes all parameters to ensure exact matches
|
||||
cache_key = f"{ticker}_{period}_{end_date}_{limit}_{data_source}"
|
||||
|
||||
# Check cache first - simple exact match
|
||||
if cached_data := _cache.get_financial_metrics(cache_key):
|
||||
return [FinancialMetrics(**metric) for metric in cached_data]
|
||||
|
||||
financial_metrics = []
|
||||
|
||||
if data_source == "finnhub":
|
||||
# Use Finnhub API - Basic Financials
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
# Fetch basic financials from Finnhub
|
||||
# metric='all' returns all available metrics
|
||||
financials = client.company_basic_financials(ticker, "all")
|
||||
|
||||
if not financials or "metric" not in financials:
|
||||
return []
|
||||
|
||||
# Finnhub returns {series: {...}, metric: {...}, metricType: ..., symbol: ...}
|
||||
# We need to create a FinancialMetrics object from this
|
||||
metric_data = financials.get("metric", {})
|
||||
|
||||
# Create a FinancialMetrics object with available data
|
||||
metric = _map_finnhub_metrics(ticker, end_date, period, metric_data)
|
||||
|
||||
financial_metrics = [metric]
|
||||
|
||||
else: # financial_datasets
|
||||
# Use Financial Datasets API
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}"
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
# Parse response with Pydantic model
|
||||
metrics_response = FinancialMetricsResponse(**response.json())
|
||||
financial_metrics = metrics_response.financial_metrics
|
||||
|
||||
if not financial_metrics:
|
||||
return []
|
||||
|
||||
# Cache the results as dicts using the comprehensive cache key
|
||||
_cache.set_financial_metrics(
|
||||
cache_key,
|
||||
[m.model_dump() for m in financial_metrics],
|
||||
)
|
||||
return financial_metrics
|
||||
|
||||
|
||||
def _map_finnhub_metrics(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str,
|
||||
metric_data: dict,
|
||||
) -> FinancialMetrics:
|
||||
"""Map Finnhub metric data to FinancialMetrics model."""
|
||||
return FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=metric_data.get("marketCapitalization"),
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=metric_data.get("peBasicExclExtraTTM"),
|
||||
price_to_book_ratio=metric_data.get("pbAnnual"),
|
||||
price_to_sales_ratio=metric_data.get("psAnnual"),
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=metric_data.get("grossMarginTTM"),
|
||||
operating_margin=metric_data.get("operatingMarginTTM"),
|
||||
net_margin=metric_data.get("netProfitMarginTTM"),
|
||||
return_on_equity=metric_data.get("roeTTM"),
|
||||
return_on_assets=metric_data.get("roaTTM"),
|
||||
return_on_invested_capital=metric_data.get("roicTTM"),
|
||||
asset_turnover=metric_data.get("assetTurnoverTTM"),
|
||||
inventory_turnover=metric_data.get("inventoryTurnoverTTM"),
|
||||
receivables_turnover=metric_data.get("receivablesTurnoverTTM"),
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=metric_data.get("currentRatioAnnual"),
|
||||
quick_ratio=metric_data.get("quickRatioAnnual"),
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=metric_data.get("totalDebt/totalEquityAnnual"),
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=metric_data.get("revenueGrowthTTMYoy"),
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=metric_data.get("epsGrowthTTMYoy"),
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=metric_data.get("payoutRatioAnnual"),
|
||||
earnings_per_share=metric_data.get("epsBasicExclExtraItemsTTM"),
|
||||
book_value_per_share=metric_data.get("bookValuePerShareAnnual"),
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
|
||||
|
||||
def search_line_items(
|
||||
ticker: str,
|
||||
line_items: list[str],
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> list[LineItem]:
|
||||
"""
|
||||
Fetch line items from Financial Datasets API (only supported source).
|
||||
|
||||
Returns empty list on API errors to allow graceful degradation.
|
||||
"""
|
||||
try:
|
||||
api_key = get_api_key()
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = "https://api.financialdatasets.ai/financials/search/line-items"
|
||||
body = {
|
||||
"tickers": [ticker],
|
||||
"line_items": line_items,
|
||||
"end_date": end_date,
|
||||
"period": period,
|
||||
"limit": limit,
|
||||
}
|
||||
response = _make_api_request(
|
||||
url,
|
||||
headers,
|
||||
method="POST",
|
||||
json_data=body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.info(
|
||||
f"Warning: Failed to fetch line items for {ticker}: "
|
||||
f"{response.status_code} - {response.text}",
|
||||
)
|
||||
return []
|
||||
|
||||
data = response.json()
|
||||
response_model = LineItemResponse(**data)
|
||||
search_results = response_model.search_results
|
||||
|
||||
if not search_results:
|
||||
return []
|
||||
|
||||
return search_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Warning: Exception while fetching line items for {ticker}: {str(e)}",
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _fetch_finnhub_insider_trades(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from Finnhub API."""
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
from_date = start_date or (
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
- datetime.timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
insider_data = client.stock_insider_transactions(
|
||||
ticker,
|
||||
from_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
if not insider_data or "data" not in insider_data:
|
||||
return []
|
||||
|
||||
return [
|
||||
_convert_finnhub_insider_trade(ticker, trade)
|
||||
for trade in insider_data["data"][:limit]
|
||||
]
|
||||
|
||||
|
||||
def _fetch_fd_insider_trades(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from Financial Datasets API."""
|
||||
headers = {"X-API-KEY": api_key}
|
||||
all_trades = []
|
||||
current_end_date = end_date
|
||||
|
||||
while True:
|
||||
url = f"https://api.financialdatasets.ai/insider-trades/?ticker={ticker}&filing_date_lte={current_end_date}"
|
||||
if start_date:
|
||||
url += f"&filing_date_gte={start_date}"
|
||||
url += f"&limit={limit}"
|
||||
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
response_model = InsiderTradeResponse(**data)
|
||||
insider_trades = response_model.insider_trades
|
||||
|
||||
if not insider_trades:
|
||||
break
|
||||
|
||||
all_trades.extend(insider_trades)
|
||||
|
||||
if not start_date or len(insider_trades) < limit:
|
||||
break
|
||||
|
||||
current_end_date = min(
|
||||
trade.filing_date for trade in insider_trades
|
||||
).split("T")[0]
|
||||
|
||||
if current_end_date <= start_date:
|
||||
break
|
||||
|
||||
return all_trades
|
||||
|
||||
|
||||
def get_insider_trades(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from cache or API."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{data_source}"
|
||||
)
|
||||
|
||||
if cached_data := _cache.get_insider_trades(cache_key):
|
||||
return [InsiderTrade(**trade) for trade in cached_data]
|
||||
|
||||
if data_source == "finnhub":
|
||||
all_trades = _fetch_finnhub_insider_trades(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
else:
|
||||
all_trades = _fetch_fd_insider_trades(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
|
||||
if not all_trades:
|
||||
return []
|
||||
|
||||
_cache.set_insider_trades(
|
||||
cache_key,
|
||||
[trade.model_dump() for trade in all_trades],
|
||||
)
|
||||
return all_trades
|
||||
|
||||
|
||||
def _fetch_finnhub_company_news(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from Finnhub API."""
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
from_date = start_date or (
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
- datetime.timedelta(days=30)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
news_data = client.company_news(ticker, _from=from_date, to=end_date)
|
||||
|
||||
if not news_data:
|
||||
return []
|
||||
|
||||
all_news = []
|
||||
for news_item in news_data[:limit]:
|
||||
company_news = CompanyNews(
|
||||
ticker=ticker,
|
||||
title=news_item.get("headline", ""),
|
||||
related=news_item.get("related", ""),
|
||||
source=news_item.get("source", ""),
|
||||
date=(
|
||||
datetime.datetime.fromtimestamp(
|
||||
news_item.get("datetime", 0),
|
||||
datetime.timezone.utc,
|
||||
).strftime("%Y-%m-%d")
|
||||
if news_item.get("datetime")
|
||||
else None
|
||||
),
|
||||
url=news_item.get("url", ""),
|
||||
summary=news_item.get("summary", ""),
|
||||
category=news_item.get("category", ""),
|
||||
)
|
||||
all_news.append(company_news)
|
||||
return all_news
|
||||
|
||||
|
||||
def _fetch_fd_company_news(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from Financial Datasets API."""
|
||||
headers = {"X-API-KEY": api_key}
|
||||
all_news = []
|
||||
current_end_date = end_date
|
||||
|
||||
while True:
|
||||
url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}"
|
||||
if start_date:
|
||||
url += f"&start_date={start_date}"
|
||||
url += f"&limit={limit}"
|
||||
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
response_model = CompanyNewsResponse(**data)
|
||||
company_news = response_model.news
|
||||
|
||||
if not company_news:
|
||||
break
|
||||
|
||||
all_news.extend(company_news)
|
||||
|
||||
if not start_date or len(company_news) < limit:
|
||||
break
|
||||
|
||||
current_end_date = min(
|
||||
news.date for news in company_news if news.date is not None
|
||||
).split("T")[0]
|
||||
|
||||
if current_end_date <= start_date:
|
||||
break
|
||||
|
||||
return all_news
|
||||
|
||||
|
||||
def get_company_news(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from cache or API."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{data_source}"
|
||||
)
|
||||
|
||||
if cached_data := _cache.get_company_news(cache_key):
|
||||
return [CompanyNews(**news) for news in cached_data]
|
||||
|
||||
if data_source == "finnhub":
|
||||
all_news = _fetch_finnhub_company_news(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
else:
|
||||
all_news = _fetch_fd_company_news(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
|
||||
if not all_news:
|
||||
return []
|
||||
|
||||
_cache.set_company_news(
|
||||
cache_key,
|
||||
[news.model_dump() for news in all_news],
|
||||
)
|
||||
return all_news
|
||||
|
||||
|
||||
def _convert_finnhub_insider_trade(ticker: str, trade: dict) -> InsiderTrade:
|
||||
"""Convert Finnhub insider trade format to InsiderTrade model."""
|
||||
shares_after = trade.get("share", 0)
|
||||
change = trade.get("change", 0)
|
||||
|
||||
return InsiderTrade(
|
||||
ticker=ticker,
|
||||
issuer=None,
|
||||
name=trade.get("name", ""),
|
||||
title=None,
|
||||
is_board_director=None,
|
||||
transaction_date=trade.get("transactionDate", ""),
|
||||
transaction_shares=abs(change),
|
||||
transaction_price_per_share=trade.get("transactionPrice", 0.0),
|
||||
transaction_value=abs(change) * trade.get("transactionPrice", 0.0),
|
||||
shares_owned_before_transaction=(
|
||||
shares_after - change if shares_after and change else None
|
||||
),
|
||||
shares_owned_after_transaction=float(shares_after)
|
||||
if shares_after
|
||||
else None,
|
||||
security_title=None,
|
||||
filing_date=trade.get("filingDate", ""),
|
||||
)
|
||||
|
||||
|
||||
def get_market_cap(ticker: str, end_date: str) -> float | None:
|
||||
"""Fetch market cap from the API. Finnhub values are converted from millions."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# For today's date, use company facts API
|
||||
if end_date == datetime.datetime.now().strftime("%Y-%m-%d"):
|
||||
headers = {"X-API-KEY": api_key}
|
||||
url = (
|
||||
f"https://api.financialdatasets.ai/company/facts/?ticker={ticker}"
|
||||
)
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
response_model = CompanyFactsResponse(**data)
|
||||
return response_model.company_facts.market_cap
|
||||
|
||||
financial_metrics = get_financial_metrics(ticker, end_date)
|
||||
if not financial_metrics:
|
||||
return None
|
||||
|
||||
market_cap = financial_metrics[0].market_cap
|
||||
if not market_cap:
|
||||
return None
|
||||
|
||||
# Finnhub returns market cap in millions
|
||||
if data_source == "finnhub":
|
||||
market_cap = market_cap * 1_000_000
|
||||
|
||||
return market_cap
|
||||
|
||||
|
||||
def prices_to_df(prices: list[Price]) -> pd.DataFrame:
|
||||
"""Convert prices to a DataFrame."""
|
||||
df = pd.DataFrame([p.model_dump() for p in prices])
|
||||
df["Date"] = pd.to_datetime(df["time"])
|
||||
df.set_index("Date", inplace=True)
|
||||
numeric_cols = ["open", "close", "high", "low", "volume"]
|
||||
for col in numeric_cols:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
4
backend/utils/__init__.py
Normal file
4
backend/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# This file can be empty
|
||||
|
||||
"""Utility modules for the application."""
|
||||
449
backend/utils/analyst_tracker.py
Normal file
449
backend/utils/analyst_tracker.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Analyst Performance Tracker
|
||||
Tracks analyst predictions and calculates win rates for leaderboard
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalystPerformanceTracker:
|
||||
"""
|
||||
Tracks analyst predictions and evaluates accuracy
|
||||
|
||||
Workflow:
|
||||
1. Record analyst predictions for each ticker before market close
|
||||
2. After market close, evaluate predictions against actual returns
|
||||
3. Update leaderboard with win rates and statistics
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.daily_predictions = {}
|
||||
|
||||
def record_analyst_predictions(
|
||||
self,
|
||||
final_predictions: List[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Record predictions from analysts for the current trading day
|
||||
|
||||
Args:
|
||||
final_predictions: List of structured prediction results
|
||||
Format: [
|
||||
{
|
||||
'agent': 'analyst_name',
|
||||
'predictions': [
|
||||
{'ticker': 'AAPL', '
|
||||
direction': 'up',
|
||||
'confidence': 0.75},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
tickers: List of tickers being analyzed
|
||||
"""
|
||||
self.daily_predictions = {}
|
||||
|
||||
direction_mapping = {
|
||||
"up": "long",
|
||||
"down": "short",
|
||||
"neutral": "hold",
|
||||
}
|
||||
|
||||
for result in final_predictions:
|
||||
analyst_id = result.get("agent")
|
||||
if not analyst_id:
|
||||
continue
|
||||
|
||||
predictions = result.get("predictions", [])
|
||||
|
||||
self.daily_predictions[analyst_id] = {}
|
||||
|
||||
for pred in predictions:
|
||||
ticker = pred.get("ticker")
|
||||
direction = pred.get("direction", "neutral")
|
||||
|
||||
if ticker:
|
||||
signal = direction_mapping.get(direction, "hold")
|
||||
self.daily_predictions[analyst_id][ticker] = signal
|
||||
|
||||
def evaluate_predictions(
|
||||
self,
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
date: str,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate analyst predictions against actual market moves
|
||||
|
||||
Args:
|
||||
open_prices: Opening prices for each ticker
|
||||
close_prices: Closing prices for each ticker
|
||||
date: Trading date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict mapping analyst_id to evaluation results
|
||||
"""
|
||||
evaluation_results = {}
|
||||
|
||||
# Map internal signal types to frontend display names
|
||||
signal_display_map = {
|
||||
"long": "bull",
|
||||
"short": "bear",
|
||||
"hold": "neutral",
|
||||
}
|
||||
|
||||
for analyst_id, predictions in self.daily_predictions.items():
|
||||
correct_long = 0
|
||||
correct_short = 0
|
||||
incorrect_long = 0
|
||||
incorrect_short = 0
|
||||
unknown_long = 0
|
||||
unknown_short = 0
|
||||
hold_count = 0
|
||||
|
||||
# Individual signal records for frontend display
|
||||
individual_signals: List[Dict[str, Any]] = []
|
||||
|
||||
for ticker, prediction in predictions.items():
|
||||
open_price = open_prices.get(ticker, 0)
|
||||
close_price = close_prices.get(ticker, 0)
|
||||
|
||||
signal_type = signal_display_map.get(prediction, "neutral")
|
||||
|
||||
# Cannot evaluate if prices are missing
|
||||
if open_price <= 0 or close_price <= 0:
|
||||
if prediction == "long":
|
||||
unknown_long += 1
|
||||
elif prediction == "short":
|
||||
unknown_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": "unknown",
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
actual_return = (close_price - open_price) / open_price
|
||||
|
||||
if prediction == "long":
|
||||
is_correct = actual_return > 0
|
||||
if is_correct:
|
||||
correct_long += 1
|
||||
else:
|
||||
incorrect_long += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
elif prediction == "short":
|
||||
is_correct = actual_return < 0
|
||||
if is_correct:
|
||||
correct_short += 1
|
||||
else:
|
||||
incorrect_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
elif prediction == "hold":
|
||||
hold_count += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": None,
|
||||
},
|
||||
)
|
||||
|
||||
total_long = correct_long + incorrect_long + unknown_long
|
||||
total_short = correct_short + incorrect_short + unknown_short
|
||||
evaluated_long = correct_long + incorrect_long
|
||||
evaluated_short = correct_short + incorrect_short
|
||||
total_evaluated = evaluated_long + evaluated_short
|
||||
correct_predictions = correct_long + correct_short
|
||||
|
||||
win_rate = (
|
||||
correct_predictions / total_evaluated
|
||||
if total_evaluated > 0
|
||||
else None
|
||||
)
|
||||
|
||||
evaluation_results[analyst_id] = {
|
||||
"total_predictions": total_evaluated,
|
||||
"correct_predictions": correct_predictions,
|
||||
"win_rate": win_rate,
|
||||
"bull": {
|
||||
"n": total_long,
|
||||
"win": correct_long,
|
||||
"unknown": unknown_long,
|
||||
},
|
||||
"bear": {
|
||||
"n": total_short,
|
||||
"win": correct_short,
|
||||
"unknown": unknown_short,
|
||||
},
|
||||
"hold": hold_count,
|
||||
"signals": individual_signals,
|
||||
}
|
||||
|
||||
return evaluation_results
|
||||
|
||||
def clear_daily_predictions(self):
|
||||
"""Clear predictions after evaluation"""
|
||||
self.daily_predictions = {}
|
||||
|
||||
def _process_single_pm_decision(
|
||||
self,
|
||||
_ticker: str,
|
||||
decision: Dict,
|
||||
open_price: float,
|
||||
close_price: float,
|
||||
_date: str,
|
||||
) -> Tuple[str, Optional[bool], str]:
|
||||
"""
|
||||
Process a single PM decision and evaluate correctness
|
||||
|
||||
Returns:
|
||||
Tuple of (prediction, is_correct, signal_type)
|
||||
"""
|
||||
action = decision.get("action", "hold")
|
||||
|
||||
# Convert action to prediction format
|
||||
if action in ["buy", "long"]:
|
||||
prediction = "long"
|
||||
elif action in ["sell", "short"]:
|
||||
prediction = "short"
|
||||
else:
|
||||
prediction = "hold"
|
||||
|
||||
signal_display_map = {
|
||||
"long": "bull",
|
||||
"short": "bear",
|
||||
"hold": "neutral",
|
||||
}
|
||||
signal_type = signal_display_map.get(prediction, "neutral")
|
||||
|
||||
# Handle invalid prices
|
||||
if open_price <= 0 or close_price <= 0:
|
||||
return prediction, None, signal_type
|
||||
|
||||
# Evaluate correctness
|
||||
actual_return = (close_price - open_price) / open_price
|
||||
|
||||
if prediction == "long":
|
||||
is_correct = actual_return > 0
|
||||
elif prediction == "short":
|
||||
is_correct = actual_return < 0
|
||||
else: # hold
|
||||
is_correct = None
|
||||
|
||||
return prediction, is_correct, signal_type
|
||||
|
||||
def evaluate_pm_decisions(
|
||||
self,
|
||||
pm_decisions: Dict[str, Dict],
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
date: str,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate PM's trading decisions against actual market moves
|
||||
|
||||
Args:
|
||||
pm_decisions: PM decisions {ticker: {action, quantity, ...}}
|
||||
open_prices: Opening prices for each ticker
|
||||
close_prices: Closing prices for each ticker
|
||||
date: Trading date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict with 'portfolio_manager' key containing evaluation results
|
||||
"""
|
||||
if not pm_decisions or not open_prices or not close_prices:
|
||||
return {}
|
||||
|
||||
correct_long = 0
|
||||
correct_short = 0
|
||||
incorrect_long = 0
|
||||
incorrect_short = 0
|
||||
unknown_long = 0
|
||||
unknown_short = 0
|
||||
hold_count = 0
|
||||
|
||||
individual_signals: List[Dict[str, Any]] = []
|
||||
|
||||
for ticker, decision in pm_decisions.items():
|
||||
open_price = open_prices.get(ticker, 0)
|
||||
close_price = close_prices.get(ticker, 0)
|
||||
|
||||
(
|
||||
prediction,
|
||||
is_correct,
|
||||
signal_type,
|
||||
) = self._process_single_pm_decision(
|
||||
ticker,
|
||||
decision,
|
||||
open_price,
|
||||
close_price,
|
||||
date,
|
||||
)
|
||||
|
||||
if is_correct is None and (open_price <= 0 or close_price <= 0):
|
||||
if prediction == "long":
|
||||
unknown_long += 1
|
||||
elif prediction == "short":
|
||||
unknown_short += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": "unknown",
|
||||
},
|
||||
)
|
||||
elif prediction == "hold":
|
||||
hold_count += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": None,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if prediction == "long":
|
||||
if is_correct:
|
||||
correct_long += 1
|
||||
else:
|
||||
incorrect_long += 1
|
||||
else:
|
||||
if is_correct:
|
||||
correct_short += 1
|
||||
else:
|
||||
incorrect_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
total_long = correct_long + incorrect_long + unknown_long
|
||||
total_short = correct_short + incorrect_short + unknown_short
|
||||
evaluated_long = correct_long + incorrect_long
|
||||
evaluated_short = correct_short + incorrect_short
|
||||
total_evaluated = evaluated_long + evaluated_short
|
||||
correct_predictions = correct_long + correct_short
|
||||
|
||||
win_rate = (
|
||||
correct_predictions / total_evaluated
|
||||
if total_evaluated > 0
|
||||
else None
|
||||
)
|
||||
|
||||
return {
|
||||
"portfolio_manager": {
|
||||
"total_predictions": total_evaluated,
|
||||
"correct_predictions": correct_predictions,
|
||||
"win_rate": win_rate,
|
||||
"bull": {
|
||||
"n": total_long,
|
||||
"win": correct_long,
|
||||
"unknown": unknown_long,
|
||||
},
|
||||
"bear": {
|
||||
"n": total_short,
|
||||
"win": correct_short,
|
||||
"unknown": unknown_short,
|
||||
},
|
||||
"hold": hold_count,
|
||||
"signals": individual_signals,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def update_leaderboard_with_evaluations(
|
||||
leaderboard: List[Dict[str, Any]],
|
||||
evaluations: Dict[str, Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Update leaderboard with new evaluation results
|
||||
|
||||
Args:
|
||||
leaderboard: Current leaderboard data
|
||||
evaluations: Evaluation results for the day
|
||||
|
||||
Returns:
|
||||
Updated leaderboard
|
||||
"""
|
||||
for entry in leaderboard:
|
||||
agent_id = entry.get("agentId")
|
||||
if not agent_id or agent_id not in evaluations:
|
||||
continue
|
||||
|
||||
eval_result = evaluations[agent_id]
|
||||
|
||||
# Update aggregate stats
|
||||
entry["bull"]["n"] += eval_result["bull"]["n"]
|
||||
entry["bull"]["win"] += eval_result["bull"]["win"]
|
||||
entry["bull"]["unknown"] = (
|
||||
entry["bull"].get("unknown", 0) + eval_result["bull"]["unknown"]
|
||||
)
|
||||
entry["bear"]["n"] += eval_result["bear"]["n"]
|
||||
entry["bear"]["win"] += eval_result["bear"]["win"]
|
||||
entry["bear"]["unknown"] = (
|
||||
entry["bear"].get("unknown", 0) + eval_result["bear"]["unknown"]
|
||||
)
|
||||
|
||||
# Calculate win rate based on evaluated signals only
|
||||
# evaluated = total - unknown
|
||||
evaluated_bull = entry["bull"]["n"] - entry["bull"]["unknown"]
|
||||
evaluated_bear = entry["bear"]["n"] - entry["bear"]["unknown"]
|
||||
total_evaluated = evaluated_bull + evaluated_bear
|
||||
total_wins = entry["bull"]["win"] + entry["bear"]["win"]
|
||||
|
||||
if total_evaluated > 0:
|
||||
entry["winRate"] = round(total_wins / total_evaluated, 4)
|
||||
|
||||
# Add individual signal records
|
||||
if "signals" not in entry:
|
||||
entry["signals"] = []
|
||||
|
||||
for signal in eval_result.get("signals", []):
|
||||
entry["signals"].append(signal)
|
||||
|
||||
# Keep only recent signals (e.g., last 100 individual signals)
|
||||
entry["signals"] = entry["signals"][-100:]
|
||||
|
||||
# Re-rank analysts by win rate (rank starts from 1)
|
||||
analyst_entries = [e for e in leaderboard if e.get("rank") is not None]
|
||||
analyst_entries.sort(key=lambda e: e.get("winRate", 0), reverse=True)
|
||||
for idx, entry in enumerate(analyst_entries):
|
||||
entry["rank"] = idx + 1 # Rank 1 = highest win rate (gold medal)
|
||||
|
||||
return leaderboard
|
||||
405
backend/utils/baselines.py
Normal file
405
backend/utils/baselines.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Baseline Strategy Calculators
|
||||
Tracks performance of simple baseline strategies for comparison
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Portfolio(TypedDict):
|
||||
cash: float
|
||||
positions: Dict[str, float]
|
||||
|
||||
|
||||
class BaselineCalculator:
|
||||
"""
|
||||
Calculates baseline strategy returns for comparison
|
||||
|
||||
Strategies:
|
||||
1. Equal-weight: Allocate equal weight to all tickers
|
||||
2. Market-cap-weighted: Allocate proportional to market cap
|
||||
3. Simple momentum: Monthly rebalance,
|
||||
long top 50% momentum, short bottom 50%
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float = 100000.0):
|
||||
self.initial_capital = initial_capital
|
||||
|
||||
self.equal_weight_portfolio: Portfolio = {"cash": 0.0, "positions": {}}
|
||||
self.market_cap_portfolio: Portfolio = {"cash": 0.0, "positions": {}}
|
||||
self.momentum_portfolio: Portfolio = {
|
||||
"cash": initial_capital,
|
||||
"positions": {},
|
||||
}
|
||||
|
||||
self.equal_weight_initialized = False
|
||||
self.market_cap_initialized = False
|
||||
self.momentum_last_rebalance_date = None
|
||||
|
||||
def calculate_equal_weight_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate equal-weight portfolio value
|
||||
|
||||
On first call, initialize positions with equal allocation using
|
||||
open prices. Subsequently, mark-to-market existing positions
|
||||
using close prices.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
"""
|
||||
if not self.equal_weight_initialized:
|
||||
allocation_per_ticker = self.initial_capital / len(tickers)
|
||||
self.equal_weight_portfolio["cash"] = 0.0
|
||||
for ticker in tickers:
|
||||
price = open_prices.get(ticker, 0) # Use OPEN price for buying
|
||||
if price > 0:
|
||||
shares = allocation_per_ticker / price
|
||||
self.equal_weight_portfolio["positions"][ticker] = shares
|
||||
logger.info(
|
||||
f"Equal Weight: Initialized {ticker} with "
|
||||
f"{shares:.2f} shares @ ${price:.2f} (open)",
|
||||
)
|
||||
self.equal_weight_initialized = True
|
||||
|
||||
total_value = self.equal_weight_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.equal_weight_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def calculate_market_cap_weighted_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate market-cap-weighted portfolio value
|
||||
|
||||
On first call, initialize positions weighted by market cap using
|
||||
open prices. Subsequently, mark-to-market existing positions
|
||||
using close prices.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
market_caps: Market capitalization for each ticker
|
||||
"""
|
||||
if not self.market_cap_initialized:
|
||||
total_market_cap = sum(market_caps.get(t, 0) for t in tickers)
|
||||
if total_market_cap <= 0:
|
||||
logger.warning("No market cap data, using equal weight")
|
||||
return self.calculate_equal_weight_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
)
|
||||
|
||||
self.market_cap_portfolio["cash"] = 0.0
|
||||
for ticker in tickers:
|
||||
market_cap = market_caps.get(ticker, 0)
|
||||
price = open_prices.get(ticker, 0) # Use OPEN price for buying
|
||||
if market_cap > 0 and price > 0:
|
||||
weight = market_cap / total_market_cap
|
||||
allocation = self.initial_capital * weight
|
||||
shares = allocation / price
|
||||
self.market_cap_portfolio["positions"][ticker] = shares
|
||||
logger.info(
|
||||
f"Market Cap Weighted: Initialized {ticker} with "
|
||||
f"{shares:.2f} shares @ ${price:.2f} (open), "
|
||||
f"weight={weight:.2%}",
|
||||
)
|
||||
self.market_cap_initialized = True
|
||||
|
||||
total_value = self.market_cap_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.market_cap_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def calculate_momentum_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
date: str,
|
||||
rebalance: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate momentum strategy portfolio value
|
||||
|
||||
Strategy: Monthly rebalance
|
||||
- Long top 50% momentum stocks
|
||||
- Short bottom 50% momentum stocks (if shorting enabled)
|
||||
- Equal weight within each group
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
open_prices: Opening prices (used for rebalancing trades)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
momentum_scores: Momentum scores for each ticker
|
||||
date: Current date (YYYY-MM-DD)
|
||||
rebalance: Force rebalance if True
|
||||
"""
|
||||
should_rebalance = rebalance
|
||||
if self.momentum_last_rebalance_date is None:
|
||||
should_rebalance = True
|
||||
elif not rebalance:
|
||||
last_date = datetime.strptime(
|
||||
self.momentum_last_rebalance_date,
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
current_date = datetime.strptime(date, "%Y-%m-%d")
|
||||
if (current_date.year, current_date.month) != (
|
||||
last_date.year,
|
||||
last_date.month,
|
||||
):
|
||||
should_rebalance = True
|
||||
|
||||
if should_rebalance:
|
||||
self._rebalance_momentum_portfolio(
|
||||
tickers,
|
||||
open_prices,
|
||||
momentum_scores,
|
||||
)
|
||||
self.momentum_last_rebalance_date = date
|
||||
|
||||
total_value = self.momentum_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.momentum_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def _rebalance_momentum_portfolio(
|
||||
self,
|
||||
tickers: List[str],
|
||||
prices: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
):
|
||||
"""Rebalance momentum portfolio based on current momentum scores"""
|
||||
current_value = self.momentum_portfolio["cash"]
|
||||
for ticker, shares in self.momentum_portfolio["positions"].items():
|
||||
price = prices.get(ticker, 0)
|
||||
current_value += shares * price
|
||||
|
||||
self.momentum_portfolio["positions"] = {}
|
||||
|
||||
sorted_tickers = sorted(
|
||||
tickers,
|
||||
key=lambda t: momentum_scores.get(t, 0),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
mid_point = len(sorted_tickers) // 2
|
||||
long_tickers = (
|
||||
sorted_tickers[:mid_point] if mid_point > 0 else sorted_tickers
|
||||
)
|
||||
|
||||
if len(long_tickers) == 0:
|
||||
self.momentum_portfolio["cash"] = current_value
|
||||
return
|
||||
|
||||
allocation_per_ticker = current_value / len(long_tickers)
|
||||
used_capital = 0.0
|
||||
|
||||
for ticker in long_tickers:
|
||||
price = prices.get(ticker, 0)
|
||||
if price > 0:
|
||||
shares = allocation_per_ticker / price
|
||||
self.momentum_portfolio["positions"][ticker] = shares
|
||||
used_capital += allocation_per_ticker
|
||||
|
||||
self.momentum_portfolio["cash"] = current_value - used_capital
|
||||
|
||||
def get_all_baseline_values(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
date: str,
|
||||
rebalance_momentum: bool = False,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Get all baseline portfolio values in one call
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase/rebalancing)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
market_caps: Market caps for each ticker
|
||||
momentum_scores: Momentum scores for rebalancing
|
||||
date: Current date
|
||||
rebalance_momentum: Whether to rebalance momentum portfolio
|
||||
|
||||
Returns:
|
||||
Dict with keys: equal_weight, market_cap_weighted, momentum
|
||||
"""
|
||||
equal_weight_value = self.calculate_equal_weight_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
)
|
||||
market_cap_value = self.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
market_caps,
|
||||
)
|
||||
momentum_value = self.calculate_momentum_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
momentum_scores,
|
||||
date,
|
||||
rebalance_momentum,
|
||||
)
|
||||
|
||||
return {
|
||||
"equal_weight": equal_weight_value,
|
||||
"market_cap_weighted": market_cap_value,
|
||||
"momentum": momentum_value,
|
||||
}
|
||||
|
||||
def export_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Export calculator state for persistence
|
||||
|
||||
Returns:
|
||||
Dictionary containing all portfolio states for serialization
|
||||
"""
|
||||
return {
|
||||
"baseline_state": {
|
||||
"initialized": self.equal_weight_initialized,
|
||||
"initial_allocation": dict(
|
||||
self.equal_weight_portfolio["positions"],
|
||||
),
|
||||
},
|
||||
"baseline_vw_state": {
|
||||
"initialized": self.market_cap_initialized,
|
||||
"initial_allocation": dict(
|
||||
self.market_cap_portfolio["positions"],
|
||||
),
|
||||
},
|
||||
"momentum_state": {
|
||||
"positions": dict(self.momentum_portfolio["positions"]),
|
||||
"cash": self.momentum_portfolio["cash"],
|
||||
"initialized": self.momentum_last_rebalance_date is not None,
|
||||
"last_rebalance_date": self.momentum_last_rebalance_date,
|
||||
},
|
||||
}
|
||||
|
||||
def load_state(self, state: Dict[str, Any]):
|
||||
"""
|
||||
Load calculator state from persistence
|
||||
|
||||
Args:
|
||||
state: Dictionary containing baseline_state, baseline_vw_state,
|
||||
momentum_state from storage
|
||||
"""
|
||||
# Load equal-weight state
|
||||
baseline_state = state.get("baseline_state", {})
|
||||
if baseline_state.get("initialized", False):
|
||||
self.equal_weight_initialized = True
|
||||
self.equal_weight_portfolio["positions"] = dict(
|
||||
baseline_state.get("initial_allocation", {}),
|
||||
)
|
||||
self.equal_weight_portfolio["cash"] = 0.0
|
||||
logger.info(
|
||||
f"Restored equal-weight portfolio with "
|
||||
f"{len(self.equal_weight_portfolio['positions'])} positions",
|
||||
)
|
||||
|
||||
# Load market-cap-weighted state
|
||||
baseline_vw_state = state.get("baseline_vw_state", {})
|
||||
if baseline_vw_state.get("initialized", False):
|
||||
self.market_cap_initialized = True
|
||||
self.market_cap_portfolio["positions"] = dict(
|
||||
baseline_vw_state.get("initial_allocation", {}),
|
||||
)
|
||||
self.market_cap_portfolio["cash"] = 0.0
|
||||
logger.info(
|
||||
f"Restored market-cap portfolio with "
|
||||
f"{len(self.market_cap_portfolio['positions'])} positions",
|
||||
)
|
||||
|
||||
# Load momentum state
|
||||
momentum_state = state.get("momentum_state", {})
|
||||
if momentum_state.get("initialized", False):
|
||||
self.momentum_portfolio["positions"] = dict(
|
||||
momentum_state.get("positions", {}),
|
||||
)
|
||||
self.momentum_portfolio["cash"] = momentum_state.get(
|
||||
"cash",
|
||||
self.initial_capital,
|
||||
)
|
||||
self.momentum_last_rebalance_date = momentum_state.get(
|
||||
"last_rebalance_date",
|
||||
)
|
||||
logger.info(
|
||||
f"Restored momentum portfolio with "
|
||||
f"{len(self.momentum_portfolio['positions'])} positions, "
|
||||
f"last rebalance: {self.momentum_last_rebalance_date}",
|
||||
)
|
||||
|
||||
|
||||
def calculate_momentum_scores(
|
||||
tickers: List[str],
|
||||
prices_history: Dict[str, List[Tuple[str, float]]],
|
||||
lookback_days: int = 20,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate momentum scores for tickers
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
prices_history: Dict mapping ticker to list of (date, price) tuples
|
||||
lookback_days: Number of days to calculate momentum
|
||||
|
||||
Returns:
|
||||
Dict mapping ticker to momentum score (percentage return)
|
||||
"""
|
||||
momentum_scores = {}
|
||||
|
||||
for ticker in tickers:
|
||||
history = prices_history.get(ticker, [])
|
||||
if len(history) < 2:
|
||||
momentum_scores[ticker] = 0.0
|
||||
continue
|
||||
|
||||
sorted_history = sorted(history, key=lambda x: x[0])
|
||||
|
||||
if len(sorted_history) < lookback_days:
|
||||
start_price = sorted_history[0][1]
|
||||
end_price = sorted_history[-1][1]
|
||||
else:
|
||||
start_price = sorted_history[-lookback_days][1]
|
||||
end_price = sorted_history[-1][1]
|
||||
|
||||
if start_price > 0:
|
||||
momentum_scores[ticker] = (end_price - start_price) / start_price
|
||||
else:
|
||||
momentum_scores[ticker] = 0.0
|
||||
|
||||
return momentum_scores
|
||||
321
backend/utils/msg_adapter.py
Normal file
321
backend/utils/msg_adapter.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Message Adapter - Converts AgentScope Msg to frontend JSON format
|
||||
Ensures compatibility with existing frontend without modifications
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrontendAdapter:
|
||||
"""
|
||||
Adapter to convert AgentScope messages to frontend-compatible format
|
||||
|
||||
Frontend expects specific message types:
|
||||
- agent: Agent thinking/analysis messages
|
||||
- team_summary: Portfolio summary with equity curves
|
||||
- team_holdings: Current portfolio holdings
|
||||
- team_stats: Portfolio statistics
|
||||
- team_trades: Trade history
|
||||
- team_leaderboard: Agent performance rankings
|
||||
- price_update: Real-time price updates
|
||||
- system: System notifications
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse(msg: Msg) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse AgentScope Msg to frontend format
|
||||
|
||||
Args:
|
||||
msg: AgentScope Msg object
|
||||
|
||||
Returns:
|
||||
Dictionary in frontend format, or None if message should be skipped
|
||||
"""
|
||||
if msg is None:
|
||||
return None
|
||||
|
||||
# Determine message type based on metadata or content
|
||||
msg_type = FrontendAdapter._determine_type(msg)
|
||||
|
||||
if msg_type == "agent":
|
||||
return FrontendAdapter._format_agent_msg(msg)
|
||||
elif msg_type == "portfolio_update":
|
||||
return FrontendAdapter._format_portfolio_msg(msg)
|
||||
elif msg_type == "system":
|
||||
return FrontendAdapter._format_system_msg(msg)
|
||||
else:
|
||||
# Default: treat as agent message
|
||||
return FrontendAdapter._format_agent_msg(msg)
|
||||
|
||||
@staticmethod
|
||||
def _determine_type(msg: Msg) -> str:
|
||||
"""Determine frontend message type from Msg"""
|
||||
# Check metadata for explicit type
|
||||
if hasattr(msg, "metadata") and msg.metadata:
|
||||
if "type" in msg.metadata:
|
||||
return msg.metadata["type"]
|
||||
|
||||
# Check if message contains portfolio update
|
||||
if "portfolio" in msg.metadata:
|
||||
return "portfolio_update"
|
||||
|
||||
# Check message name/role
|
||||
if msg.name == "system":
|
||||
return "system"
|
||||
|
||||
# Default to agent message
|
||||
return "agent"
|
||||
|
||||
@staticmethod
|
||||
def _format_agent_msg(msg: object) -> Dict[str, Any]:
|
||||
"""
|
||||
Format agent message for frontend
|
||||
|
||||
Args:
|
||||
msg: Either AgentScope Msg or dict from pipeline results
|
||||
|
||||
Frontend expects:
|
||||
{
|
||||
"type": "agent",
|
||||
"role_key": "analyst_id",
|
||||
"content": "message text",
|
||||
"timestamp": "ISO timestamp"
|
||||
}
|
||||
"""
|
||||
# Handle dict from pipeline results
|
||||
if isinstance(msg, dict):
|
||||
name = msg.get("agent", "unknown")
|
||||
content = msg.get("content", "")
|
||||
else:
|
||||
# Handle Msg object
|
||||
name = msg.name
|
||||
content = msg.content
|
||||
|
||||
return {
|
||||
"type": "agent",
|
||||
"role_key": name,
|
||||
"content": content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_portfolio_msg(msg: Msg) -> Dict[str, Any]:
|
||||
"""
|
||||
Format portfolio update message
|
||||
|
||||
This typically generates multiple frontend messages:
|
||||
- team_summary
|
||||
- team_holdings
|
||||
- team_stats
|
||||
- team_trades (if trades were executed)
|
||||
"""
|
||||
metadata = msg.metadata or {}
|
||||
portfolio = metadata.get("portfolio", {})
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# Generate holdings message
|
||||
holdings = FrontendAdapter.build_holdings(portfolio)
|
||||
if holdings:
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_holdings",
|
||||
"data": holdings,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Generate stats message
|
||||
stats = FrontendAdapter.build_stats(portfolio)
|
||||
if stats:
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_stats",
|
||||
"data": stats,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Generate trades message if execution logs exist
|
||||
execution_logs = metadata.get("execution_logs", [])
|
||||
if execution_logs:
|
||||
trades = FrontendAdapter.build_trades(execution_logs)
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_trades",
|
||||
"mode": "incremental",
|
||||
"data": trades,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Return composite message
|
||||
return {
|
||||
"type": "composite",
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_system_msg(msg: Msg) -> Dict[str, Any]:
|
||||
"""Format system message"""
|
||||
return {
|
||||
"type": "system",
|
||||
"content": msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else json.dumps(msg.content),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_holdings(
|
||||
portfolio: Dict[str, Any],
|
||||
prices: Dict[str, float] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build holdings array from portfolio state"""
|
||||
holdings = []
|
||||
prices = prices or {}
|
||||
|
||||
positions = portfolio.get("positions", {})
|
||||
cash = portfolio.get("cash", 0.0)
|
||||
|
||||
# Calculate total value using current prices
|
||||
total_value = cash
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
total_value += (long_shares - short_shares) * price
|
||||
|
||||
# Build holdings for each position
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
avg_price = position.get("avg_price", 0)
|
||||
current_price = prices.get(ticker) or avg_price
|
||||
|
||||
net_shares = long_shares - short_shares
|
||||
if net_shares == 0:
|
||||
continue
|
||||
|
||||
market_value = net_shares * current_price
|
||||
weight = market_value / total_value if total_value > 0 else 0
|
||||
|
||||
holdings.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"quantity": net_shares,
|
||||
"avg": avg_price,
|
||||
"currentPrice": current_price,
|
||||
"marketValue": market_value,
|
||||
"weight": weight,
|
||||
},
|
||||
)
|
||||
|
||||
# Add cash as a holding
|
||||
if cash > 0:
|
||||
holdings.append(
|
||||
{
|
||||
"ticker": "CASH",
|
||||
"quantity": 1,
|
||||
"avg": cash,
|
||||
"currentPrice": cash,
|
||||
"marketValue": cash,
|
||||
"weight": cash / total_value if total_value > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
return holdings
|
||||
|
||||
@staticmethod
|
||||
def build_stats(
|
||||
portfolio: Dict[str, Any],
|
||||
prices: Dict[str, float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build stats dictionary from portfolio"""
|
||||
prices = prices or {}
|
||||
positions = portfolio.get("positions", {})
|
||||
cash = portfolio.get("cash", 0.0)
|
||||
margin_used = portfolio.get("margin_used", 0.0)
|
||||
|
||||
# Calculate total value using current prices
|
||||
total_value = cash
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
total_value += (long_shares - short_shares) * price
|
||||
|
||||
# Calculate ticker weights
|
||||
ticker_weights = {}
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
|
||||
market_value = (long_shares - short_shares) * price
|
||||
if market_value != 0:
|
||||
ticker_weights[ticker] = (
|
||||
market_value / total_value if total_value > 0 else 0
|
||||
)
|
||||
|
||||
# Calculate total return
|
||||
initial_cash = portfolio.get("initial_cash", 100000.0)
|
||||
total_return = (
|
||||
((total_value - initial_cash) / initial_cash * 100)
|
||||
if initial_cash > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"totalAssetValue": round(total_value, 2),
|
||||
"totalReturn": round(total_return, 2),
|
||||
"cashPosition": round(cash, 2),
|
||||
"tickerWeights": ticker_weights,
|
||||
"marginUsed": round(margin_used, 2),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_trades(execution_logs: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build trades array from execution logs
|
||||
|
||||
Frontend expects:
|
||||
[{
|
||||
"ts": 1234567890,
|
||||
"ticker": "AAPL",
|
||||
"side": "LONG",
|
||||
"qty": 100,
|
||||
"price": 150.0,
|
||||
"reason": "Buy signal"
|
||||
}, ...]
|
||||
"""
|
||||
trades = []
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
|
||||
for log in execution_logs:
|
||||
# Parse execution log (simplified - should use structured data)
|
||||
if "Executed" in log:
|
||||
# Extract trade details from log string
|
||||
# in real implementation, pass structured data
|
||||
trades.append(
|
||||
{
|
||||
"ts": timestamp,
|
||||
"ticker": "UNKNOWN", # Should parse from log
|
||||
"side": "LONG", # Should parse from log
|
||||
"qty": 0, # Should parse from log
|
||||
"price": 0.0, # Should parse from log
|
||||
"reason": log,
|
||||
},
|
||||
)
|
||||
|
||||
return trades
|
||||
140
backend/utils/progress.py
Normal file
140
backend/utils/progress.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.style import Style
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class AgentProgress:
|
||||
"""Manages progress tracking for multiple agents."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_status = {}
|
||||
self.table = Table(show_header=False, box=None, padding=(0, 1))
|
||||
self.live = Live(self.table, console=console, refresh_per_second=4)
|
||||
self.started = False
|
||||
self.update_handlers = []
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Register a handler to be called when agent status updates."""
|
||||
self.update_handlers.append(handler)
|
||||
return handler # Return handler to support use as decorator
|
||||
|
||||
def unregister_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Unregister a previously registered handler."""
|
||||
if handler in self.update_handlers:
|
||||
self.update_handlers.remove(handler)
|
||||
|
||||
def start(self):
|
||||
"""Start the progress display."""
|
||||
if not self.started:
|
||||
self.live.start()
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the progress display."""
|
||||
if self.started:
|
||||
self.live.stop()
|
||||
self.started = False
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
agent_name: str,
|
||||
ticker: Optional[str] = None,
|
||||
status: str = "",
|
||||
analysis: Optional[str] = None,
|
||||
):
|
||||
"""Update the status of an agent."""
|
||||
if agent_name not in self.agent_status:
|
||||
self.agent_status[agent_name] = {"status": "", "ticker": None}
|
||||
|
||||
if ticker:
|
||||
self.agent_status[agent_name]["ticker"] = ticker
|
||||
if status:
|
||||
self.agent_status[agent_name]["status"] = status
|
||||
if analysis:
|
||||
self.agent_status[agent_name]["analysis"] = analysis
|
||||
|
||||
# Set the timestamp as UTC datetime
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
self.agent_status[agent_name]["timestamp"] = timestamp
|
||||
|
||||
# Notify all registered handlers
|
||||
for handler in self.update_handlers:
|
||||
handler(agent_name, ticker, status, analysis, timestamp)
|
||||
|
||||
self._refresh_display()
|
||||
|
||||
def get_all_status(self):
|
||||
"""Get the current status of all agents as a dictionary."""
|
||||
return {
|
||||
agent_name: {
|
||||
"ticker": info["ticker"],
|
||||
"status": info["status"],
|
||||
"display_name": self._get_display_name(agent_name),
|
||||
}
|
||||
for agent_name, info in self.agent_status.items()
|
||||
}
|
||||
|
||||
def _get_display_name(self, agent_name: str) -> str:
|
||||
"""Convert agent_name to a display-friendly format."""
|
||||
return agent_name.replace("_agent", "").replace("_", " ").title()
|
||||
|
||||
def _refresh_display(self):
|
||||
"""Refresh the progress display."""
|
||||
self.table.columns.clear()
|
||||
self.table.add_column(width=100)
|
||||
|
||||
# Sort Risk Management and Portfolio Management at the bottom
|
||||
def sort_key(item):
|
||||
agent_name = item[0]
|
||||
if "risk_manager" in agent_name:
|
||||
return (2, agent_name)
|
||||
elif "portfolio_manager" in agent_name:
|
||||
return (3, agent_name)
|
||||
else:
|
||||
return (1, agent_name)
|
||||
|
||||
for agent_name, info in sorted(
|
||||
self.agent_status.items(),
|
||||
key=sort_key,
|
||||
):
|
||||
status = info["status"]
|
||||
ticker = info["ticker"]
|
||||
# Create the status text with appropriate styling
|
||||
if status.lower() == "done":
|
||||
style = Style(color="green", bold=True)
|
||||
symbol = "✓"
|
||||
elif status.lower() == "error":
|
||||
style = Style(color="red", bold=True)
|
||||
symbol = "✗"
|
||||
else:
|
||||
style = Style(color="yellow")
|
||||
symbol = "⋯"
|
||||
|
||||
agent_display = self._get_display_name(agent_name)
|
||||
status_text = Text()
|
||||
status_text.append(f"{symbol} ", style=style)
|
||||
status_text.append(f"{agent_display:<20}", style=Style(bold=True))
|
||||
|
||||
if ticker:
|
||||
status_text.append(f"[{ticker}] ", style=Style(color="cyan"))
|
||||
status_text.append(status, style=style)
|
||||
|
||||
self.table.add_row(status_text)
|
||||
|
||||
|
||||
# Create a global instance
|
||||
progress = AgentProgress()
|
||||
362
backend/utils/settlement.py
Normal file
362
backend/utils/settlement.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Settlement Coordinator
|
||||
Unified daily settlement logic for agent portfolio, baselines, and analyst tracking
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.services.storage import StorageService
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SettlementCoordinator:
|
||||
"""
|
||||
Coordinates daily settlement after market close
|
||||
|
||||
Responsibilities:
|
||||
1. Calculate agent portfolio P&L
|
||||
2. Update baseline portfolios (equal-weight, market-cap, momentum)
|
||||
3. Evaluate analyst predictions and update leaderboard
|
||||
4. Update summary.json with all portfolio values
|
||||
5. Persist state to storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: "StorageService",
|
||||
initial_capital: float = 100000.0,
|
||||
):
|
||||
self.storage = storage
|
||||
self.initial_capital = initial_capital
|
||||
self.baseline_calculator = BaselineCalculator(initial_capital)
|
||||
self.analyst_tracker = AnalystPerformanceTracker()
|
||||
|
||||
self.price_history: Dict[str, List[tuple]] = {}
|
||||
|
||||
# Load persisted state from storage
|
||||
self._load_persisted_state()
|
||||
|
||||
def _load_persisted_state(self):
|
||||
"""
|
||||
Load persisted baseline and price history state from storage
|
||||
|
||||
This restores the baseline calculator state so that backtest/live mode
|
||||
can resume from where it left off.
|
||||
"""
|
||||
internal_state = self.storage.load_internal_state()
|
||||
|
||||
# Load baseline calculator state
|
||||
baseline_state = {
|
||||
"baseline_state": internal_state.get("baseline_state", {}),
|
||||
"baseline_vw_state": internal_state.get("baseline_vw_state", {}),
|
||||
"momentum_state": internal_state.get("momentum_state", {}),
|
||||
}
|
||||
self.baseline_calculator.load_state(baseline_state)
|
||||
|
||||
# Load price history for momentum calculation
|
||||
saved_price_history = internal_state.get("price_history", {})
|
||||
if saved_price_history:
|
||||
# Convert saved format back to list of tuples
|
||||
for ticker, history in saved_price_history.items():
|
||||
converted_history = []
|
||||
for entry in history:
|
||||
if isinstance(entry, dict):
|
||||
converted_history.append(
|
||||
(entry["date"], entry["price"]),
|
||||
)
|
||||
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
converted_history.append((entry[0], entry[1]))
|
||||
else:
|
||||
continue
|
||||
self.price_history[ticker] = converted_history
|
||||
logger.info(
|
||||
f"Restored price history for {len(self.price_history)} tickers",
|
||||
)
|
||||
|
||||
def _save_persisted_state(self):
|
||||
"""
|
||||
Save baseline and price history state to storage
|
||||
|
||||
This persists the baseline calculator state so that backtest/live mode
|
||||
can resume from where it left off after restart.
|
||||
"""
|
||||
internal_state = self.storage.load_internal_state()
|
||||
|
||||
# Export baseline calculator state
|
||||
baseline_state = self.baseline_calculator.export_state()
|
||||
internal_state["baseline_state"] = baseline_state["baseline_state"]
|
||||
internal_state["baseline_vw_state"] = baseline_state[
|
||||
"baseline_vw_state"
|
||||
]
|
||||
internal_state["momentum_state"] = baseline_state["momentum_state"]
|
||||
|
||||
# Save price history (convert tuples to dicts for JSON serialization)
|
||||
price_history_serializable = {}
|
||||
for ticker, history in self.price_history.items():
|
||||
price_history_serializable[ticker] = [
|
||||
{"date": date, "price": price} for date, price in history
|
||||
]
|
||||
internal_state["price_history"] = price_history_serializable
|
||||
|
||||
self.storage.save_internal_state(internal_state)
|
||||
logger.info("Persisted baseline calculator and price history state")
|
||||
|
||||
def record_analyst_predictions(
|
||||
self,
|
||||
final_predictions: List[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Record structured analyst predictions before market close
|
||||
|
||||
Args:
|
||||
final_predictions: Structured prediction results from analysts
|
||||
Format: [
|
||||
{
|
||||
'agent': 'analyst_name',
|
||||
'predictions': [
|
||||
{'ticker': 'AAPL', 'direction': 'up', 'confidence': 0.75},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
tickers: List of tickers being analyzed
|
||||
"""
|
||||
self.analyst_tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
def update_price_history(
|
||||
self,
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
"""
|
||||
Update price history for momentum calculation
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
prices: Current prices for each ticker
|
||||
"""
|
||||
for ticker, price in prices.items():
|
||||
if ticker not in self.price_history:
|
||||
self.price_history[ticker] = []
|
||||
self.price_history[ticker].append((date, price))
|
||||
|
||||
self.price_history[ticker] = self.price_history[ticker][-60:]
|
||||
|
||||
def run_daily_settlement(
|
||||
self,
|
||||
date: str,
|
||||
tickers: List[str],
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
agent_portfolio: Dict[str, Any],
|
||||
analyst_results: List[Dict[str, Any]], # pylint: disable=W0613
|
||||
pm_decisions: Optional[Dict[str, Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run complete daily settlement
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
tickers: List of tickers
|
||||
open_prices: Opening prices
|
||||
close_prices: Closing prices
|
||||
market_caps: Market caps for each ticker
|
||||
agent_portfolio: Current agent portfolio state
|
||||
analyst_results: Analyst analysis results
|
||||
pm_decisions: PM's trading decisions
|
||||
|
||||
Returns:
|
||||
Settlement results including all portfolio values and evaluations
|
||||
"""
|
||||
logger.info(f"Running daily settlement for {date}")
|
||||
|
||||
self.update_price_history(date, close_prices)
|
||||
|
||||
momentum_scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
self.price_history,
|
||||
lookback_days=20,
|
||||
)
|
||||
|
||||
rebalance_momentum = self._should_rebalance_momentum(date)
|
||||
|
||||
baseline_values = self.baseline_calculator.get_all_baseline_values(
|
||||
tickers=tickers,
|
||||
open_prices=open_prices if open_prices else close_prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
momentum_scores=momentum_scores,
|
||||
date=date,
|
||||
rebalance_momentum=rebalance_momentum,
|
||||
)
|
||||
|
||||
logger.info(f"Baseline values calculated: {baseline_values}")
|
||||
|
||||
agent_value = self.storage.calculate_portfolio_value(
|
||||
agent_portfolio,
|
||||
close_prices,
|
||||
)
|
||||
|
||||
analyst_evaluations = self.analyst_tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
date,
|
||||
)
|
||||
|
||||
pm_evaluations = {}
|
||||
if pm_decisions:
|
||||
pm_evaluations = self.analyst_tracker.evaluate_pm_decisions(
|
||||
pm_decisions,
|
||||
open_prices,
|
||||
close_prices,
|
||||
date,
|
||||
)
|
||||
|
||||
all_evaluations = {**analyst_evaluations, **pm_evaluations}
|
||||
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
updated_leaderboard = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
all_evaluations,
|
||||
)
|
||||
self.storage.save_file("leaderboard", updated_leaderboard)
|
||||
|
||||
self._update_summary_with_baselines(
|
||||
date,
|
||||
agent_value,
|
||||
baseline_values,
|
||||
)
|
||||
|
||||
self.analyst_tracker.clear_daily_predictions()
|
||||
|
||||
# Persist baseline calculator and price history state
|
||||
self._save_persisted_state()
|
||||
|
||||
return {
|
||||
"date": date,
|
||||
"agent_portfolio_value": agent_value,
|
||||
"baseline_values": baseline_values,
|
||||
"analyst_evaluations": analyst_evaluations,
|
||||
"baselines_updated": True,
|
||||
"leaderboard_updated": True,
|
||||
}
|
||||
|
||||
def _should_rebalance_momentum(self, date: str) -> bool:
|
||||
"""
|
||||
Check if momentum portfolio should rebalance
|
||||
|
||||
Returns True if it's a new month
|
||||
"""
|
||||
last_rebalance = self.baseline_calculator.momentum_last_rebalance_date
|
||||
if last_rebalance is None:
|
||||
return True
|
||||
|
||||
last_date = datetime.strptime(last_rebalance, "%Y-%m-%d")
|
||||
current_date = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
return (current_date.year, current_date.month) != (
|
||||
last_date.year,
|
||||
last_date.month,
|
||||
)
|
||||
|
||||
def _update_summary_with_baselines(
|
||||
self,
|
||||
date: str,
|
||||
agent_value: float,
|
||||
baseline_values: Dict[str, float],
|
||||
):
|
||||
"""
|
||||
Update summary.json with agent and baseline portfolio values
|
||||
|
||||
NOTE: History updates are now handled centrally by storage.update_dashboard_after_cycle()
|
||||
to ensure all histories (equity, baseline, baseline_vw, momentum) stay synchronized.
|
||||
baseline_values are returned in run_daily_settlement() and passed to storage.
|
||||
|
||||
Args:
|
||||
date: Trading date (used for backtest-compatible timestamps)
|
||||
agent_value: Agent portfolio value
|
||||
baseline_values: Baseline portfolio values
|
||||
"""
|
||||
# History updates are now handled by storage.update_dashboard_after_cycle()
|
||||
# which receives baseline_values from settlement_result and updates all histories together.
|
||||
# This ensures equity and baseline data points are always synchronized.
|
||||
|
||||
def update_intraday_values(
|
||||
self,
|
||||
tickers: List[str],
|
||||
current_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
agent_portfolio: Dict[str, Any],
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Update portfolio values with current prices (for live mode intraday updates)
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
current_prices: Current prices
|
||||
market_caps: Market caps
|
||||
agent_portfolio: Current agent portfolio
|
||||
|
||||
Returns:
|
||||
Dict with current portfolio values
|
||||
"""
|
||||
agent_value = self.storage.calculate_portfolio_value(
|
||||
agent_portfolio,
|
||||
current_prices,
|
||||
)
|
||||
|
||||
equal_weight = self.baseline_calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
)
|
||||
market_cap = (
|
||||
self.baseline_calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
market_caps,
|
||||
)
|
||||
)
|
||||
|
||||
momentum_scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
self.price_history,
|
||||
lookback_days=20,
|
||||
)
|
||||
|
||||
last_date = (
|
||||
list(self.price_history.values())[0][-1][0]
|
||||
if self.price_history
|
||||
else ""
|
||||
)
|
||||
|
||||
momentum = self.baseline_calculator.calculate_momentum_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
momentum_scores,
|
||||
date=last_date,
|
||||
rebalance=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"agent": agent_value,
|
||||
"equal_weight": equal_weight,
|
||||
"market_cap_weighted": market_cap,
|
||||
"momentum": momentum,
|
||||
}
|
||||
348
backend/utils/terminal_dashboard.py
Normal file
348
backend/utils/terminal_dashboard.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Terminal Dashboard - Persistent unified panel using Rich Live
|
||||
"""
|
||||
# pylint: disable=R0915,R0912
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TerminalDashboard:
|
||||
"""Unified persistent terminal dashboard"""
|
||||
|
||||
def __init__(self, console: Console = None):
|
||||
self.console = console or Console()
|
||||
self.live: Optional[Live] = None
|
||||
|
||||
# Config state
|
||||
self.mode = "live"
|
||||
self.config_name = ""
|
||||
self.host = "0.0.0.0"
|
||||
self.port = 8765
|
||||
self.poll_interval = 10
|
||||
self.trigger_time = "now"
|
||||
self.mock = False
|
||||
self.enable_memory = False
|
||||
self.local_time = ""
|
||||
self.nyse_time = ""
|
||||
self.start_date = ""
|
||||
self.end_date = ""
|
||||
self.tickers: List[str] = []
|
||||
self.initial_cash = 100000.0
|
||||
|
||||
# Trading state
|
||||
self.current_date = "-"
|
||||
self.status = "Initializing"
|
||||
self.total_value = 0.0
|
||||
self.cash = 0.0
|
||||
self.pnl_pct = 0.0
|
||||
self.holdings: List[Dict] = []
|
||||
self.trades: List[Dict] = []
|
||||
self.days_completed = 0
|
||||
self.days_total = 0
|
||||
|
||||
# Progress message (last line)
|
||||
self.progress = ""
|
||||
self._dots_index = 0
|
||||
self._animator_running = False
|
||||
self._animator_thread: Optional[threading.Thread] = None
|
||||
|
||||
def set_config(
|
||||
self,
|
||||
mode: str,
|
||||
config_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
poll_interval: int,
|
||||
trigger_time: str = "now",
|
||||
mock: bool = False,
|
||||
enable_memory: bool = False,
|
||||
local_time: str = "",
|
||||
nyse_time: str = "",
|
||||
start_date: str = "",
|
||||
end_date: str = "",
|
||||
tickers: List[str] = None,
|
||||
initial_cash: float = 100000.0,
|
||||
):
|
||||
"""Set configuration state"""
|
||||
self.mode = mode
|
||||
self.config_name = config_name
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.poll_interval = poll_interval
|
||||
self.trigger_time = trigger_time
|
||||
self.mock = mock
|
||||
self.enable_memory = enable_memory
|
||||
self.local_time = local_time
|
||||
self.nyse_time = nyse_time
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.tickers = tickers or []
|
||||
self.initial_cash = initial_cash
|
||||
self.total_value = initial_cash
|
||||
self.cash = initial_cash
|
||||
|
||||
def _build_panel(self) -> Panel:
|
||||
"""Build the unified dashboard panel"""
|
||||
# Main grid
|
||||
main_table = Table.grid(padding=(0, 2))
|
||||
main_table.add_column(width=28)
|
||||
main_table.add_column(width=22)
|
||||
main_table.add_column(width=22)
|
||||
|
||||
# Left: Config + Status
|
||||
left = Table.grid(padding=(0, 0))
|
||||
left.add_column()
|
||||
|
||||
# Mode line
|
||||
if self.mode == "backtest":
|
||||
mode_str = "[cyan]Backtest[/cyan]"
|
||||
elif self.mock:
|
||||
mode_str = "[yellow]MOCK[/yellow]"
|
||||
else:
|
||||
mode_str = "[green]LIVE[/green]"
|
||||
|
||||
left.add_row(f"[bold]Mode:[/bold] {mode_str}")
|
||||
left.add_row(f"[dim]Config:[/dim] {self.config_name}")
|
||||
left.add_row(f"[dim]Server:[/dim] {self.host}:{self.port}")
|
||||
|
||||
if self.mode == "live" and self.nyse_time:
|
||||
left.add_row(f"[dim]NYSE:[/dim] {self.nyse_time[:19]}")
|
||||
trigger_display = (
|
||||
"[green]NOW[/green]"
|
||||
if self.trigger_time == "now"
|
||||
else self.trigger_time
|
||||
)
|
||||
left.add_row(f"[dim]Trigger:[/dim] {trigger_display}")
|
||||
|
||||
# Status
|
||||
left.add_row("")
|
||||
status_style = "green" if self.status == "Running" else "yellow"
|
||||
left.add_row(
|
||||
"[bold]Status:[/bold] "
|
||||
f"[{status_style}]{self.status}[/{status_style}]",
|
||||
)
|
||||
if self.mode == "backtest":
|
||||
left.add_row(
|
||||
f"[dim]Backtesting Period:[/dim] {self.days_total} days\n"
|
||||
f" {self.start_date} -> {self.end_date}",
|
||||
)
|
||||
left.add_row(f"[dim]Current Date:[/dim] {self.current_date}")
|
||||
|
||||
# Middle: Portfolio
|
||||
mid = Table.grid(padding=(0, 0))
|
||||
mid.add_column()
|
||||
|
||||
pnl_style = "green" if self.pnl_pct >= 0 else "red"
|
||||
mid.add_row("[bold]Portfolio[/bold]")
|
||||
mid.add_row(f"NAV: [bold]${self.total_value:,.0f}[/bold]")
|
||||
mid.add_row(f"Cash: ${self.cash:,.0f}")
|
||||
mid.add_row(f"P&L: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]")
|
||||
|
||||
# Positions
|
||||
mid.add_row("")
|
||||
mid.add_row("[bold]Positions[/bold]")
|
||||
stock_holdings = [
|
||||
h for h in self.holdings if h.get("ticker") != "CASH"
|
||||
]
|
||||
if stock_holdings:
|
||||
for h in stock_holdings[:7]:
|
||||
qty = h.get("quantity", 0)
|
||||
ticker = h.get("ticker", "")[:5]
|
||||
val = h.get("marketValue", 0)
|
||||
qty_str = f"{qty:+d}" if qty != 0 else "0"
|
||||
mid.add_row(
|
||||
f"[cyan]{ticker:<5}[/cyan] {qty_str:>5} ${val:>7,.0f}",
|
||||
)
|
||||
if len(stock_holdings) > 7:
|
||||
mid.add_row(f"[dim]+{len(stock_holdings) - 7} more[/dim]")
|
||||
else:
|
||||
mid.add_row("[dim]No positions[/dim]")
|
||||
|
||||
# Right: Recent Trades
|
||||
right = Table.grid(padding=(0, 0))
|
||||
right.add_column()
|
||||
|
||||
right.add_row("[bold]Recent Trades[/bold]")
|
||||
if self.trades:
|
||||
for t in self.trades[:10]:
|
||||
side = t.get("side", "")
|
||||
ticker = t.get("ticker", "")[:5]
|
||||
qty = t.get("qty", 0)
|
||||
if side == "LONG":
|
||||
side_str = "[green]L[/green]"
|
||||
elif side == "SHORT":
|
||||
side_str = "[red]S[/red]"
|
||||
else:
|
||||
side_str = "[dim]H[/dim]"
|
||||
right.add_row(f"{side_str} [cyan]{ticker:<5}[/cyan] {qty:>4}")
|
||||
if len(self.trades) > 10:
|
||||
right.add_row(f"[dim]+{len(self.trades) - 10} more[/dim]")
|
||||
else:
|
||||
right.add_row("[dim]No trades[/dim]")
|
||||
|
||||
main_table.add_row(left, mid, right)
|
||||
|
||||
# Outer table to add progress line at bottom
|
||||
outer = Table.grid(padding=(0, 0))
|
||||
outer.add_column()
|
||||
outer.add_row(main_table)
|
||||
|
||||
# Progress line (last row) with animated dots
|
||||
if self.progress:
|
||||
DOTS_FRAMES = [" ", ". ", ".. ", "..."]
|
||||
dots = DOTS_FRAMES[self._dots_index % len(DOTS_FRAMES)]
|
||||
outer.add_row("")
|
||||
outer.add_row(f"[dim]> {self.progress}{dots}[/dim]")
|
||||
|
||||
# Build panel
|
||||
title = "[bold cyan]EvoTraders[/bold cyan]"
|
||||
if self.mode == "backtest":
|
||||
title += " [dim]Backtest[/dim]"
|
||||
elif self.mock:
|
||||
title += " [dim]Mock[/dim]"
|
||||
else:
|
||||
title += " [dim]Live[/dim]"
|
||||
|
||||
return Panel(
|
||||
outer,
|
||||
title=title,
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _run_animator(self):
|
||||
"""Background thread to animate the dots"""
|
||||
while self._animator_running:
|
||||
time.sleep(0.3)
|
||||
if self.progress and self.live:
|
||||
self._dots_index += 1
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def start(self):
|
||||
"""Start the live dashboard display"""
|
||||
self.live = Live(
|
||||
self._build_panel(),
|
||||
console=self.console,
|
||||
refresh_per_second=4,
|
||||
vertical_overflow="visible",
|
||||
)
|
||||
self.live.start()
|
||||
|
||||
# Start animator thread
|
||||
self._animator_running = True
|
||||
self._animator_thread = threading.Thread(
|
||||
target=self._run_animator,
|
||||
daemon=True,
|
||||
)
|
||||
self._animator_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the live dashboard"""
|
||||
self._animator_running = False
|
||||
if self._animator_thread:
|
||||
self._animator_thread.join(timeout=0.5)
|
||||
self._animator_thread = None
|
||||
if self.live:
|
||||
self.live.stop()
|
||||
self.live = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
date: str = None,
|
||||
status: str = None,
|
||||
portfolio: Dict[str, Any] = None,
|
||||
holdings: List[Dict] = None,
|
||||
trades: List[Dict] = None,
|
||||
days_completed: int = None,
|
||||
days_total: int = None,
|
||||
):
|
||||
"""Update dashboard state and refresh display"""
|
||||
if date:
|
||||
self.current_date = date
|
||||
if status:
|
||||
self.status = status
|
||||
if days_completed is not None:
|
||||
self.days_completed = days_completed
|
||||
if days_total is not None:
|
||||
self.days_total = days_total
|
||||
|
||||
if portfolio:
|
||||
self.total_value = portfolio.get(
|
||||
"totalAssetValue",
|
||||
0,
|
||||
) or portfolio.get(
|
||||
"total_value",
|
||||
self.initial_cash,
|
||||
)
|
||||
self.cash = portfolio.get("cashPosition", 0) or portfolio.get(
|
||||
"cash",
|
||||
self.initial_cash,
|
||||
)
|
||||
if self.total_value > 0 and self.initial_cash > 0:
|
||||
self.pnl_pct = (
|
||||
(self.total_value - self.initial_cash) / self.initial_cash
|
||||
) * 100
|
||||
|
||||
if holdings is not None:
|
||||
self.holdings = holdings
|
||||
if trades is not None:
|
||||
self.trades = trades
|
||||
|
||||
if self.live:
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def log(self, msg: str, also_log: bool = True):
|
||||
"""
|
||||
Update progress message and refresh panel
|
||||
|
||||
Args:
|
||||
msg: Progress message to display
|
||||
also_log: Whether to also write to logger (default True)
|
||||
"""
|
||||
self.progress = msg
|
||||
if also_log:
|
||||
logger.info(msg)
|
||||
if self.live:
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def print_final_summary(self):
|
||||
"""Print final summary when dashboard stops"""
|
||||
pnl_style = "green" if self.pnl_pct >= 0 else "red"
|
||||
|
||||
if self.mode == "backtest":
|
||||
msg = (
|
||||
f"[bold]Backtest Complete[/bold] | "
|
||||
f"Days: {self.days_completed} | "
|
||||
f"NAV: ${self.total_value:,.0f} | "
|
||||
f"Return: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"[bold]Session End[/bold] | "
|
||||
f"NAV: ${self.total_value:,.0f} | "
|
||||
f"P&L: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]"
|
||||
)
|
||||
|
||||
self.console.print(Panel(msg, border_style="green"))
|
||||
|
||||
|
||||
# Global instance
|
||||
_dashboard: Optional[TerminalDashboard] = None
|
||||
|
||||
|
||||
def get_dashboard() -> TerminalDashboard:
|
||||
"""Get or create global dashboard instance"""
|
||||
global _dashboard
|
||||
if _dashboard is None:
|
||||
_dashboard = TerminalDashboard()
|
||||
return _dashboard
|
||||
772
backend/utils/trade_executor.py
Normal file
772
backend/utils/trade_executor.py
Normal file
@@ -0,0 +1,772 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Trading Execution Engine - Supports Two Modes
|
||||
1. Signal mode: Only records directional signal decisions
|
||||
2. Portfolio mode: Executes specific trades and tracks positions
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class DirectionSignalRecorder:
|
||||
"""Direction signal recorder, records daily investment direction decisions"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize direction signal recorder"""
|
||||
self.signal_log = [] # Record all directional signal history
|
||||
|
||||
def record_direction_signals(
|
||||
self,
|
||||
decisions: Dict[str, Dict[str, Any]],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Record Portfolio Manager's directional signal decisions
|
||||
|
||||
Args:
|
||||
decisions: PM's direction decisions {ticker: {action, confidence, reasoning}}
|
||||
current_date: Current date (used for backtest compatibility)
|
||||
|
||||
Returns:
|
||||
Signal recording report
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Use provided date for timestamp (backtest compatible)
|
||||
timestamp = f"{current_date}T09:30:00"
|
||||
|
||||
signal_report: Dict[str, Any] = {
|
||||
"recorded_signals": {},
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
"total_signals": len(decisions),
|
||||
}
|
||||
|
||||
print(
|
||||
f"\n📊 Recording directional signal decisions for {current_date}...",
|
||||
)
|
||||
|
||||
# Record directional signal for each ticker
|
||||
for ticker, decision in decisions.items():
|
||||
action = decision.get("action", "hold")
|
||||
confidence = decision.get("confidence", 0)
|
||||
reasoning = decision.get("reasoning", "")
|
||||
|
||||
# Record signal
|
||||
signal_record = {
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
self.signal_log.append(signal_record)
|
||||
signal_report["recorded_signals"][ticker] = {
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
# Display signal
|
||||
action_emoji = {"long": "📈", "short": "📉", "hold": "➖"}
|
||||
emoji = action_emoji.get(action, "❓")
|
||||
print(
|
||||
f" {emoji} {ticker}: {action.upper()} (Confidence: {confidence}%) - {reasoning}",
|
||||
)
|
||||
|
||||
print(f"\n✅ Recorded directional signals for {len(decisions)} stocks")
|
||||
|
||||
return signal_report
|
||||
|
||||
def get_signal_summary(self) -> Dict[str, Any]:
|
||||
"""Get signal recording summary"""
|
||||
return {
|
||||
"total_signals": len(self.signal_log),
|
||||
"signal_log": self.signal_log,
|
||||
}
|
||||
|
||||
|
||||
def parse_pm_decisions(pm_output: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Parse Portfolio Manager output format
|
||||
|
||||
Args:
|
||||
pm_output: PM's raw output
|
||||
|
||||
Returns:
|
||||
Standardized decision format
|
||||
"""
|
||||
if isinstance(pm_output, dict) and "decisions" in pm_output:
|
||||
return pm_output["decisions"]
|
||||
elif isinstance(pm_output, dict):
|
||||
# If directly a decision dictionary
|
||||
return pm_output
|
||||
else:
|
||||
print(f"Warning: Unable to parse PM output format: {type(pm_output)}")
|
||||
return {}
|
||||
|
||||
|
||||
class PortfolioTradeExecutor:
|
||||
"""Portfolio mode trade executor, executes specific trades and tracks positions"""
|
||||
|
||||
portfolio: Dict[str, Any]
|
||||
trade_history: List[Dict[str, Any]]
|
||||
portfolio_history: List[Dict[str, Any]]
|
||||
|
||||
def __init__(self, initial_portfolio: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize Portfolio trade executor
|
||||
|
||||
Args:
|
||||
initial_portfolio: Initial portfolio state
|
||||
"""
|
||||
|
||||
if initial_portfolio is None:
|
||||
self.portfolio = {
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
# Default 0.0 (short selling disabled)
|
||||
"margin_requirement": 0.0,
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
else:
|
||||
self.portfolio = deepcopy(initial_portfolio)
|
||||
|
||||
self.trade_history = [] # Trade history
|
||||
self.portfolio_history = [] # Portfolio history
|
||||
|
||||
def execute_trade(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a single trade
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker
|
||||
action: Trade action (long/short/hold)
|
||||
quantity: Number of shares
|
||||
price: Current price
|
||||
current_date: Trade date
|
||||
|
||||
Returns:
|
||||
Trade result dictionary
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if action == "hold" or quantity == 0:
|
||||
return {"status": "success", "message": "No trade needed"}
|
||||
|
||||
if price <= 0:
|
||||
return {"status": "failed", "reason": "Invalid price"}
|
||||
|
||||
result = self._execute_single_trade(
|
||||
ticker=ticker,
|
||||
action=action,
|
||||
target_quantity=quantity,
|
||||
price=price,
|
||||
date=current_date,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def execute_trades(
|
||||
self,
|
||||
decisions: Dict[str, Dict[str, Any]],
|
||||
current_prices: Dict[str, float],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute trading decisions and update positions
|
||||
|
||||
Args:
|
||||
decisions: {ticker: {action, quantity, confidence, reasoning}}
|
||||
current_prices: {ticker: current_price}
|
||||
current_date: Current date (used for backtest compatibility)
|
||||
|
||||
Returns:
|
||||
Trade execution report
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Use provided date for timestamp (backtest compatible)
|
||||
timestamp = f"{current_date}T09:30:00"
|
||||
|
||||
execution_report: Dict[str, Any] = {
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
"executed_trades": [],
|
||||
"failed_trades": [],
|
||||
"portfolio_before": deepcopy(self.portfolio),
|
||||
"portfolio_after": None,
|
||||
}
|
||||
|
||||
print(f"\n💼 Executing Portfolio trades for {current_date}...")
|
||||
|
||||
# Execute trades for each ticker
|
||||
for ticker, decision in decisions.items():
|
||||
action = decision.get("action", "hold")
|
||||
quantity = decision.get("quantity", 0)
|
||||
|
||||
if action == "hold" or quantity == 0:
|
||||
continue
|
||||
|
||||
price = current_prices.get(ticker, 0)
|
||||
if price <= 0:
|
||||
execution_report["failed_trades"].append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"quantity": quantity,
|
||||
"reason": "No valid price data",
|
||||
},
|
||||
)
|
||||
print(
|
||||
f" ❌ {ticker}: Unable to execute {action} - No valid price",
|
||||
)
|
||||
continue
|
||||
|
||||
# Execute trade
|
||||
trade_result = self._execute_single_trade(
|
||||
ticker,
|
||||
action,
|
||||
quantity,
|
||||
price,
|
||||
current_date,
|
||||
)
|
||||
if trade_result["status"] == "success":
|
||||
execution_report["executed_trades"].append(trade_result)
|
||||
|
||||
trades_info = ", ".join(trade_result.get("trades", []))
|
||||
print(
|
||||
f" ✔ {ticker}: {action} Target {quantity} shares "
|
||||
f"({trades_info}) @ ${price:.2f}",
|
||||
)
|
||||
else:
|
||||
execution_report["failed_trades"].append(trade_result)
|
||||
print(
|
||||
f" ✗ {ticker}: Unable to execute {action} - {trade_result['reason']}",
|
||||
)
|
||||
|
||||
# Record final portfolio state
|
||||
execution_report["portfolio_after"] = deepcopy(self.portfolio)
|
||||
self.portfolio_history.append(
|
||||
{
|
||||
"date": current_date,
|
||||
"portfolio": deepcopy(self.portfolio),
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = self._calculate_portfolio_value(current_prices)
|
||||
execution_report["portfolio_value"] = portfolio_value
|
||||
|
||||
print("\n✔ Trade execution completed:")
|
||||
print(f" Success: {len(execution_report['executed_trades'])} trades")
|
||||
print(f" Failed: {len(execution_report['failed_trades'])} trades")
|
||||
print(f" Portfolio value: ${portfolio_value:,.2f}")
|
||||
print(f" Cash balance: ${self.portfolio['cash']:,.2f}")
|
||||
|
||||
return execution_report
|
||||
|
||||
def _execute_single_trade(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute single trade - Incremental mode
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker
|
||||
action: long(add position)/short(reduce position)/hold
|
||||
target_quantity: Incremental quantity (long=buy shares, short=sell shares)
|
||||
price: Current price
|
||||
date: Trade date
|
||||
"""
|
||||
|
||||
# Ensure position exists
|
||||
if ticker not in self.portfolio["positions"]:
|
||||
self.portfolio["positions"][ticker] = {
|
||||
"long": 0,
|
||||
"short": 0,
|
||||
"long_cost_basis": 0.0,
|
||||
"short_cost_basis": 0.0,
|
||||
}
|
||||
|
||||
position = self.portfolio["positions"][ticker]
|
||||
current_long = position["long"]
|
||||
current_short = position["short"]
|
||||
|
||||
trades_executed = [] # Record actually executed trade steps
|
||||
|
||||
if action == "long":
|
||||
result = self._execute_long_action(
|
||||
ticker,
|
||||
target_quantity,
|
||||
price,
|
||||
date,
|
||||
current_long,
|
||||
current_short,
|
||||
trades_executed,
|
||||
)
|
||||
if result["status"] == "failed":
|
||||
return result
|
||||
|
||||
elif action == "short":
|
||||
result = self._execute_short_action(
|
||||
ticker,
|
||||
target_quantity,
|
||||
price,
|
||||
date,
|
||||
current_long,
|
||||
current_short,
|
||||
trades_executed,
|
||||
)
|
||||
if result["status"] == "failed":
|
||||
return result
|
||||
|
||||
elif action == "hold":
|
||||
print(f"\n⏸️ {ticker} Position unchanged: {current_long} shares")
|
||||
|
||||
# Record trade with backtest-compatible timestamp
|
||||
trade_record = {
|
||||
"status": "success",
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"target_quantity": target_quantity,
|
||||
"price": price,
|
||||
"trades": trades_executed,
|
||||
"date": date,
|
||||
"timestamp": f"{date}T09:30:00",
|
||||
}
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
|
||||
return trade_record
|
||||
|
||||
def _execute_long_action(
|
||||
self,
|
||||
ticker: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
current_long: int,
|
||||
current_short: int,
|
||||
trades_executed: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute long action: Buy shares or cover shorts first"""
|
||||
print(
|
||||
f"\n📈 {ticker} Long operation: Current Long {current_long}, "
|
||||
f"Short {current_short} → Target quantity {target_quantity}",
|
||||
)
|
||||
|
||||
if target_quantity <= 0:
|
||||
print(" ⏸️ Quantity is 0, no trade needed")
|
||||
return {"status": "success"}
|
||||
|
||||
remaining = target_quantity
|
||||
|
||||
# If has short position, cover first
|
||||
if current_short > 0:
|
||||
cover_qty = min(remaining, current_short)
|
||||
print(f" 1️⃣ Cover short: {cover_qty} shares")
|
||||
cover_result = self._cover_short_position(
|
||||
ticker,
|
||||
cover_qty,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if cover_result["status"] == "failed":
|
||||
return cover_result
|
||||
trades_executed.append(f"Cover {cover_qty} shares")
|
||||
remaining -= cover_qty
|
||||
|
||||
# If still has remaining quantity, buy long
|
||||
if remaining > 0:
|
||||
print(f" 2️⃣ Buy long: {remaining} shares")
|
||||
buy_result = self._buy_long_position(
|
||||
ticker,
|
||||
remaining,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if buy_result["status"] == "failed":
|
||||
return buy_result
|
||||
trades_executed.append(f"Buy {remaining} shares")
|
||||
|
||||
# Display final result
|
||||
final_long = self.portfolio["positions"][ticker]["long"]
|
||||
final_short = self.portfolio["positions"][ticker]["short"]
|
||||
print(
|
||||
f" ✅ Final state: Long {final_long} shares, Short {final_short} shares",
|
||||
)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _execute_short_action(
|
||||
self,
|
||||
ticker: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
current_long: int,
|
||||
current_short: int,
|
||||
trades_executed: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute short action: Sell long positions first, then short if needed"""
|
||||
print(
|
||||
f"\n📉 {ticker} Short operation (quantity={target_quantity} shares):",
|
||||
)
|
||||
print(
|
||||
f" Current state: Long {current_long} shares, Short {current_short} shares",
|
||||
)
|
||||
|
||||
if target_quantity <= 0:
|
||||
print(" ⏸️ Quantity is 0, no trade needed")
|
||||
return {"status": "success"}
|
||||
|
||||
remaining_quantity = target_quantity
|
||||
|
||||
# Step 1: If there are long positions, sell first
|
||||
if current_long > 0:
|
||||
sell_quantity = min(remaining_quantity, current_long)
|
||||
print(f" 1️⃣ Sell long: {sell_quantity} shares")
|
||||
sell_result = self._sell_long_position(
|
||||
ticker,
|
||||
sell_quantity,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if sell_result["status"] == "failed":
|
||||
return sell_result
|
||||
trades_executed.append(f"Sell {sell_quantity} shares")
|
||||
remaining_quantity -= sell_quantity
|
||||
|
||||
# Step 2: If there's remaining quantity, establish or increase short position
|
||||
if remaining_quantity > 0:
|
||||
print(f" 2️⃣ Short: {remaining_quantity} shares")
|
||||
short_result = self._open_short_position(
|
||||
ticker,
|
||||
remaining_quantity,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if short_result["status"] == "failed":
|
||||
return short_result
|
||||
trades_executed.append(f"Short {remaining_quantity} shares")
|
||||
|
||||
# Display final result
|
||||
final_long = self.portfolio["positions"][ticker]["long"]
|
||||
final_short = self.portfolio["positions"][ticker]["short"]
|
||||
print(
|
||||
f" ✅ Final state: Long {final_long} shares, Short {final_short} shares",
|
||||
)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _buy_long_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Buy long position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
trade_value = quantity * price
|
||||
|
||||
if self.portfolio["cash"] < trade_value:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "buy",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient cash (needed: ${trade_value:.2f}, available: "
|
||||
f"${self.portfolio['cash']:.2f})",
|
||||
}
|
||||
|
||||
# Update position cost basis
|
||||
old_long = position["long"]
|
||||
old_cost_basis = position["long_cost_basis"]
|
||||
new_long = old_long + quantity
|
||||
|
||||
# 🐛 Debug info
|
||||
print(f" 🔍 Buy {ticker}:")
|
||||
print(f" Old position: {old_long} shares @ ${old_cost_basis:.2f}")
|
||||
print(f" Buy: {quantity} shares @ ${price:.2f}")
|
||||
print(f" New position: {new_long} shares")
|
||||
|
||||
if new_long > 0:
|
||||
new_cost_basis = (
|
||||
(old_long * old_cost_basis) + (quantity * price)
|
||||
) / new_long
|
||||
print(
|
||||
f" New cost: ${new_cost_basis:.2f} = "
|
||||
f"(({old_long} × ${old_cost_basis:.2f}) + "
|
||||
f"({quantity} × ${price:.2f})) / {new_long}",
|
||||
)
|
||||
position["long_cost_basis"] = new_cost_basis
|
||||
position["long"] = new_long
|
||||
|
||||
# Deduct cash
|
||||
self.portfolio["cash"] -= trade_value
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _sell_long_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sell long position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
|
||||
if position["long"] < quantity:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "sell",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient long position (holding: {position['long']},"
|
||||
f" trying to sell: {quantity})",
|
||||
}
|
||||
|
||||
# Reduce position
|
||||
position["long"] -= quantity
|
||||
if position["long"] == 0:
|
||||
position["long_cost_basis"] = 0.0
|
||||
|
||||
# Increase cash
|
||||
trade_value = quantity * price
|
||||
self.portfolio["cash"] += trade_value
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _open_short_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Open short position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
trade_value = quantity * price
|
||||
margin_needed = trade_value * self.portfolio["margin_requirement"]
|
||||
|
||||
if self.portfolio["cash"] < margin_needed:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "short",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient margin (needed: ${margin_needed:.2f}, "
|
||||
f"available: ${self.portfolio['cash']:.2f})",
|
||||
}
|
||||
|
||||
# Update position cost basis
|
||||
old_short = position["short"]
|
||||
old_cost_basis = position["short_cost_basis"]
|
||||
new_short = old_short + quantity
|
||||
if new_short > 0:
|
||||
position["short_cost_basis"] = (
|
||||
(old_short * old_cost_basis) + (quantity * price)
|
||||
) / new_short
|
||||
position["short"] = new_short
|
||||
|
||||
# Increase cash (short sale proceeds) and margin used
|
||||
self.portfolio["cash"] += trade_value - margin_needed
|
||||
self.portfolio["margin_used"] += margin_needed
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _cover_short_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Cover short position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
|
||||
if position["short"] < quantity:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "cover",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient short position (holding: {position['short']}, "
|
||||
f"trying to cover: {quantity})",
|
||||
}
|
||||
|
||||
# Calculate released margin - 🔧 FIX: Use cost_basis instead of current price
|
||||
trade_value = quantity * price
|
||||
cost_basis = position["short_cost_basis"]
|
||||
margin_released = (
|
||||
quantity * cost_basis * self.portfolio["margin_requirement"]
|
||||
)
|
||||
|
||||
# Reduce position
|
||||
position["short"] -= quantity
|
||||
if position["short"] == 0:
|
||||
position["short_cost_basis"] = 0.0
|
||||
|
||||
# Deduct cash (buy to cover) and release margin
|
||||
self.portfolio["cash"] -= trade_value
|
||||
self.portfolio["cash"] += margin_released
|
||||
self.portfolio["margin_used"] -= margin_released
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _calculate_portfolio_value(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
) -> float:
|
||||
"""Calculate total portfolio value (net liquidation value)"""
|
||||
# Add margin_used back because it's frozen cash, not lost money
|
||||
total_value = self.portfolio["cash"] + self.portfolio["margin_used"]
|
||||
|
||||
for ticker, position in self.portfolio["positions"].items():
|
||||
if ticker in current_prices:
|
||||
price = current_prices[ticker]
|
||||
# Add long position value
|
||||
total_value += position["long"] * price
|
||||
# Subtract short position value (liability)
|
||||
total_value -= position["short"] * price
|
||||
|
||||
return total_value
|
||||
|
||||
def get_portfolio_summary(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get portfolio summary"""
|
||||
portfolio_value = self._calculate_portfolio_value(current_prices)
|
||||
|
||||
positions_summary = []
|
||||
for ticker, position in self.portfolio["positions"].items():
|
||||
if position["long"] > 0 or position["short"] > 0:
|
||||
price = current_prices.get(ticker, 0)
|
||||
long_value = position["long"] * price
|
||||
short_value = position["short"] * price
|
||||
|
||||
positions_summary.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"long_shares": position["long"],
|
||||
"short_shares": position["short"],
|
||||
"long_value": long_value,
|
||||
"short_value": short_value,
|
||||
"long_cost_basis": position["long_cost_basis"],
|
||||
"short_cost_basis": position["short_cost_basis"],
|
||||
"long_pnl": (
|
||||
long_value
|
||||
- (position["long"] * position["long_cost_basis"])
|
||||
if position["long"] > 0
|
||||
else 0
|
||||
),
|
||||
"short_pnl": (
|
||||
(position["short"] * position["short_cost_basis"])
|
||||
- short_value
|
||||
if position["short"] > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"portfolio_value": portfolio_value,
|
||||
"cash": self.portfolio["cash"],
|
||||
"margin_used": self.portfolio["margin_used"],
|
||||
"positions": positions_summary,
|
||||
"total_trades": len(self.trade_history),
|
||||
}
|
||||
|
||||
|
||||
def execute_trading_decisions(
|
||||
pm_decisions: Dict[str, Any],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience function to record directional signal decisions (Signal mode)
|
||||
|
||||
Args:
|
||||
pm_decisions: PM's direction decisions
|
||||
current_date: Current date (optional)
|
||||
|
||||
Returns:
|
||||
Signal recording report
|
||||
"""
|
||||
# Parse PM decisions
|
||||
decisions = parse_pm_decisions(pm_decisions)
|
||||
|
||||
# Create direction signal recorder
|
||||
recorder = DirectionSignalRecorder()
|
||||
|
||||
# Record directional signals
|
||||
signal_report = recorder.record_direction_signals(decisions, current_date)
|
||||
|
||||
return signal_report
|
||||
|
||||
|
||||
def execute_portfolio_trades(
|
||||
pm_decisions: Dict[str, Any],
|
||||
current_prices: Dict[str, float],
|
||||
portfolio: Dict[str, Any],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute Portfolio mode trading decisions
|
||||
|
||||
Args:
|
||||
pm_decisions: PM's trading decisions
|
||||
current_prices: Current prices
|
||||
portfolio: Current portfolio state
|
||||
current_date: Current date (optional)
|
||||
|
||||
Returns:
|
||||
Trade execution report and updated portfolio
|
||||
"""
|
||||
# Parse PM decisions
|
||||
decisions = parse_pm_decisions(pm_decisions)
|
||||
|
||||
# Create Portfolio trade executor
|
||||
executor = PortfolioTradeExecutor(initial_portfolio=portfolio)
|
||||
|
||||
# Execute trades
|
||||
execution_report = executor.execute_trades(
|
||||
decisions,
|
||||
current_prices,
|
||||
current_date,
|
||||
)
|
||||
|
||||
# Add portfolio summary
|
||||
execution_report["portfolio_summary"] = executor.get_portfolio_summary(
|
||||
current_prices,
|
||||
)
|
||||
|
||||
# Return updated portfolio
|
||||
execution_report["updated_portfolio"] = executor.portfolio
|
||||
|
||||
return execution_report
|
||||
Reference in New Issue
Block a user