diff --git a/backend/agents/base/__init__.py b/backend/agents/base/__init__.py index 85d83a2..f566569 100644 --- a/backend/agents/base/__init__.py +++ b/backend/agents/base/__init__.py @@ -13,6 +13,26 @@ from .command_handler import ( create_command_dispatcher, ) +# 评估钩子 (从evaluation_hook.py导入) +from .evaluation_hook import ( + EvaluationHook, + EvaluationCollector, + MetricType, + EvaluationMetric, + EvaluationResult, + parse_evaluation_hooks, +) + +# 技能适配钩子 (从skill_adaptation_hook.py导入) +from .skill_adaptation_hook import ( + AdaptationAction, + AdaptationThreshold, + AdaptationEvent, + SkillAdaptationHook, + AdaptationManager, + get_adaptation_manager, +) + __all__ = [ # 命令处理 "AgentCommandDispatcher", @@ -20,4 +40,18 @@ __all__ = [ "CommandHandler", "CommandResult", "create_command_dispatcher", + # 评估钩子 + "EvaluationHook", + "EvaluationCollector", + "MetricType", + "EvaluationMetric", + "EvaluationResult", + "parse_evaluation_hooks", + # 技能适配钩子 + "AdaptationAction", + "AdaptationThreshold", + "AdaptationEvent", + "SkillAdaptationHook", + "AdaptationManager", + "get_adaptation_manager", ] diff --git a/backend/agents/base/evo_agent.py b/backend/agents/base/evo_agent.py index bb2cc87..5bd0e4d 100644 --- a/backend/agents/base/evo_agent.py +++ b/backend/agents/base/evo_agent.py @@ -27,6 +27,7 @@ from .hooks import ( HookManager, BootstrapHook, MemoryCompactionHook, + WorkspaceWatchHook, HOOK_PRE_REASONING, ) from ..prompts.builder import ( @@ -36,6 +37,16 @@ from ..prompts.builder import ( from ..agent_workspace import load_agent_workspace_config from ..skills_manager import SkillsManager +# Team infrastructure imports (graceful import - may not exist yet) +try: + from backend.agents.team.messenger import AgentMessenger + from backend.agents.team.task_delegator import TaskDelegator + TEAM_INFRA_AVAILABLE = True +except ImportError: + TEAM_INFRA_AVAILABLE = False + AgentMessenger = None + TaskDelegator = None + if TYPE_CHECKING: from agentscope.formatter import FormatterBase from agentscope.model import ModelWrapperBase @@ -152,6 +163,12 @@ class EvoAgent(ToolGuardMixin, ReActAgent): memory_compact_threshold=memory_compact_threshold, ) + # Initialize team infrastructure if available + self._messenger: Optional["AgentMessenger"] = None + self._task_delegator: Optional["TaskDelegator"] = None + if TEAM_INFRA_AVAILABLE: + self._init_team_infrastructure() + logger.info( "EvoAgent initialized: %s (workspace: %s)", agent_id, @@ -268,6 +285,17 @@ class EvoAgent(ToolGuardMixin, ReActAgent): ) logger.debug("Registered memory compaction hook") + # Workspace watch hook - auto-reload markdown files on change + workspace_watch_hook = WorkspaceWatchHook( + workspace_dir=self.workspace_dir, + ) + self._hook_manager.register( + hook_type=HOOK_PRE_REASONING, + hook_name="workspace_watch", + hook=workspace_watch_hook, + ) + logger.debug("Registered workspace watch hook") + async def _reasoning(self, **kwargs) -> Msg: """Override reasoning to execute pre-reasoning hooks. @@ -405,7 +433,78 @@ class EvoAgent(ToolGuardMixin, ReActAgent): ) ]), "registered_hooks": self._hook_manager.list_hooks(), + "team_infra_available": TEAM_INFRA_AVAILABLE, } + def _init_team_infrastructure(self) -> None: + """Initialize team infrastructure components (messenger and task delegator). + + This method initializes the AgentMessenger for inter-agent communication + and the TaskDelegator for subagent delegation. + """ + if not TEAM_INFRA_AVAILABLE: + return + + try: + self._messenger = AgentMessenger(agent_id=self.agent_id) + self._task_delegator = TaskDelegator(agent=self) + logger.debug( + "Team infrastructure initialized for agent: %s", + self.agent_id, + ) + except Exception as e: + logger.warning( + "Failed to initialize team infrastructure for %s: %s", + self.agent_id, + e, + ) + self._messenger = None + self._task_delegator = None + + @property + def messenger(self) -> Optional["AgentMessenger"]: + """Get the agent's messenger for inter-agent communication. + + Returns: + AgentMessenger instance if available, None otherwise + """ + return self._messenger + + def delegate_task( + self, + task_type: str, + task_data: Dict[str, Any], + target_agent: Optional[str] = None, + ) -> Dict[str, Any]: + """Delegate a task to a subagent using the TaskDelegator. + + Args: + task_type: Type of task to delegate + task_data: Data/payload for the task + target_agent: Optional specific agent ID to delegate to + + Returns: + Dict containing the delegation result + """ + if not TEAM_INFRA_AVAILABLE or self._task_delegator is None: + return { + "success": False, + "error": "Team infrastructure not available", + } + + try: + return self._task_delegator.delegate_task( + task_type=task_type, + task_data=task_data, + target_agent=target_agent, + ) + except Exception as e: + logger.error( + "Task delegation failed for %s: %s", + self.agent_id, + e, + ) + return {"success": False, "error": str(e)} + __all__ = ["EvoAgent"] diff --git a/backend/agents/base/hooks.py b/backend/agents/base/hooks.py index 97ab7bd..387e5c2 100644 --- a/backend/agents/base/hooks.py +++ b/backend/agents/base/hooks.py @@ -284,19 +284,120 @@ class BootstrapHook(Hook): return None +class WorkspaceWatchHook(Hook): + """Hook for auto-reloading workspace markdown files on change. + + Monitors SOUL.md, AGENTS.md, PROFILE.md, etc. and triggers + a prompt rebuild when any of them change. Based on CoPaw's + AgentConfigWatcher approach but for markdown files. + """ + + # Files to monitor (same as PromptBuilder.DEFAULT_FILES) + WATCHED_FILES = frozenset([ + "SOUL.md", "AGENTS.md", "PROFILE.md", "ROLE.md", + "POLICY.md", "MEMORY.md", "HEARTBEAT.md", "STYLE.md", + "BOOTSTRAP.md", + ]) + + def __init__( + self, + workspace_dir: Path, + poll_interval: float = 2.0, + ): + """Initialize workspace watch hook. + + Args: + workspace_dir: Workspace directory to monitor + poll_interval: How often to check for changes (seconds) + """ + self.workspace_dir = Path(workspace_dir) + self.poll_interval = poll_interval + self._last_mtimes: dict[str, float] = {} + self._initialized = False + + def _scan_mtimes(self) -> dict[str, float]: + """Scan watched files and return their current mtimes.""" + mtimes = {} + for name in self.WATCHED_FILES: + path = self.workspace_dir / name + if path.exists(): + mtimes[name] = path.stat().st_mtime + return mtimes + + def _has_changes(self) -> bool: + """Check if any watched file has changed since last check.""" + current = self._scan_mtimes() + + if not self._initialized: + self._last_mtimes = current + self._initialized = True + return False + + # Check for new, modified, or deleted files + if set(current.keys()) != set(self._last_mtimes.keys()): + self._last_mtimes = current + return True + + for name, mtime in current.items(): + if mtime != self._last_mtimes.get(name): + self._last_mtimes = current + return True + + return False + + async def __call__( + self, + agent: "ReActAgent", + kwargs: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """Check for file changes and rebuild prompt if needed. + + Args: + agent: The agent instance + kwargs: Input arguments (unused) + + Returns: + None + """ + try: + if self._has_changes(): + logger.info( + "Workspace files changed, triggering prompt rebuild for: %s", + getattr(agent, "agent_id", "unknown"), + ) + if hasattr(agent, "rebuild_sys_prompt"): + agent.rebuild_sys_prompt() + else: + logger.warning( + "Agent %s has no rebuild_sys_prompt method", + getattr(agent, "agent_id", "unknown"), + ) + except Exception as e: + logger.error("Workspace watch hook failed: %s", e, exc_info=True) + + return None + + class MemoryCompactionHook(Hook): """Hook for automatic memory compaction when context is full. This hook monitors the token count of messages and triggers compaction when it exceeds the threshold. It preserves the system prompt and recent messages while summarizing older conversation history. + + Based on CoPaw's memory compaction design with additional improvements: + - memory_compact_ratio: Ratio to compact when threshold reached + - memory_reserve_ratio: Always keep a reserve of tokens for recent messages + - enable_tool_result_compact: Compact tool results separately + - tool_result_compact_keep_n: Number of tool results to keep """ def __init__( self, memory_manager: Any, memory_compact_threshold: Optional[int] = None, - memory_compact_reserve: Optional[int] = None, + memory_compact_ratio: float = 0.75, + memory_reserve_ratio: float = 0.1, enable_tool_result_compact: bool = False, tool_result_compact_keep_n: int = 5, ): @@ -305,13 +406,15 @@ class MemoryCompactionHook(Hook): Args: memory_manager: Memory manager instance for compaction memory_compact_threshold: Token threshold for compaction - memory_compact_reserve: Reserve tokens for recent messages + memory_compact_ratio: Target ratio to compact to (e.g., 0.75 = compact to 75%) + memory_reserve_ratio: Reserve ratio to always keep free (e.g., 0.1 = 10%) enable_tool_result_compact: Enable tool result compaction tool_result_compact_keep_n: Number of tool results to keep """ self.memory_manager = memory_manager self.memory_compact_threshold = memory_compact_threshold - self.memory_compact_reserve = memory_compact_reserve + self.memory_compact_ratio = memory_compact_ratio + self.memory_reserve_ratio = memory_reserve_ratio self.enable_tool_result_compact = enable_tool_result_compact self.tool_result_compact_keep_n = tool_result_compact_keep_n @@ -382,32 +485,61 @@ class MemoryCompactionHook(Hook): ) -> None: """Compact memory by summarizing older messages. + Uses CoPaw-style memory management: + - memory_compact_ratio: Target ratio to compact to (e.g., 0.75 means compact to 75%) + - memory_reserve_ratio: Always keep this ratio free (e.g., 0.1 means keep 10% for recent) + Args: agent: The agent instance messages: Current messages in memory """ - if self.memory_compact_reserve is None: + if self.memory_compact_threshold is None: return - # Keep recent messages - keep_count = min( - len(messages) // 4, - 10, # Max 10 recent messages - ) - keep_count = max(keep_count, 2) # At least 2 + # Estimate total tokens + total_tokens = self._estimate_tokens(messages) - messages_to_compact = messages[:-keep_count] if keep_count < len(messages) else [] + # Calculate reserve based on ratio (CoPaw-style) + reserve_tokens = int(total_tokens * self.memory_reserve_ratio) + + # Calculate target tokens after compaction + target_tokens = int(total_tokens * self.memory_compact_ratio) + target_tokens = max(target_tokens, total_tokens - reserve_tokens) + + # Find messages to compact (older ones) + # Keep recent messages that fit within target + messages_to_compact = [] + kept_tokens = 0 + + # Start from oldest, stop when we've kept enough + for msg in messages: + msg_tokens = self._estimate_tokens([msg]) + if kept_tokens + msg_tokens > target_tokens: + messages_to_compact.append(msg) + else: + kept_tokens += msg_tokens if not messages_to_compact: return + logger.info( + "Compacting %d messages (%d tokens) to target %d tokens", + len(messages_to_compact), + self._estimate_tokens(messages_to_compact), + target_tokens, + ) + # Use memory manager to compact if available if hasattr(self.memory_manager, "compact_memory"): try: summary = await self.memory_manager.compact_memory( messages=messages_to_compact, ) - logger.info("Memory compacted: %d messages summarized", len(messages_to_compact)) + logger.info( + "Memory compacted: %d messages summarized, summary: %s", + len(messages_to_compact), + summary[:200] if summary else "N/A", + ) # Mark messages as compressed if supported if hasattr(agent.memory, "update_messages_mark"): @@ -420,6 +552,142 @@ class MemoryCompactionHook(Hook): except Exception as e: logger.error("Memory manager compaction failed: %s", e) + # Tool result compaction (CoPaw-style) + if self.enable_tool_result_compact: + await self._compact_tool_results(agent, messages) + + async def _compact_tool_results( + self, + agent: "ReActAgent", + messages: List[Any], + ) -> None: + """Compact tool results by keeping only recent ones. + + Based on CoPaw's tool_result_compact_keep_n pattern. + Tool results can be very verbose, so we keep only the N most recent ones. + + Args: + agent: The agent instance + messages: Current messages in memory + """ + if not hasattr(agent.memory, "content"): + return + + # Find tool result messages (usually have "tool" role or tool_related content) + tool_results = [] + for msg, _ in agent.memory.content: + if hasattr(msg, "role") and msg.role == "tool": + tool_results.append(msg) + + if len(tool_results) <= self.tool_result_compact_keep_n: + return + + # Keep only the most recent N tool results + excess_results = tool_results[:-self.tool_result_compact_keep_n] + + logger.info( + "Tool result compaction: %d tool results found, keeping %d, compacting %d", + len(tool_results), + self.tool_result_compact_keep_n, + len(excess_results), + ) + + # Mark excess tool results as compressed if supported + if hasattr(agent.memory, "update_messages_mark"): + from agentscope.agent._react_agent import _MemoryMark + await agent.memory.update_messages_mark( + new_mark=_MemoryMark.COMPRESSED, + msg_ids=[msg.id for msg in excess_results], + ) + + +class HeartbeatHook(Hook): + """Pre-reasoning hook that injects HEARTBEAT.md content. + + Reads the agent's HEARTBEAT.md file and prepends it to the + reasoning input, causing the agent to perform self-checks. + + This enables "主动检查" (proactive monitoring) - periodic + market condition and position checks during trading hours. + """ + + HEARTBEAT_FILE = "HEARTBEAT.md" + + def __init__(self, workspace_dir: Path): + """Initialize heartbeat hook. + + Args: + workspace_dir: Working directory containing HEARTBEAT.md + """ + self.workspace_dir = Path(workspace_dir) + self._completed_flag = self.workspace_dir / ".heartbeat_completed" + + def _read_heartbeat_content(self) -> Optional[str]: + """Read HEARTBEAT.md if it exists and is non-empty. + + Returns: + The HEARTBEAT.md content stripped of whitespace, or None + if the file is absent or empty. + """ + hb_path = self.workspace_dir / self.HEARTBEAT_FILE + if not hb_path.exists(): + return None + content = hb_path.read_text(encoding="utf-8").strip() + return content if content else None + + async def __call__( + self, + agent: "ReActAgent", + kwargs: Dict[str, Any], + ) -> Optional[Dict[str, Any]]: + """Prepend heartbeat task to user message. + + Args: + agent: The agent instance + kwargs: Input arguments to the _reasoning method + + Returns: + Modified kwargs with heartbeat content prepended, or None + if no HEARTBEAT.md content is available. + """ + try: + content = self._read_heartbeat_content() + if not content: + return None + + logger.debug( + "Heartbeat: found HEARTBEAT.md for agent %s", + getattr(agent, "agent_id", "unknown"), + ) + + # Build heartbeat task instruction (Chinese) + hb_task = ( + "# 定期主动检查\n\n" + f"{content}\n\n" + "请执行上述检查并报告结果。" + ) + + # Inject into the first user message in memory + if hasattr(agent, "memory") and agent.memory.content: + system_count = sum( + 1 for msg, _ in agent.memory.content if msg.role == "system" + ) + for msg, _ in agent.memory.content[system_count:]: + if msg.role == "user": + original_content = msg.content + msg.content = hb_task + "\n\n" + original_content + break + + logger.debug( + "Heartbeat task prepended for agent %s", + getattr(agent, "agent_id", "unknown"), + ) + + except Exception as e: + logger.error("Heartbeat hook failed: %s", e, exc_info=True) + + return None + __all__ = [ "Hook", @@ -428,5 +696,7 @@ __all__ = [ "HOOK_PRE_REASONING", "HOOK_POST_ACTING", "BootstrapHook", + "HeartbeatHook", "MemoryCompactionHook", + "WorkspaceWatchHook", ] diff --git a/backend/agents/factory.py b/backend/agents/factory.py index ac3db82..28794c5 100644 --- a/backend/agents/factory.py +++ b/backend/agents/factory.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Agent Factory - Dynamic creation and management of EvoAgents.""" +import logging import shutil from dataclasses import dataclass from pathlib import Path @@ -8,6 +9,8 @@ from typing import Any, Dict, List, Optional import yaml +logger = logging.getLogger(__name__) + @dataclass class ModelConfig: @@ -342,9 +345,8 @@ class AgentFactory: "agent_type": config.get("agent_type", "unknown"), "config_path": str(config_path), }) - except Exception: - # Skip invalid agent configs - pass + except Exception as e: + logger.warning(f"Failed to load agent config {config_path}: {e}") return agents diff --git a/backend/agents/portfolio_manager.py b/backend/agents/portfolio_manager.py index 7a3444b..473f2d4 100644 --- a/backend/agents/portfolio_manager.py +++ b/backend/agents/portfolio_manager.py @@ -4,7 +4,8 @@ Portfolio Manager Agent - Based on AgentScope ReActAgent Responsible for decision-making (NOT trade execution) """ -from typing import Any, Dict, Optional +from pathlib import Path +from typing import Any, Dict, Optional, Callable from agentscope.agent import ReActAgent from agentscope.memory import InMemoryMemory, LongTermMemoryBase @@ -13,6 +14,8 @@ from agentscope.tool import Toolkit, ToolResponse from ..utils.progress import progress from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache +from .team_pipeline_config import update_active_analysts +from ..config.constants import ANALYST_TYPES class PMAgent(ReActAgent): @@ -61,6 +64,8 @@ class PMAgent(ReActAgent): "_toolkit_factory_kwargs", toolkit_factory_kwargs, ) + object.__setattr__(self, "_create_team_agent_cb", None) + object.__setattr__(self, "_remove_team_agent_cb", None) # Create toolkit after local state is ready so bound tool methods can be registered. if toolkit is None: @@ -152,6 +157,107 @@ class PMAgent(ReActAgent): ], ) + def _add_team_analyst(self, agent_id: str) -> ToolResponse: + """Add one analyst to active discussion team.""" + config_name = self.config.get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + active = update_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=list(ANALYST_TYPES.keys()), + add=[agent_id], + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=( + f"Active analyst team updated. Added: {agent_id}. " + f"Current active analysts: {', '.join(active)}" + ), + ), + ], + ) + + def _remove_team_analyst(self, agent_id: str) -> ToolResponse: + """Remove one analyst from active discussion team.""" + callback_msg = "" + callback = self._remove_team_agent_cb + if callback is not None: + callback_msg = callback(agent_id=agent_id) + + config_name = self.config.get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + active = update_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=list(ANALYST_TYPES.keys()), + remove=[agent_id], + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=( + f"Active analyst team updated. Removed: {agent_id}. " + f"Current active analysts: {', '.join(active)}" + + (f" | {callback_msg}" if callback_msg else "") + ), + ), + ], + ) + + def _set_active_analysts(self, agent_ids: str) -> ToolResponse: + """Set active analysts from comma-separated agent ids.""" + requested = [ + item.strip() for item in str(agent_ids or "").split(",") if item.strip() + ] + config_name = self.config.get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + active = update_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=list(ANALYST_TYPES.keys()), + set_to=requested, + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=f"Active analyst team set to: {', '.join(active)}", + ), + ], + ) + + def _create_team_analyst(self, agent_id: str, analyst_type: str) -> ToolResponse: + """Create a runtime analyst instance and activate it.""" + callback = self._create_team_agent_cb + if callback is None: + return ToolResponse( + content=[ + TextBlock( + type="text", + text="Runtime agent creation is not available in current pipeline.", + ), + ], + ) + result = callback(agent_id=agent_id, analyst_type=analyst_type) + return ToolResponse( + content=[ + TextBlock(type="text", text=result), + ], + ) + + def set_team_controller( + self, + *, + create_agent_callback: Optional[Callable[..., str]] = None, + remove_agent_callback: Optional[Callable[..., str]] = None, + ) -> None: + """Inject runtime team lifecycle callbacks from pipeline.""" + object.__setattr__(self, "_create_team_agent_cb", create_agent_callback) + object.__setattr__(self, "_remove_team_agent_cb", remove_agent_callback) + async def reply(self, x: Msg = None) -> Msg: """ Make investment decisions diff --git a/backend/agents/prompt_factory.py b/backend/agents/prompt_factory.py index 187b83f..6e9431d 100644 --- a/backend/agents/prompt_factory.py +++ b/backend/agents/prompt_factory.py @@ -50,7 +50,13 @@ def build_agent_system_prompt( toolkit: Any, analyst_type: Optional[str] = None, ) -> str: - """Build the final system prompt for an agent.""" + """Build the final system prompt for an agent. + + Always reads fresh from disk — no caching. + """ + # Clear any cached templates before building (CoPaw-style, no caching) + _prompt_loader.clear_cache() + sections: list[str] = [] canonical_agent_id = ( "portfolio_manager" diff --git a/backend/agents/prompt_loader.py b/backend/agents/prompt_loader.py index af649bc..53ca308 100644 --- a/backend/agents/prompt_loader.py +++ b/backend/agents/prompt_loader.py @@ -27,10 +27,6 @@ class PromptLoader: else: self.prompts_dir = Path(prompts_dir) - # Cache loaded prompts - self._prompt_cache: Dict[str, str] = {} - self._yaml_cache: Dict[str, Dict] = {} - def load_prompt( self, agent_type: str, @@ -38,37 +34,20 @@ class PromptLoader: variables: Optional[Dict[str, Any]] = None, ) -> str: """ - Load and render Prompt + Load and render Prompt. - Args: - agent_type: Agent type (analyst, portfolio_manager, risk_manager) - prompt_name: Prompt file name (without extension) - variables: Variable dictionary for rendering Prompt - - Returns: - Rendered prompt string - - Examples: - loader = PromptLoader() - prompt = loader.load_prompt("analyst", "tool_selection", - {"analyst_persona": "Technical Analyst"}) + No caching — always reads fresh from disk (CoPaw-style). """ - cache_key = f"{agent_type}/{prompt_name}" + prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md" - # Try to load from cache - if cache_key not in self._prompt_cache: - prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"Prompt file not found: {prompt_path}\n" + f"Please create the prompt file or check the path.", + ) - if not prompt_path.exists(): - raise FileNotFoundError( - f"Prompt file not found: {prompt_path}\n" - f"Please create the prompt file or check the path.", - ) - - with open(prompt_path, "r", encoding="utf-8") as f: - self._prompt_cache[cache_key] = f.read() - - prompt_template = self._prompt_cache[cache_key] + with open(prompt_path, "r", encoding="utf-8") as f: + prompt_template = f.read() # If variables provided, use simple string replacement if variables: @@ -76,8 +55,6 @@ class PromptLoader: else: rendered = prompt_template - # Smart escaping: escape braces in JSON code blocks - # rendered = self._escape_json_braces(rendered) return rendered def _render_template( @@ -140,45 +117,26 @@ class PromptLoader: config_name: str, ) -> Dict[str, Any]: """ - Load YAML configuration file + Load YAML configuration file. - Args: - agent_type: Agent type - config_name: Configuration file name (without extension) - - Returns: - Configuration dictionary - - Examples: - >>> loader = PromptLoader() - >>> config = loader.load_yaml_config("analyst", "personas") + No caching — always reads fresh from disk (CoPaw-style). """ - cache_key = f"{agent_type}/{config_name}" + yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml" - if cache_key not in self._yaml_cache: - yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml" + if not yaml_path.exists(): + raise FileNotFoundError(f"YAML config not found: {yaml_path}") - if not yaml_path.exists(): - raise FileNotFoundError(f"YAML config not found: {yaml_path}") - - with open(yaml_path, "r", encoding="utf-8") as f: - self._yaml_cache[cache_key] = yaml.safe_load(f) - - return self._yaml_cache[cache_key] + with open(yaml_path, "r", encoding="utf-8") as f: + return yaml.safe_load(f) or {} def clear_cache(self): - """Clear cache (for hot reload)""" - self._prompt_cache.clear() - self._yaml_cache.clear() + """No-op — caching removed (CoPaw-style, always fresh reads).""" + pass def reload_prompt(self, agent_type: str, prompt_name: str): - """Reload specified prompt (force cache refresh)""" - cache_key = f"{agent_type}/{prompt_name}" - if cache_key in self._prompt_cache: - del self._prompt_cache[cache_key] + """No-op — caching removed.""" + pass def reload_config(self, agent_type: str, config_name: str): - """Reload specified configuration (force cache refresh)""" - cache_key = f"{agent_type}/{config_name}" - if cache_key in self._yaml_cache: - del self._yaml_cache[cache_key] + """No-op — caching removed.""" + pass diff --git a/backend/agents/skill_metadata.py b/backend/agents/skill_metadata.py index b5fda7b..8175ee1 100644 --- a/backend/agents/skill_metadata.py +++ b/backend/agents/skill_metadata.py @@ -19,6 +19,8 @@ class SkillMetadata: description: str version: str = "" tools: List[str] = field(default_factory=list) + allowed_tools: List[str] = field(default_factory=list) + denied_tools: List[str] = field(default_factory=list) def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata: @@ -60,6 +62,8 @@ def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata: description=description, version=str(frontmatter.get("version") or "").strip(), tools=_string_list(frontmatter.get("tools")), + allowed_tools=_string_list(frontmatter.get("allowed_tools")), + denied_tools=_string_list(frontmatter.get("denied_tools")), ) diff --git a/backend/agents/skills_manager.py b/backend/agents/skills_manager.py index 4e522a8..e922400 100644 --- a/backend/agents/skills_manager.py +++ b/backend/agents/skills_manager.py @@ -3,14 +3,29 @@ from pathlib import Path import shutil -from typing import Dict, Iterable, List +import tempfile +import zipfile +from typing import Any, Dict, Iterable, Iterator, List, Optional, Set +from urllib.parse import urlparse +from urllib.request import urlretrieve import yaml from backend.agents.agent_workspace import load_agent_workspace_config from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata +from backend.agents.skill_loader import validate_skill from backend.config.bootstrap_config import get_bootstrap_config_for_run +try: + from watchdog.observers import Observer + from watchdog.events import FileSystemEventHandler, FileSystemEvent + WATCHDOG_AVAILABLE = True +except ImportError: + WATCHDOG_AVAILABLE = False + Observer = None + FileSystemEventHandler = object + FileSystemEvent = object # type: ignore[misc,assignment] + class SkillsManager: """Sync named skills into a run-scoped active skills workspace.""" @@ -178,6 +193,57 @@ class SkillsManager: ) return skill_dir + def install_external_skill_for_agent( + self, + config_name: str, + agent_id: str, + source: str, + *, + skill_name: str | None = None, + activate: bool = True, + ) -> Dict[str, object]: + """ + Install an external skill into one agent's local skill space. + + Supports: + - local skill directory containing SKILL.md + - local zip archive containing one skill directory + - http(s) URL to zip archive + """ + source_path = self._resolve_external_source_path(source) + skill_dir = self._resolve_external_skill_dir(source_path) + metadata = parse_skill_metadata(skill_dir, source="external") + final_name = _normalize_skill_name(skill_name or metadata.skill_name or skill_dir.name) + if not final_name: + raise ValueError("Could not determine skill name from external source.") + + target_dir = self.get_agent_local_root(config_name, agent_id) / final_name + target_dir.parent.mkdir(parents=True, exist_ok=True) + if target_dir.exists(): + shutil.rmtree(target_dir) + shutil.copytree(skill_dir, target_dir) + + validation = validate_skill(target_dir) + if not validation.get("valid", False): + shutil.rmtree(target_dir, ignore_errors=True) + raise ValueError( + "Installed skill is invalid: " + + "; ".join(validation.get("errors", [])) + ) + + if activate: + self.update_agent_skill_overrides( + config_name=config_name, + agent_id=agent_id, + enable=[final_name], + ) + return { + "skill_name": final_name, + "target_dir": str(target_dir), + "activated": activate, + "warnings": validation.get("warnings", []), + } + def update_agent_local_skill( self, config_name: str, @@ -239,6 +305,58 @@ class SkillsManager: "content": body, } + def _resolve_external_source_path(self, source: str) -> Path: + """Resolve source into a local path; download URL when needed.""" + parsed = urlparse(source) + if parsed.scheme in {"http", "https"}: + suffix = Path(parsed.path).suffix or ".zip" + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + temp_path = Path(tmp.name) + urlretrieve(source, temp_path) + return temp_path + return Path(source).expanduser().resolve() + + def _resolve_external_skill_dir(self, source_path: Path) -> Path: + """Resolve external source path to a skill directory containing SKILL.md.""" + if not source_path.exists(): + raise FileNotFoundError(f"Source does not exist: {source_path}") + + if source_path.is_dir(): + if (source_path / "SKILL.md").exists(): + return source_path + children = [ + item for item in source_path.iterdir() + if item.is_dir() and (item / "SKILL.md").exists() + ] + if len(children) == 1: + return children[0] + raise ValueError( + "Source directory must contain SKILL.md " + "or exactly one child directory containing SKILL.md." + ) + + if source_path.suffix.lower() != ".zip": + raise ValueError("External source file must be a .zip archive.") + + temp_root = Path(tempfile.mkdtemp(prefix="external_skill_")) + with zipfile.ZipFile(source_path, "r") as archive: + archive.extractall(temp_root) + + candidates = [ + item.parent + for item in temp_root.rglob("SKILL.md") + if item.is_file() + ] + unique = [] + for item in candidates: + if item not in unique: + unique.append(item) + if len(unique) != 1: + raise ValueError( + "Zip archive must contain exactly one skill directory with SKILL.md." + ) + return unique[0] + def update_agent_skill_overrides( self, config_name: str, @@ -500,6 +618,7 @@ class SkillsManager: self, config_name: str, agent_defaults: Dict[str, Iterable[str]], + auto_reload: bool = False, ) -> Dict[str, List[Path]]: """Resolve all agent skills into per-agent installed/active workspaces.""" resolved: Dict[str, List[str]] = {} @@ -574,6 +693,9 @@ class SkillsManager: skill_sources=disabled_sources, ) + if auto_reload: + self.watch_active_skills(config_name, agent_defaults) + return active_map def _is_shared_skill(self, skill_name: str) -> bool: @@ -583,6 +705,72 @@ class SkillsManager: return False return True + def watch_active_skills( + self, + config_name: str, + agent_defaults: Dict[str, Iterable[str]], + callback: Optional[Any] = None, + ) -> "_SkillsWatcher": + """Start file system monitoring on active skill directories. + + Args: + config_name: Run configuration name. + agent_defaults: Map of agent_id -> default skill names. + callback: Optional callable invoked on file changes with + (changed_paths: List[Path]). + + Returns: + A _SkillsWatcher instance. Call .stop() to halt monitoring. + """ + if not WATCHDOG_AVAILABLE: + raise ImportError( + "watchdog is required for watch_active_skills. " + "Install it with: pip install watchdog" + ) + + watched_paths: List[Path] = [] + for agent_id in agent_defaults: + active_root = self.get_agent_active_root(config_name, agent_id) + if active_root.exists(): + watched_paths.append(active_root) + local_root = self.get_agent_local_root(config_name, agent_id) + if local_root.exists(): + watched_paths.append(local_root) + + handler = _SkillsChangeHandler(watched_paths, callback) + observer = Observer() + for path in watched_paths: + observer.schedule(handler, str(path), recursive=True) + observer.start() + return _SkillsWatcher(observer, handler) + + def reload_skills_if_changed( + self, + config_name: str, + agent_defaults: Dict[str, Iterable[str]], + ) -> Dict[str, List[Path]]: + """Check for file changes and reload active skills if needed. + + Args: + config_name: Run configuration name. + agent_defaults: Map of agent_id -> default skill names. + + Returns: + Map of agent_id -> list of reloaded skill paths, or empty dict + if no changes were detected. + """ + changed = self._pending_skill_changes.get(config_name) + if not changed: + return {} + + self._pending_skill_changes[config_name] = set() + return self.prepare_active_skills(config_name, agent_defaults) + + # ------------------------------------------------------------------------- + # Internal change-tracking state (populated by _SkillsChangeHandler) + # ------------------------------------------------------------------------- + _pending_skill_changes: Dict[str, Set[Path]] = {} + def _resolve_disabled_skill_names( self, config_name: str, @@ -613,6 +801,53 @@ class SkillsManager: ] +class _SkillsWatcher: + """Handle returned by watch_active_skills; call .stop() to halt monitoring.""" + + def __init__(self, observer: Observer, handler: "_SkillsChangeHandler") -> None: + self._observer = observer + self._handler = handler + + def stop(self) -> None: + """Stop the file system observer.""" + self._observer.stop() + self._observer.join() + + +class _SkillsChangeHandler(FileSystemEventHandler): + """Collects file-change events on skill directories.""" + + def __init__( + self, + watched_paths: List[Path], + callback: Optional[Any] = None, + ) -> None: + super().__init__() + self._watched_paths = watched_paths + self._callback = callback + + def on_any_event(self, event: FileSystemEvent) -> None: + if event.is_directory: + return + src_path = Path(event.src_path) + for watched in self._watched_paths: + if src_path.is_relative_to(watched): + SkillsManager._pending_skill_changes.setdefault( + self._run_id_from_path(src_path), set() + ).add(src_path) + if self._callback: + self._callback([src_path]) + break + + @staticmethod + def _run_id_from_path(path: Path) -> str: + """Infer config_name from a path like runs/{config_name}/skills/active/...""" + parts = path.parts + for i, part in enumerate(parts): + if part == "runs" and i + 1 < len(parts): + return parts[i + 1] + return "default" + def _dedupe_preserve_order(items: Iterable[str]) -> List[str]: result: List[str] = [] for item in items: diff --git a/backend/agents/toolkit_factory.py b/backend/agents/toolkit_factory.py index 9feb442..ed0141e 100644 --- a/backend/agents/toolkit_factory.py +++ b/backend/agents/toolkit_factory.py @@ -5,7 +5,7 @@ 以及合并Agent特定工具。 """ -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Set from pathlib import Path import yaml @@ -13,6 +13,7 @@ import yaml from backend.agents.agent_workspace import load_agent_workspace_config from backend.agents.skills_manager import SkillsManager from backend.agents.skill_loader import load_skill_from_dir, get_skill_tools +from backend.agents.skill_metadata import parse_skill_metadata from backend.config.bootstrap_config import get_bootstrap_config_for_run @@ -117,6 +118,26 @@ def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None: pm_agent._make_decision, group_name="portfolio_ops", ) + if hasattr(pm_agent, "_add_team_analyst"): + toolkit.register_tool_function( + pm_agent._add_team_analyst, + group_name="portfolio_ops", + ) + if hasattr(pm_agent, "_remove_team_analyst"): + toolkit.register_tool_function( + pm_agent._remove_team_analyst, + group_name="portfolio_ops", + ) + if hasattr(pm_agent, "_set_active_analysts"): + toolkit.register_tool_function( + pm_agent._set_active_analysts, + group_name="portfolio_ops", + ) + if hasattr(pm_agent, "_create_team_analyst"): + toolkit.register_tool_function( + pm_agent._create_team_analyst, + group_name="portfolio_ops", + ) def _register_risk_tool_groups(toolkit: Any) -> None: @@ -223,6 +244,8 @@ def create_agent_toolkit( for skill_dir in active_skill_dirs: toolkit.register_agent_skill(str(skill_dir)) + apply_skill_tool_restrictions(toolkit, active_skill_dirs) + if active_groups: toolkit.update_tool_groups(group_names=active_groups, active=True) @@ -309,6 +332,8 @@ def create_toolkit_from_workspace( for skill_dir in skill_dirs: toolkit.register_agent_skill(str(skill_dir)) + apply_skill_tool_restrictions(toolkit, skill_dirs) + # 激活指定的工具组 if active_groups is None: # 从配置中读取 @@ -397,3 +422,96 @@ def refresh_toolkit_skills( for skill_dir in sorted(local_root.iterdir()): if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists(): toolkit.register_agent_skill(str(skill_dir)) + + +def apply_skill_tool_restrictions(toolkit: Any, skill_dirs: List[Path]) -> None: + """Apply per-skill allowed_tools / denied_tools restrictions to a toolkit. + + If a skill specifies allowed_tools, only those tools are accessible when + that skill is active. If a skill specifies denied_tools, those tools are + removed regardless of allowed_tools. Denied tools take precedence. + + This function annotates the toolkit with a _skill_tool_restrictions map + that downstream code can consult when resolving available tools. + + Args: + toolkit: The agentscope Toolkit instance. + skill_dirs: List of skill directory paths to inspect. + """ + restrictions: Dict[str, Dict[str, Set[str]]] = {} + for skill_dir in skill_dirs: + metadata = parse_skill_metadata(skill_dir, source="active") + if not metadata.allowed_tools and not metadata.denied_tools: + continue + restrictions[skill_dir.name] = { + "allowed": set(metadata.allowed_tools), + "denied": set(metadata.denied_tools), + } + if hasattr(toolkit, "agent_skills"): + for skill in toolkit.agent_skills: + skill_name = getattr(skill, "name", "") or "" + if skill_name in restrictions: + setattr( + skill, + "_tool_allowed", + restrictions[skill_name]["allowed"], + ) + setattr( + skill, + "_tool_denied", + restrictions[skill_name]["denied"], + ) + + +def get_skill_effective_tools(skill: Any) -> Optional[Set[str]]: + """Return the effective tool set for a skill after applying restrictions. + + If the skill has no restrictions (no allowed_tools / denied_tools), + returns None to indicate "all tools allowed". + + If allowed_tools is set, returns only those tools minus denied_tools. + If only denied_tools is set, returns all tools minus denied_tools. + + Args: + skill: A skill object previously registered via register_agent_skill. + + Returns: + A set of allowed tool names, or None if unrestricted. + """ + allowed = getattr(skill, "_tool_allowed", None) + denied = getattr(skill, "_tool_denied", set()) + + if allowed is None: + return None + + effective = allowed - denied + return effective + + +def filter_toolkit_by_skill( + toolkit: Any, + skill_name: str, +) -> Set[str]: + """Return the set of tool names that are accessible for a given skill. + + Args: + toolkit: The agentscope Toolkit instance. + skill_name: Name of the skill to query. + + Returns: + Set of allowed tool names, or all registered tool names if unrestricted. + """ + if not hasattr(toolkit, "agent_skills"): + return set() + + for skill in toolkit.agent_skills: + name = getattr(skill, "name", "") or "" + if name != skill_name: + continue + effective = get_skill_effective_tools(skill) + if effective is None: + return set() + return effective + + return set() + diff --git a/backend/agents/workspace.py b/backend/agents/workspace.py index df89677..2f3872b 100644 --- a/backend/agents/workspace.py +++ b/backend/agents/workspace.py @@ -1,12 +1,15 @@ # -*- coding: utf-8 -*- """Workspace Manager - Create and manage agent workspaces.""" +import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional import yaml +logger = logging.getLogger(__name__) + @dataclass class WorkspaceConfig: @@ -123,9 +126,8 @@ class WorkspaceRegistry: with open(config_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) or {} workspaces.append(WorkspaceConfig.from_dict(data)) - except Exception: - # Skip invalid workspace configs - pass + except Exception as e: + logger.warning(f"Failed to load workspace config {config_path}: {e}") return workspaces @@ -167,9 +169,8 @@ class WorkspaceRegistry: "agent_type": config.get("agent_type", "unknown"), "config_path": str(config_path), }) - except Exception: - # Skip invalid agent configs - pass + except Exception as e: + logger.warning(f"Failed to load agent config {config_path}: {e}") return agents @@ -294,8 +295,8 @@ class WorkspaceRegistry: try: with open(config_path, "r", encoding="utf-8") as f: current_config = yaml.safe_load(f) or {} - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to load existing config {config_path}: {e}") # Update fields if name is not None: diff --git a/backend/agents/workspace_manager.py b/backend/agents/workspace_manager.py index 043fe36..57509fc 100644 --- a/backend/agents/workspace_manager.py +++ b/backend/agents/workspace_manager.py @@ -7,6 +7,7 @@ from typing import Dict, Iterable, Optional import yaml from .skills_manager import SkillsManager +from .team_pipeline_config import ensure_team_pipeline_config class RunWorkspaceManager: @@ -23,6 +24,16 @@ class RunWorkspaceManager: run_dir = self.get_run_dir(config_name) run_dir.mkdir(parents=True, exist_ok=True) self.skills_manager.ensure_activation_manifest(config_name) + ensure_team_pipeline_config( + project_root=self.project_root, + config_name=config_name, + default_analysts=[ + "fundamentals_analyst", + "technical_analyst", + "sentiment_analyst", + "valuation_analyst", + ], + ) bootstrap_path = run_dir / "BOOTSTRAP.md" if not bootstrap_path.exists(): bootstrap_path.write_text( diff --git a/backend/api/agents.py b/backend/api/agents.py index 26ae94a..73a562a 100644 --- a/backend/api/agents.py +++ b/backend/api/agents.py @@ -4,15 +4,20 @@ Agent API Routes Provides REST API endpoints for agent management within workspaces. """ -from typing import Any, Dict, List, Optional +import logging +import os +import tempfile from pathlib import Path +from typing import Any, Dict, List, Optional -from fastapi import APIRouter, HTTPException, Depends, Body +from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form from pydantic import BaseModel, Field from backend.agents import AgentFactory, WorkspaceManager, get_registry from backend.agents.skills_manager import SkillsManager +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/api/workspaces/{workspace_id}/agents", tags=["agents"]) @@ -35,6 +40,13 @@ class UpdateAgentRequest(BaseModel): disabled_skills: Optional[List[str]] = None +class InstallExternalSkillRequest(BaseModel): + """Request to install an external skill for one agent.""" + source: str = Field(..., description="Directory path, zip path, or http(s) zip URL") + name: Optional[str] = Field(None, description="Optional override skill name") + activate: bool = Field(True, description="Whether to enable skill immediately") + + class AgentResponse(BaseModel): """Agent information response.""" agent_id: str @@ -344,6 +356,86 @@ async def disable_skill( } +@router.post("/{agent_id}/skills/install") +async def install_external_skill( + workspace_id: str, + agent_id: str, + request: InstallExternalSkillRequest, + registry=Depends(get_registry), +): + """Install an external skill into one agent's local skills.""" + agent_info = registry.get(agent_id) + if not agent_info or agent_info.workspace_id != workspace_id: + raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found") + + skills_manager = SkillsManager() + try: + result = skills_manager.install_external_skill_for_agent( + config_name=workspace_id, + agent_id=agent_id, + source=request.source, + skill_name=request.name, + activate=request.activate, + ) + except (FileNotFoundError, ValueError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + return { + "message": f"Installed external skill '{result['skill_name']}' for '{agent_id}'", + **result, + } + + +@router.post("/{agent_id}/skills/upload") +async def upload_external_skill( + workspace_id: str, + agent_id: str, + file: UploadFile = File(...), + name: Optional[str] = Form(None), + activate: bool = Form(True), + registry=Depends(get_registry), +): + """Upload a zip skill package from frontend and install for one agent.""" + agent_info = registry.get(agent_id) + if not agent_info or agent_info.workspace_id != workspace_id: + raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found") + + original_name = (file.filename or "").strip() + if not original_name.lower().endswith(".zip"): + raise HTTPException(status_code=400, detail="Uploaded file must be a .zip archive") + + suffix = Path(original_name).suffix or ".zip" + temp_path: Optional[str] = None + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: + temp_path = tmp.name + content = await file.read() + tmp.write(content) + + skills_manager = SkillsManager() + result = skills_manager.install_external_skill_for_agent( + config_name=workspace_id, + agent_id=agent_id, + source=temp_path, + skill_name=name, + activate=activate, + ) + except (FileNotFoundError, ValueError) as exc: + raise HTTPException(status_code=400, detail=str(exc)) + finally: + try: + await file.close() + except Exception as e: + logger.warning(f"Failed to close uploaded file: {e}") + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + return { + "message": f"Uploaded and installed external skill '{result['skill_name']}' for '{agent_id}'", + **result, + } + + @router.get("/{agent_id}/files/{filename}", response_model=AgentFileResponse) async def get_agent_file( workspace_id: str, diff --git a/backend/api/runtime.py b/backend/api/runtime.py index 57dae78..6f2b7a2 100644 --- a/backend/api/runtime.py +++ b/backend/api/runtime.py @@ -1,19 +1,25 @@ # -*- coding: utf-8 -*- -"""Runtime API routes exposing the latest trading run state.""" +"""Runtime API routes - Control Plane for managing Gateway processes.""" from __future__ import annotations import asyncio import json +import logging +import os +import signal +import subprocess +import sys from datetime import datetime from pathlib import Path from typing import Any, Dict, List, Optional +logger = logging.getLogger(__name__) + from fastapi import APIRouter, HTTPException, BackgroundTasks from pydantic import BaseModel, Field from backend.runtime.agent_runtime import AgentRuntimeState -from backend.runtime.context import TradingRunContext from backend.runtime.manager import TradingRuntimeManager, get_global_runtime_manager router = APIRouter(prefix="/api/runtime", tags=["runtime"]) @@ -21,9 +27,9 @@ router = APIRouter(prefix="/api/runtime", tags=["runtime"]) runtime_manager: Optional[TradingRuntimeManager] = None PROJECT_ROOT = Path(__file__).resolve().parents[2] -# Global task reference for running pipeline -_running_task: Optional[asyncio.Task] = None -_stop_event: Optional[asyncio.Event] = None +# Gateway process management +_gateway_process: Optional[subprocess.Popen] = None +_gateway_port: int = 8765 class RunContextResponse(BaseModel): @@ -67,12 +73,15 @@ class LaunchConfig(BaseModel): mode: str = Field(default="live", description="运行模式: live, backtest") start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD") end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD") + poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)") + enable_mock: bool = Field(default=False, description="是否启用模拟模式(使用模拟价格数据)") class LaunchResponse(BaseModel): run_id: str status: str run_dir: str + gateway_port: int message: str @@ -81,10 +90,10 @@ class StopResponse(BaseModel): message: str -class RestartResponse(BaseModel): - run_id: str - status: str - message: str +class GatewayStatusResponse(BaseModel): + is_running: bool + port: int + run_id: Optional[str] = None def _generate_run_id() -> str: @@ -97,44 +106,92 @@ def _get_run_dir(run_id: str) -> Path: return PROJECT_ROOT / "runs" / run_id -def _latest_snapshot_path() -> Optional[Path]: - candidates = sorted( - PROJECT_ROOT.glob("runs/*/state/runtime_state.json"), - key=lambda path: path.stat().st_mtime, - reverse=True, +def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int: + """Find an available port for Gateway.""" + import socket + for port in range(start_port, max_port): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + if s.connect_ex(('localhost', port)) != 0: + return port + raise RuntimeError("No available port found") + + +def _is_gateway_running() -> bool: + """Check if Gateway process is running.""" + global _gateway_process + if _gateway_process is None: + return False + return _gateway_process.poll() is None + + +def _stop_gateway() -> bool: + """Stop the Gateway process.""" + global _gateway_process + if _gateway_process is None: + return False + + try: + # Try graceful shutdown first + _gateway_process.terminate() + try: + _gateway_process.wait(timeout=5) + except subprocess.TimeoutExpired: + # Force kill if graceful shutdown fails + _gateway_process.kill() + _gateway_process.wait() + except Exception as e: + logger.warning(f"Error during gateway shutdown: {e}") + finally: + _gateway_process = None + + return True + + +def _start_gateway_process( + run_id: str, + run_dir: Path, + bootstrap: Dict[str, Any], + port: int +) -> subprocess.Popen: + """Start Gateway as a separate process.""" + # Prepare environment + env = os.environ.copy() + + # Create command arguments + cmd = [ + sys.executable, + "-m", "backend.gateway_server", + "--run-id", run_id, + "--run-dir", str(run_dir), + "--port", str(port), + "--bootstrap", json.dumps(bootstrap) + ] + + # Start process + process = subprocess.Popen( + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=PROJECT_ROOT ) - return candidates[0] if candidates else None - -def _load_snapshot() -> Dict[str, Any]: - snapshot_path = _latest_snapshot_path() - if snapshot_path is None or not snapshot_path.exists(): - raise HTTPException(status_code=503, detail="runtime manager is not initialized") - return json.loads(snapshot_path.read_text(encoding="utf-8")) - - -def _get_runtime_payload() -> Dict[str, Any]: - if runtime_manager is not None: - return runtime_manager.build_snapshot() - return _load_snapshot() - - -def _to_state_response(state: AgentRuntimeState) -> RuntimeAgentState: - return RuntimeAgentState( - agent_id=state.agent_id, - status=state.status, - last_session=state.last_session, - last_updated=state.last_updated.isoformat(), - ) + return process @router.get("/context", response_model=RunContextResponse) async def get_run_context() -> RunContextResponse: """Return the most recent run context.""" - payload = _get_runtime_payload() - context = payload.get("context") + snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") + snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) + + if not snapshots: + raise HTTPException(status_code=404, detail="No run context available") + + latest = json.loads(snapshots[0].read_text(encoding="utf-8")) + context = latest.get("context") if context is None: - raise HTTPException(status_code=404, detail="run context is not ready") + raise HTTPException(status_code=404, detail="Run context is not ready") return RunContextResponse( config_name=context["config_name"], @@ -144,88 +201,74 @@ async def get_run_context() -> RunContextResponse: @router.get("/agents", response_model=RuntimeAgentsResponse) -async def list_agent_states() -> RuntimeAgentsResponse: - """List the current runtime state of every registered agent.""" - payload = _get_runtime_payload() - agents = [RuntimeAgentState(**agent) for agent in payload.get("agents", [])] - return RuntimeAgentsResponse(agents=agents) +async def get_runtime_agents() -> RuntimeAgentsResponse: + """Return agent states from the most recent run.""" + snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") + snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) + + if not snapshots: + raise HTTPException(status_code=404, detail="No runtime state available") + + latest = json.loads(snapshots[0].read_text(encoding="utf-8")) + agents = latest.get("agents", []) + + return RuntimeAgentsResponse( + agents=[RuntimeAgentState(**a) for a in agents] + ) @router.get("/events", response_model=RuntimeEventsResponse) -async def list_runtime_events() -> RuntimeEventsResponse: - """Return the recent runtime events that TradingRuntimeManager emitted.""" - payload = _get_runtime_payload() - events = [RuntimeEvent(**event) for event in payload.get("events", [])] - return RuntimeEventsResponse(events=events) +async def get_runtime_events() -> RuntimeEventsResponse: + """Return events from the most recent run.""" + snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") + snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) + if not snapshots: + raise HTTPException(status_code=404, detail="No runtime state available") -@router.get("/agents/{agent_id}", response_model=RuntimeAgentState) -async def get_agent_state(agent_id: str) -> RuntimeAgentState: - """Return the current runtime state for a single agent.""" - payload = _get_runtime_payload() - state = next( - (agent for agent in payload.get("agents", []) if agent["agent_id"] == agent_id), - None, + latest = json.loads(snapshots[0].read_text(encoding="utf-8")) + events = latest.get("events", []) + + return RuntimeEventsResponse( + events=[RuntimeEvent(**e) for e in events] ) - if state is None: - raise HTTPException(status_code=404, detail=f"agent '{agent_id}' not registered") - return RuntimeAgentState(**state) -def register_runtime_manager(manager: TradingRuntimeManager) -> None: - """Allow other modules to expose the runtime manager to the API.""" - global runtime_manager - runtime_manager = manager +@router.get("/gateway/status", response_model=GatewayStatusResponse) +async def get_gateway_status() -> GatewayStatusResponse: + """Get Gateway process status and port.""" + global _gateway_port + is_running = _is_gateway_running() + run_id = None -def unregister_runtime_manager() -> None: - """Drop the runtime manager reference (used for shutdown/testing).""" - global runtime_manager - runtime_manager = None - - -async def _stop_current_runtime(force: bool = True) -> bool: - """Stop the current running runtime if exists. - - Args: - force: If True, cancel the running task immediately - - Returns: - True if a runtime was stopped, False if no runtime was running - """ - global _running_task, _stop_event - - # Signal stop - if _stop_event is not None: - _stop_event.set() - - # Cancel running task - if _running_task is not None and not _running_task.done(): - if force: - _running_task.cancel() + if is_running: + # Try to find run_id from runtime state + snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") + snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) + if snapshots: try: - await _running_task - except asyncio.CancelledError: - pass - else: - # Wait for graceful shutdown - try: - await asyncio.wait_for(_running_task, timeout=30.0) - except asyncio.TimeoutError: - _running_task.cancel() - try: - await _running_task - except asyncio.CancelledError: - pass + latest = json.loads(snapshots[0].read_text(encoding="utf-8")) + run_id = latest.get("context", {}).get("config_name") + except Exception as e: + logger.warning(f"Failed to parse latest snapshot: {e}") - _running_task = None - _stop_event = None + return GatewayStatusResponse( + is_running=is_running, + port=_gateway_port, + run_id=run_id + ) - # Unregister runtime manager - if runtime_manager is not None: - unregister_runtime_manager() - return True +@router.get("/gateway/port") +async def get_gateway_port() -> Dict[str, Any]: + """Get WebSocket Gateway port for frontend connection.""" + global _gateway_port + return { + "port": _gateway_port, + "is_running": _is_gateway_running(), + "ws_url": f"ws://localhost:{_gateway_port}" + } @router.post("/start", response_model=LaunchResponse) @@ -235,13 +278,18 @@ async def start_runtime( ) -> LaunchResponse: """Start a new trading runtime with the given configuration. - If a runtime is already running, it will be forcefully stopped first. - Creates a new timestamped run directory. + 1. Stop existing Gateway if running + 2. Generate run ID and directory + 3. Create runtime manager + 4. Start Gateway as subprocess (Data Plane) + 5. Return Gateway port for WebSocket connection """ - global _running_task, _stop_event, runtime_manager + global _gateway_process, _gateway_port - # 1. Stop current runtime if exists - await _stop_current_runtime(force=True) + # 1. Stop existing Gateway + if _is_gateway_running(): + _stop_gateway() + await asyncio.sleep(1) # Wait for port release # 2. Generate run ID and directory run_id = _generate_run_id() @@ -260,92 +308,136 @@ async def start_runtime( "mode": config.mode, "start_date": config.start_date, "end_date": config.end_date, + "poll_interval": config.poll_interval, + "enable_mock": config.enable_mock, } - # 4. Create and prepare runtime manager - runtime_manager = TradingRuntimeManager( + # 4. Create runtime manager + manager = TradingRuntimeManager( config_name=run_id, run_dir=run_dir, bootstrap=bootstrap, ) - runtime_manager.prepare_run() - set_global_runtime_manager = None # Will be set by main module + manager.prepare_run() + register_runtime_manager(manager) # 5. Write BOOTSTRAP.md _write_bootstrap_md(run_dir, bootstrap) - # 6. Start pipeline in background - _stop_event = asyncio.Event() - _running_task = asyncio.create_task( - _run_pipeline(run_id, run_dir, bootstrap, _stop_event) - ) + # 6. Find available port and start Gateway process + _gateway_port = _find_available_port(start_port=8765) + + try: + _gateway_process = _start_gateway_process( + run_id=run_id, + run_dir=run_dir, + bootstrap=bootstrap, + port=_gateway_port + ) + + # Wait briefly to check if process started successfully + await asyncio.sleep(2) + + if not _is_gateway_running(): + stdout, stderr = _gateway_process.communicate(timeout=1) + _gateway_process = None + raise HTTPException( + status_code=500, + detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}" + ) + + except Exception as e: + _stop_gateway() + raise HTTPException(status_code=500, detail=f"Failed to start Gateway: {str(e)}") return LaunchResponse( run_id=run_id, status="started", run_dir=str(run_dir), - message=f"Runtime started with run_id: {run_id}", + gateway_port=_gateway_port, + message=f"Runtime started with run_id: {run_id}, Gateway on port: {_gateway_port}", ) @router.post("/stop", response_model=StopResponse) async def stop_runtime(force: bool = True) -> StopResponse: - """Stop the current running runtime. + """Stop the current running runtime.""" + global _gateway_process - Args: - force: If True, forcefully cancel the running task - """ - was_running = await _stop_current_runtime(force=force) + was_running = _is_gateway_running() if not was_running: raise HTTPException(status_code=404, detail="No runtime is currently running") + # Stop Gateway process + _stop_gateway() + + # Unregister runtime manager + unregister_runtime_manager() + return StopResponse( status="stopped", message="Runtime stopped successfully", ) -@router.post("/restart", response_model=RestartResponse) +@router.post("/restart") async def restart_runtime( config: LaunchConfig, background_tasks: BackgroundTasks -) -> RestartResponse: - """Restart the runtime with a new configuration. - - Equivalent to stop + start. - """ +): + """Restart the runtime with a new configuration.""" # Stop current runtime - await _stop_current_runtime(force=True) + await stop_runtime(force=True) # Start new runtime response = await start_runtime(config, background_tasks) - return RestartResponse( - run_id=response.run_id, - status="restarted", - message=f"Runtime restarted with run_id: {response.run_id}", - ) + return { + "run_id": response.run_id, + "status": "restarted", + "gateway_port": response.gateway_port, + "message": f"Runtime restarted with run_id: {response.run_id}", + } @router.get("/current") async def get_current_runtime(): """Get information about the currently running runtime.""" - global _running_task, runtime_manager - - is_running = _running_task is not None and not _running_task.done() - - if not is_running or runtime_manager is None: + if not _is_gateway_running(): raise HTTPException(status_code=404, detail="No runtime is currently running") + # Find latest runtime state + snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json") + snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True) + + if not snapshots: + raise HTTPException(status_code=404, detail="No runtime information available") + + latest = json.loads(snapshots[0].read_text(encoding="utf-8")) + context = latest.get("context", {}) + return { - "run_id": runtime_manager.config_name, - "run_dir": str(runtime_manager.run_dir), - "is_running": is_running, - "bootstrap": runtime_manager.bootstrap, + "run_id": context.get("config_name"), + "run_dir": context.get("run_dir"), + "is_running": True, + "gateway_port": _gateway_port, + "bootstrap": context.get("bootstrap_values", {}), } +def register_runtime_manager(manager: TradingRuntimeManager) -> None: + """Allow other modules to expose the runtime manager to the API.""" + global runtime_manager + runtime_manager = manager + + +def unregister_runtime_manager() -> None: + """Drop the runtime manager reference.""" + global runtime_manager + runtime_manager = None + + def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None: """Write bootstrap configuration to BOOTSTRAP.md.""" try: @@ -362,38 +454,7 @@ def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None: if yaml: front_matter = yaml.safe_dump(values, allow_unicode=True, sort_keys=False) else: - # Fallback to JSON if yaml not available front_matter = json.dumps(values, ensure_ascii=False, indent=2) content = f"---\n{front_matter}---\n" bootstrap_path.write_text(content, encoding="utf-8") - - -async def _run_pipeline( - run_id: str, - run_dir: Path, - bootstrap: Dict[str, Any], - stop_event: asyncio.Event -) -> None: - """Background task to run the trading pipeline.""" - import logging - logger = logging.getLogger(__name__) - - from backend.core.pipeline_runner import run_pipeline - - try: - logger.info(f"Starting pipeline for run_id: {run_id}") - await run_pipeline( - run_id=run_id, - run_dir=run_dir, - bootstrap=bootstrap, - stop_event=stop_event, - ) - logger.info(f"Pipeline completed for run_id: {run_id}") - except asyncio.CancelledError: - logger.info(f"Pipeline cancelled for run_id: {run_id}") - raise - except Exception as e: - logger.exception(f"Pipeline failed for run_id: {run_id}: {e}") - # Re-raise to allow proper cleanup - raise diff --git a/backend/cli.py b/backend/cli.py index be23d09..d09bf72 100644 --- a/backend/cli.py +++ b/backend/cli.py @@ -8,6 +8,7 @@ and frontend development server. """ # flake8: noqa: E501 # pylint: disable=R0912, R0915 +import logging import os import shutil import subprocess @@ -17,7 +18,10 @@ from pathlib import Path from typing import Optional from zoneinfo import ZoneInfo +logger = logging.getLogger(__name__) + import typer +import yaml from rich.console import Console from rich.panel import Panel from rich.prompt import Confirm @@ -27,7 +31,12 @@ from dotenv import load_dotenv from backend.agents.agent_workspace import load_agent_workspace_config from backend.agents.prompt_loader import PromptLoader from backend.agents.skills_manager import SkillsManager +from backend.agents.team_pipeline_config import ( + ensure_team_pipeline_config, + load_team_pipeline_config, +) from backend.agents.workspace_manager import WorkspaceManager +from backend.config.constants import ANALYST_TYPES from backend.data.market_ingest import ingest_symbols from backend.data.market_store import MarketStore from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled @@ -42,6 +51,8 @@ ingest_app = typer.Typer(help="Ingest Polygon market data into the research ware app.add_typer(ingest_app, name="ingest") skills_app = typer.Typer(help="Inspect and manage per-agent skills.") app.add_typer(skills_app, name="skills") +team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.") +app.add_typer(team_app, name="team") console = Console() _prompt_loader = PromptLoader() @@ -95,8 +106,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None: ) else: console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]") - except Exception: - pass + except Exception as e: + logger.debug(f"Could not calculate directory size: {e}") # Show last modified time state_dir = base_data_dir / "state" @@ -197,7 +208,8 @@ def run_data_updater(project_root: Path) -> None: console.print( "[yellow] Data updater module not available, skipping update[/yellow]\n", ) - except Exception: + except Exception as e: + logger.debug(f"Data updater check failed: {e}") console.print( "[yellow] Data updater check failed, skipping update[/yellow]\n", ) @@ -777,6 +789,78 @@ def skills_disable( console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}") +@skills_app.command("install") +def skills_install( + agent_id: str = typer.Option(..., "--agent-id", "-a", help="Target agent id."), + source: str = typer.Option( + ..., + "--source", + "-s", + help="External skill source: directory path, zip path, or http(s) zip URL.", + ), + config_name: str = typer.Option( + "default", + "--config-name", + "-c", + help="Run config name.", + ), + name: Optional[str] = typer.Option( + None, + "--name", + help="Optional override skill name.", + ), + activate: bool = typer.Option( + True, + "--activate/--no-activate", + help="Enable the skill for this agent immediately.", + ), +): + """Install an external skill into one agent's local skill directory.""" + _require_agent_asset_dir(config_name, agent_id) + skills_manager = SkillsManager(project_root=get_project_root()) + result = skills_manager.install_external_skill_for_agent( + config_name=config_name, + agent_id=agent_id, + source=source, + skill_name=name, + activate=activate, + ) + console.print( + f"[green]Installed[/green] `{result['skill_name']}` to `{agent_id}`", + ) + console.print(f"Path: {result['target_dir']}") + console.print(f"Activated: {result['activated']}") + warnings = result.get("warnings") or [] + if warnings: + console.print(f"Warnings: {'; '.join(warnings)}") + + +@team_app.command("show") +def team_show( + config_name: str = typer.Option( + "default", + "--config-name", + "-c", + help="Run config name.", + ), +): + """Show TEAM_PIPELINE.yaml for one run.""" + project_root = get_project_root() + ensure_team_pipeline_config( + project_root=project_root, + config_name=config_name, + default_analysts=list(ANALYST_TYPES.keys()), + ) + config = load_team_pipeline_config(project_root, config_name) + console.print( + Panel.fit( + yaml.safe_dump(config, allow_unicode=True, sort_keys=False), + title=f"TEAM_PIPELINE ({config_name})", + border_style="cyan", + ), + ) + + @app.command() def backtest( start: Optional[str] = typer.Option( diff --git a/backend/core/pipeline.py b/backend/core/pipeline.py index 4390214..e96cf3e 100644 --- a/backend/core/pipeline.py +++ b/backend/core/pipeline.py @@ -10,6 +10,8 @@ import json import logging import os import re +from contextlib import nullcontext +from pathlib import Path from typing import Any, Awaitable, Callable, Dict, List, Optional from agentscope.message import Msg @@ -21,6 +23,26 @@ from backend.core.state_sync import StateSync from backend.utils.trade_executor import PortfolioTradeExecutor from backend.runtime.manager import TradingRuntimeManager from backend.runtime.session import TradingSessionKey +from backend.agents.team_pipeline_config import ( + resolve_active_analysts, + update_active_analysts, +) +from backend.agents import AnalystAgent +from backend.agents.toolkit_factory import create_agent_toolkit +from backend.agents.workspace_manager import WorkspaceManager +from backend.agents.prompt_loader import PromptLoader +from backend.llm.models import get_agent_formatter, get_agent_model +from backend.config.constants import ANALYST_TYPES + +# Team infrastructure imports (graceful import - may not exist yet) +try: + from backend.agents.team.team_coordinator import TeamCoordinator + from backend.agents.team.msg_hub import MsgHub as TeamMsgHub + TEAM_COORD_AVAILABLE = True +except ImportError: + TEAM_COORD_AVAILABLE = False + TeamCoordinator = None + TeamMsgHub = None logger = logging.getLogger(__name__) @@ -77,6 +99,13 @@ class TradingPipeline: self.agent_factory = agent_factory self.runtime_manager = runtime_manager self._session_key: Optional[str] = None + self._dynamic_analysts: Dict[str, Any] = {} + + if hasattr(self.pm, "set_team_controller"): + self.pm.set_team_controller( + create_agent_callback=self._create_runtime_analyst, + remove_agent_callback=self._remove_runtime_analyst, + ) async def run_cycle( self, @@ -115,16 +144,17 @@ class TradingPipeline: _log(f"Starting cycle {date} - {len(tickers)} tickers") session_key = TradingSessionKey(date=date).key() self._session_key = session_key + active_analysts = self._get_active_analysts() if self.runtime_manager: self.runtime_manager.set_session_key(session_key) self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date}) - self._runtime_batch_status(self.analysts, "analysis_in_progress") + self._runtime_batch_status(active_analysts, "analysis_in_progress") # Phase 0: Clear short-term memory to avoid cross-day context pollution _log("Phase 0: Clearing memory") await self._clear_all_agent_memory() - participants = self.analysts + [self.risk_manager, self.pm] + participants = self._all_analysts() + [self.risk_manager, self.pm] # Single MsgHub for entire cycle - no nesting async with MsgHub( @@ -135,9 +165,13 @@ class TradingPipeline: "system", ), ): - # Phase 1.1: Analysts - _log("Phase 1.1: Analyst analysis") - analyst_results = await self._run_analysts_with_sync(tickers, date) + # Phase 1.1: Analysts (parallel execution with TeamCoordinator) + _log("Phase 1.1: Analyst analysis (parallel)") + analyst_results = await self._run_analysts_parallel( + tickers, + date, + active_analysts=active_analysts, + ) # Phase 1.2: Risk Manager _log("Phase 1.2: Risk assessment") @@ -164,6 +198,7 @@ class TradingPipeline: final_predictions = await self._collect_final_predictions( tickers, date, + active_analysts=active_analysts, ) # Record final predictions for leaderboard ranking @@ -212,7 +247,7 @@ class TradingPipeline: if close_prices and self.settlement_coordinator: _log("Phase 5: Daily review and generate memories") self._runtime_batch_status( - [self.risk_manager] + self.analysts + [self.pm], + [self.risk_manager] + self._all_analysts() + [self.pm], "settlement", ) @@ -246,13 +281,13 @@ class TradingPipeline: conference_summary=self.conference_summary, ) self._runtime_batch_status( - [self.risk_manager] + self.analysts + [self.pm], + [self.risk_manager] + self._all_analysts() + [self.pm], "reflection", ) _log(f"Cycle complete: {date}") self._runtime_batch_status( - self.analysts + [self.risk_manager, self.pm], + self._all_analysts() + [self.risk_manager, self.pm], "idle", ) self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date}) @@ -288,7 +323,7 @@ class TradingPipeline: }, ) - for analyst in self.analysts: + for analyst in self._all_analysts(): analyst.reload_runtime_assets( active_skill_dirs=active_skill_map.get(analyst.name, []), ) @@ -302,7 +337,7 @@ class TradingPipeline: return { "config_name": config_name, - "reloaded_agents": [agent.name for agent in self.analysts] + "reloaded_agents": [agent.name for agent in self._all_analysts()] + ["risk_manager", "portfolio_manager"], "active_skills": { agent_id: [path.name for path in paths] @@ -313,7 +348,7 @@ class TradingPipeline: async def _clear_all_agent_memory(self): """Clear short-term memory for all agents""" - for analyst in self.analysts: + for analyst in self._all_analysts(): await analyst.memory.clear() await self.risk_manager.memory.clear() @@ -395,7 +430,7 @@ class TradingPipeline: trajectories = {} # Capture analyst trajectories - for analyst in self.analysts: + for analyst in self._all_analysts(): try: msgs = await analyst.memory.get_memory() if msgs: @@ -605,7 +640,7 @@ class TradingPipeline: ) # Record for analysts - for analyst in self.analysts: + for analyst in self._all_analysts(): if ( hasattr(analyst, "long_term_memory") and analyst.long_term_memory is not None @@ -724,67 +759,82 @@ class TradingPipeline: date=date, ) - # Run discussion cycles (no new MsgHub - use parent's) - for cycle in range(self.max_comm_cycles): + # Conference participants: analysts + PM + conference_participants = self._get_active_analysts() + [self.pm] + + # Use TeamMsgHub for conference if available + if TEAM_COORD_AVAILABLE and TeamMsgHub is not None: _log( - "Phase 2.1: Conference discussion - " - f"Conference {cycle + 1}/{self.max_comm_cycles}", + f"Phase 2.1: Conference using TeamMsgHub with " + f"{len(conference_participants)} participants" ) + conference_hub = TeamMsgHub(participants=conference_participants) + else: + _log("Phase 2.1: Conference using standard MsgHub context") + conference_hub = None - if self.state_sync: - await self.state_sync.on_conference_cycle_start( - cycle=cycle + 1, - total_cycles=self.max_comm_cycles, + # Run discussion cycles + async with conference_hub if conference_hub else nullcontext(None): + for cycle in range(self.max_comm_cycles): + _log( + "Phase 2.1: Conference discussion - " + f"Conference {cycle + 1}/{self.max_comm_cycles}", ) - # PM sets agenda or asks questions - pm_prompt = self._build_pm_discussion_prompt( - cycle=cycle, - tickers=tickers, - date=date, - prices=prices, - analyst_results=analyst_results, - risk_assessment=risk_assessment, - ) + if self.state_sync: + await self.state_sync.on_conference_cycle_start( + cycle=cycle + 1, + total_cycles=self.max_comm_cycles, + ) - pm_msg = Msg(name="system", content=pm_prompt, role="user") - pm_response = await self.pm.reply(pm_msg) - - if self.state_sync: - pm_content = self._extract_text_content(pm_response.content) - await self.state_sync.on_conference_message( - agent_id="portfolio_manager", - content=pm_content, - ) - - # Analysts share perspectives - for analyst in self.analysts: - analyst_prompt = self._build_analyst_discussion_prompt( + # PM sets agenda or asks questions + pm_prompt = self._build_pm_discussion_prompt( cycle=cycle, tickers=tickers, date=date, + prices=prices, + analyst_results=analyst_results, + risk_assessment=risk_assessment, ) - analyst_msg = Msg( - name="system", - content=analyst_prompt, - role="user", - ) - analyst_response = await analyst.reply(analyst_msg) + pm_msg = Msg(name="system", content=pm_prompt, role="user") + pm_response = await self.pm.reply(pm_msg) if self.state_sync: - analyst_content = self._extract_text_content( - analyst_response.content, - ) + pm_content = self._extract_text_content(pm_response.content) await self.state_sync.on_conference_message( - agent_id=analyst.name, - content=analyst_content, + agent_id="portfolio_manager", + content=pm_content, ) - if self.state_sync: - await self.state_sync.on_conference_cycle_end( - cycle=cycle + 1, - ) + # Analysts share perspectives (supports per-round active team updates) + for analyst in self._get_active_analysts(): + analyst_prompt = self._build_analyst_discussion_prompt( + cycle=cycle, + tickers=tickers, + date=date, + ) + + analyst_msg = Msg( + name="system", + content=analyst_prompt, + role="user", + ) + analyst_response = await analyst.reply(analyst_msg) + + if self.state_sync: + analyst_content = self._extract_text_content( + analyst_response.content, + ) + await self.state_sync.on_conference_message( + agent_id=analyst.name, + content=analyst_content, + ) + + if self.state_sync: + await self.state_sync.on_conference_cycle_end( + cycle=cycle + 1, + ) # Generate conference summary by PM _log( @@ -885,6 +935,7 @@ class TradingPipeline: self, tickers: List[str], date: str, + active_analysts: Optional[List[Any]] = None, ) -> List[Dict[str, Any]]: """ Collect final predictions from all analysts as simple text responses. @@ -892,14 +943,15 @@ class TradingPipeline: """ _log( "Phase 2.2: Analysts generate final structured predictions\n" - f" Starting _collect_final_predictions for {len(self.analysts)} analysts", + f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts", ) final_predictions = [] - for i, analyst in enumerate(self.analysts): + analysts = active_analysts or self.analysts + for i, analyst in enumerate(analysts): _log( "Phase 2.2: Analysts generate final structured predictions\n" - f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}", + f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}", ) prompt = ( @@ -995,11 +1047,13 @@ class TradingPipeline: self, tickers: List[str], date: str, + active_analysts: Optional[List[Any]] = None, ) -> List[Dict[str, Any]]: """Run all analysts with real-time sync after each completion""" results = [] + analysts = active_analysts or self.analysts - for analyst in self.analysts: + for analyst in analysts: content = ( f"Analyze the following stocks for date {date}: {', '.join(tickers)}. " f"Provide investment signals with confidence scores and reasoning." @@ -1029,15 +1083,107 @@ class TradingPipeline: return results + async def _run_analysts_parallel( + self, + tickers: List[str], + date: str, + active_analysts: Optional[List[Any]] = None, + ) -> List[Dict[str, Any]]: + """Run all analysts in parallel using TeamCoordinator. + + This method replaces the sequential analyst loop with parallel execution + using the TeamCoordinator for orchestration. + + Args: + tickers: List of stock tickers to analyze + date: Trading date + active_analysts: Optional list of analysts to run + + Returns: + List of analyst result dictionaries + """ + analysts = active_analysts or self.analysts + + if not analysts: + return [] + + if not TEAM_COORD_AVAILABLE: + _log("TeamCoordinator not available, falling back to sequential execution") + return await self._run_analysts_with_sync( + tickers=tickers, + date=date, + active_analysts=active_analysts, + ) + + _log( + f"Phase 1.1: Running {len(analysts)} analysts in parallel " + f"[{', '.join(a.name for a in analysts)}]" + ) + + # Build the analyst prompt + content = ( + f"Analyze the following stocks for date {date}: {', '.join(tickers)}. " + f"Provide investment signals with confidence scores and reasoning." + ) + + # Create coordinator for parallel execution + coordinator = TeamCoordinator( + participants=analysts, + task_content=content, + ) + + # Run analysts in parallel via TeamCoordinator + results = await coordinator.run_phase( + "analyst_analysis", + metadata={"tickers": tickers, "date": date}, + ) + + # Process results and sync + processed_results = [] + for i, (analyst, result) in enumerate(zip(analysts, results)): + if result is not None: + extracted = self._extract_result_from_msg(result) + processed_results.append(extracted) + + # Sync retrieved memory + await self._sync_memory_if_retrieved(analyst) + + # Broadcast agent result via StateSync + if self.state_sync: + text_content = self._extract_text_content(result.content) + await self.state_sync.on_agent_complete( + agent_id=analyst.name, + content=text_content, + ) + else: + logger.warning( + "Analyst %s returned no result", + analyst.name, + ) + processed_results.append({ + "agent": analyst.name, + "content": "", + "success": False, + }) + + _log( + f"Phase 1.1: Parallel analyst execution complete " + f"({len(processed_results)}/{len(analysts)} successful)" + ) + + return processed_results + async def _run_analysts( self, tickers: List[str], date: str, + active_analysts: Optional[List[Any]] = None, ) -> List[Dict[str, Any]]: """Run all analysts (without sync, for backward compatibility)""" results = [] + analysts = active_analysts or self.analysts - for analyst in self.analysts: + for analyst in analysts: content = ( f"Analyze the following stocks for date {date}: {', '.join(tickers)}. " f"Provide investment signals with confidence scores and reasoning." @@ -1461,6 +1607,83 @@ class TradingPipeline: for agent in agents: self._runtime_update_status(agent, status) + def _all_analysts(self) -> List[Any]: + """Return static analysts plus runtime-created analysts.""" + return list(self.analysts) + list(self._dynamic_analysts.values()) + + def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str: + """Create one runtime analyst instance.""" + if analyst_type not in ANALYST_TYPES: + return ( + f"Unknown analyst_type '{analyst_type}'. " + f"Available: {', '.join(ANALYST_TYPES.keys())}" + ) + if agent_id in {agent.name for agent in self._all_analysts()}: + return f"Analyst '{agent_id}' already exists." + + config_name = getattr(self.pm, "config", {}).get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + personas = PromptLoader().load_yaml_config("analyst", "personas") + persona = personas.get(analyst_type, {}) + WorkspaceManager(project_root=project_root).ensure_agent_assets( + config_name=config_name, + agent_id=agent_id, + role_seed=persona.get("description", "").strip(), + style_seed="\n".join(f"- {item}" for item in persona.get("focus", [])), + policy_seed=( + "State a clear signal, confidence, and the conditions " + "that would invalidate the thesis." + ), + ) + + agent = AnalystAgent( + analyst_type=analyst_type, + toolkit=create_agent_toolkit( + agent_id=agent_id, + config_name=config_name, + active_skill_dirs=[], + ), + model=get_agent_model(analyst_type), + formatter=get_agent_formatter(analyst_type), + agent_id=agent_id, + config={"config_name": config_name}, + ) + self._dynamic_analysts[agent_id] = agent + update_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=[item.name for item in self._all_analysts()], + add=[agent_id], + ) + return f"Created runtime analyst '{agent_id}' ({analyst_type})." + + def _remove_runtime_analyst(self, agent_id: str) -> str: + """Remove one runtime-created analyst instance.""" + if agent_id not in self._dynamic_analysts: + return f"Runtime analyst '{agent_id}' not found." + self._dynamic_analysts.pop(agent_id, None) + config_name = getattr(self.pm, "config", {}).get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + update_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=[item.name for item in self._all_analysts()], + remove=[agent_id], + ) + return f"Removed runtime analyst '{agent_id}'." + + def _get_active_analysts(self) -> List[Any]: + """Resolve active analyst participants from run-scoped team pipeline config.""" + config_name = getattr(self.pm, "config", {}).get("config_name", "default") + project_root = Path(__file__).resolve().parents[2] + analyst_map = {agent.name: agent for agent in self._all_analysts()} + active_ids = resolve_active_analysts( + project_root=project_root, + config_name=config_name, + available_analysts=list(analyst_map.keys()), + ) + return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map] + def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None: if not self.runtime_manager: return diff --git a/backend/core/pipeline_runner.py b/backend/core/pipeline_runner.py index eedd5d9..76270cb 100644 --- a/backend/core/pipeline_runner.py +++ b/backend/core/pipeline_runner.py @@ -61,7 +61,7 @@ def stop_gateway() -> None: _gateway_instance = None -async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path): +def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path): """Create ReMeTaskLongTermMemory for an agent.""" try: from agentscope.memory import ReMeTaskLongTermMemory @@ -206,6 +206,13 @@ async def run_pipeline( """ Run the trading pipeline with the given configuration. + Service Startup Order: + Phase 1: WebSocket Server - Frontend can connect + Phase 2: Market Service - Price data starts flowing + Phase 3: Agent Runtime - Create all agents + Phase 4: Pipeline & Scheduler - Trading logic ready + Phase 5: Gateway Fully Operational - All systems running + Args: run_id: Unique run identifier (timestamp) run_dir: Run directory path @@ -219,7 +226,9 @@ async def run_pipeline( # Set global shutdown event set_shutdown_event(stop_event) - logger.info(f"[Pipeline {run_id}] Starting...") + logger.info(f"[Pipeline {run_id}] ======================================") + logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...") + logger.info(f"[Pipeline {run_id}] ======================================") try: # Extract config values @@ -230,15 +239,21 @@ async def run_pipeline( schedule_mode = bootstrap.get("schedule_mode", "daily") trigger_time = bootstrap.get("trigger_time", "09:30") interval_minutes = int(bootstrap.get("interval_minutes", 60)) + heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) mode = bootstrap.get("mode", "live") start_date = bootstrap.get("start_date") end_date = bootstrap.get("end_date") enable_memory = bootstrap.get("enable_memory", False) + enable_mock = bootstrap.get("enable_mock", False) is_backtest = mode == "backtest" - is_mock = mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true") + is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true") + + # ====================================================================== + # PHASE 0: Initialize runtime manager + # ====================================================================== + logger.info("[Phase 0/5] Initializing runtime manager...") - # Get or create runtime manager from backend.api.runtime import runtime_manager if runtime_manager is None: @@ -255,16 +270,11 @@ async def run_pipeline( from backend.api.runtime import register_runtime_manager register_runtime_manager(runtime_manager) - # Create market service - market_service = MarketService( - tickers=tickers, - poll_interval=10, - mock_mode=is_mock and not is_backtest, - backtest_mode=is_backtest, - api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None, - backtest_start_date=start_date if is_backtest else None, - backtest_end_date=end_date if is_backtest else None, - ) + # ====================================================================== + # PHASE 1 & 2: Create infrastructure services (Market, Storage) + # These will be started by Gateway in the correct order + # ====================================================================== + logger.info("[Phase 1-2/5] Creating infrastructure services...") # Create storage service storage_service = StorageService( @@ -278,7 +288,22 @@ async def run_pipeline( else: storage_service.update_leaderboard_model_info() - # Create agents and pipeline + # Create market service (data source) + market_service = MarketService( + tickers=tickers, + poll_interval=10, + mock_mode=is_mock and not is_backtest, + backtest_mode=is_backtest, + api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None, + backtest_start_date=start_date if is_backtest else None, + backtest_end_date=end_date if is_backtest else None, + ) + + # ====================================================================== + # PHASE 3: Create Agent Runtime + # ====================================================================== + logger.info("[Phase 3/5] Creating agent runtime...") + analysts, risk_manager, pm, long_term_memories = create_agents( run_id=run_id, run_dir=run_dir, @@ -303,6 +328,11 @@ async def run_pipeline( initial_capital=initial_cash, ) + # ====================================================================== + # PHASE 4: Create Pipeline & Scheduler + # ====================================================================== + logger.info("[Phase 4/5] Creating pipeline and scheduler...") + # Create pipeline pipeline = TradingPipeline( analysts=analysts, @@ -336,6 +366,7 @@ async def run_pipeline( mode=schedule_mode, trigger_time=trigger_time, interval_minutes=interval_minutes, + heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None, config={"config_name": run_id}, ) @@ -344,7 +375,15 @@ async def run_pipeline( scheduler_callback = scheduler_callback_fn - # Create Gateway for WebSocket connections (after pipeline and scheduler are ready) + # ====================================================================== + # PHASE 5: Start Gateway (WebSocket → Market → Scheduler) + # Gateway.start() will handle the final startup sequence: + # - WebSocket Server first (frontend can connect) + # - Market Service second (price data flows) + # - Scheduler last (trading begins) + # ====================================================================== + logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...") + gateway = Gateway( market_service=market_service, storage_service=storage_service, @@ -359,6 +398,7 @@ async def run_pipeline( "schedule_mode": schedule_mode, "interval_minutes": interval_minutes, "trigger_time": trigger_time, + "heartbeat_interval": heartbeat_interval, "initial_cash": initial_cash, "margin_requirement": margin_requirement, "max_comm_cycles": max_comm_cycles, @@ -374,13 +414,17 @@ async def run_pipeline( for memory in long_term_memories: await stack.enter_async_context(memory) - # Start Gateway in background task + # Start Gateway - this will execute the 4-phase startup: + # Phase 1: WebSocket Server (frontend can connect immediately) + # Phase 2: Market Service (price updates start flowing) + # Phase 3: Market Status Monitor + # Phase 4: Scheduler (trading cycles begin) gateway_task = asyncio.create_task( gateway.start(host="0.0.0.0", port=8765) ) - logger.info("[Pipeline] Gateway started on ws://localhost:8765") + logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765") - # Give Gateway a moment to start + # Wait for Gateway to fully initialize all phases await asyncio.sleep(0.5) # Define the trading cycle callback diff --git a/backend/core/scheduler.py b/backend/core/scheduler.py index 983759b..9bd2133 100644 --- a/backend/core/scheduler.py +++ b/backend/core/scheduler.py @@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles """ import asyncio import logging -from datetime import datetime, timedelta +from datetime import datetime, time, timedelta from typing import Any, Callable, Optional from zoneinfo import ZoneInfo @@ -28,17 +28,21 @@ class Scheduler: mode: str = "daily", trigger_time: Optional[str] = None, interval_minutes: Optional[int] = None, + heartbeat_interval: Optional[int] = None, config: Optional[dict] = None, ): self.mode = mode self.trigger_time = trigger_time or "09:30" # NYSE timezone self.trigger_now = self.trigger_time == "now" self.interval_minutes = interval_minutes or 60 + self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour self.config = config or {} self.running = False self._task: Optional[asyncio.Task] = None + self._heartbeat_task: Optional[asyncio.Task] = None self._callback: Optional[Callable] = None + self._heartbeat_callback: Optional[Callable] = None def _now_nyse(self) -> datetime: """Get current time in NYSE timezone""" @@ -53,6 +57,15 @@ class Scheduler: ) return len(valid_days) > 0 + def _is_trading_hours(self, now: datetime) -> bool: + """Check if current time is within NYSE trading hours (9:30-16:00 ET).""" + market_time = now.time() + return time(9, 30) <= market_time <= time(16, 0) + + def set_heartbeat_callback(self, callback: Callable) -> None: + """Register callback for heartbeat triggers.""" + self._heartbeat_callback = callback + def _next_trading_day(self, from_date: datetime) -> datetime: """Find the next trading day from given date""" check_date = from_date @@ -72,6 +85,13 @@ class Scheduler: self._callback = callback self._schedule_task() + # Start heartbeat loop if configured + if self.heartbeat_interval and self._heartbeat_callback: + self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop()) + logger.info( + f"Heartbeat loop started: interval={self.heartbeat_interval}s", + ) + logger.info( f"Scheduler started: mode={self.mode}, timezone=America/New_York", ) @@ -132,6 +152,30 @@ class Scheduler: return changed + async def _run_heartbeat_loop(self): + """Run heartbeat checks on a separate interval during trading hours.""" + while self.running: + now = self._now_nyse() + if self._is_trading_day(now) and self._is_trading_hours(now): + if self._heartbeat_callback: + try: + current_date = now.strftime("%Y-%m-%d") + logger.debug( + f"[Heartbeat] Triggering heartbeat check for {current_date}", + ) + await self._heartbeat_callback(date=current_date) + except Exception as e: + logger.error( + f"[Heartbeat] Callback failed: {e}", + exc_info=True, + ) + else: + logger.warning( + "[Heartbeat] Callback not set, skipping heartbeat", + ) + + await asyncio.sleep(self.heartbeat_interval) + async def _run_daily(self, callback: Callable): """Run once per trading day at specified time (NYSE timezone)""" first_run = True @@ -206,6 +250,9 @@ class Scheduler: if self._task: self._task.cancel() self._task = None + if self._heartbeat_task: + self._heartbeat_task.cancel() + self._heartbeat_task = None logger.info("Scheduler stopped") diff --git a/backend/data/schema.py b/backend/data/schema.py index e73dade..c9e4bde 100644 --- a/backend/data/schema.py +++ b/backend/data/schema.py @@ -163,6 +163,16 @@ class AnalystSignal(BaseModel): signal: str | None = None confidence: float | None = None reasoning: dict | str | None = None + # Extended fields for richer signal information + reasons: list[str] | None = None # Core drivers/reasons for the signal + risks: list[str] | None = None # Key risk factors + invalidation: str | None = None # Conditions that would invalidate the thesis + next_action: str | None = None # Suggested next action for PM + # Valuation-related fields + intrinsic_value: float | None = None # DCF intrinsic value + fair_value_range: dict | None = None # {bear, base, bull} fair value range + value_gap_pct: float | None = None # Value gap percentage + valuation_methods: list[str] | None = None # List of valuation methods used max_position_size: float | None = None # For risk management signals diff --git a/backend/enrich/llm_enricher.py b/backend/enrich/llm_enricher.py index f84039b..2c05dbd 100644 --- a/backend/enrich/llm_enricher.py +++ b/backend/enrich/llm_enricher.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio +import logging from concurrent.futures import ThreadPoolExecutor from typing import Any @@ -12,6 +13,8 @@ from pydantic import BaseModel, Field from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str from backend.llm.models import create_model +logger = logging.getLogger(__name__) + class EnrichedNewsItem(BaseModel): """Structured output schema for one enriched article.""" @@ -156,7 +159,8 @@ def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None: ] try: response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem)) - except Exception: + except Exception as e: + logger.warning(f"LLM enrichment failed: {e}") return None payload = _normalize_enrichment_payload(getattr(response, "metadata", None)) @@ -268,7 +272,8 @@ def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None: response = _run_async( model(messages=messages, structured_model=RangeAnalysisPayload), ) - except Exception: + except Exception as e: + logger.warning(f"LLM enrichment failed: {e}") return None metadata = getattr(response, "metadata", None) diff --git a/backend/llm/models.py b/backend/llm/models.py index d28b428..202c130 100644 --- a/backend/llm/models.py +++ b/backend/llm/models.py @@ -3,9 +3,11 @@ AgentScope Native Model Factory Uses native AgentScope model classes for LLM calls """ -from enum import Enum -from typing import Optional, Tuple import os +import time +import logging +from enum import Enum +from typing import Any, Callable, Optional, Tuple, TypeVar, Union from agentscope.formatter import ( AnthropicChatFormatter, DashScopeChatFormatter, @@ -26,6 +28,244 @@ from backend.config.env_config import ( get_env_str, ) +logger = logging.getLogger(__name__) + +# Retry wrapper types +T = TypeVar("T") + + +class RetryChatModel: + """Wraps an AgentScope model with automatic retry for transient errors. + + Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts, + and other transient failures with exponential backoff. + """ + + DEFAULT_MAX_RETRIES = 3 + DEFAULT_INITIAL_DELAY = 1.0 + DEFAULT_MAX_DELAY = 60.0 + DEFAULT_BACKOFF_MULTIPLIER = 2.0 + + # Transient error codes/messages that should trigger retry + TRANSIENT_ERROR_KEYWORDS = frozenset([ + "rate_limit", + "429", + "timeout", + "503", + "502", + "504", + "connection", + "temporary", + "overloaded", + "too_many_requests", + ]) + + def __init__( + self, + model: Any, + max_retries: int = DEFAULT_MAX_RETRIES, + initial_delay: float = DEFAULT_INITIAL_DELAY, + max_delay: float = DEFAULT_MAX_DELAY, + backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER, + on_retry: Optional[Callable[[int, Exception, float], None]] = None, + ): + """Initialize retry wrapper. + + Args: + model: The underlying AgentScope model to wrap + max_retries: Maximum number of retry attempts + initial_delay: Initial delay in seconds before first retry + max_delay: Maximum delay between retries + backoff_multiplier: Multiplier for exponential backoff + on_retry: Optional callback(retry_count, exception, delay) for logging + """ + self._model = model + self._max_retries = max_retries + self._initial_delay = initial_delay + self._max_delay = max_delay + self._backoff_multiplier = backoff_multiplier + self._on_retry = on_retry + self._total_tokens_used = 0 + self._total_cost = 0.0 + + @property + def model_name(self) -> str: + return getattr(self._model, "model_name", str(self._model)) + + @property + def total_tokens_used(self) -> int: + return self._total_tokens_used + + @property + def total_cost(self) -> float: + return self._total_cost + + def _is_transient_error(self, error: Exception) -> bool: + """Check if an error is transient and should be retried. + + Args: + error: The exception to check + + Returns: + True if the error is transient + """ + error_str = str(error).lower() + for keyword in self.TRANSIENT_ERROR_KEYWORDS: + if keyword in error_str: + return True + return False + + def _calculate_delay(self, retry_count: int) -> float: + """Calculate delay for given retry attempt with exponential backoff. + + Args: + retry_count: Current retry attempt number (1-based) + + Returns: + Delay in seconds + """ + delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1)) + return min(delay, self._max_delay) + + def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T: + """Call a function with retry logic for transient errors. + + Args: + func: Function to call + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Result from func + + Raises: + Last exception if all retries exhausted + """ + last_error: Optional[Exception] = None + + for attempt in range(1, self._max_retries + 1): + try: + result = func(*args, **kwargs) + + # Track usage if available + if hasattr(result, "usage") and result.usage: + usage = result.usage + self._total_tokens_used += getattr(usage, "total_tokens", 0) + self._total_cost += getattr(usage, "cost", 0.0) + + return result + + except Exception as e: + last_error = e + + if attempt >= self._max_retries: + logger.error( + "RetryChatModel: Max retries (%d) exhausted for %s", + self._max_retries, + self.model_name, + ) + break + + if not self._is_transient_error(e): + logger.warning( + "RetryChatModel: Non-transient error, not retrying: %s", + str(e), + ) + break + + delay = self._calculate_delay(attempt) + logger.warning( + "RetryChatModel: Transient error on attempt %d/%d, " + "retrying in %.1fs: %s", + attempt, + self._max_retries, + delay, + str(e)[:200], + ) + + if self._on_retry: + self._on_retry(attempt, e, delay) + + time.sleep(delay) + + if last_error is not None: + raise last_error + raise RuntimeError("RetryChatModel: Unexpected state, no error but no result") + + def __call__(self, *args, **kwargs) -> Any: + """Forward calls to the wrapped model with retry logic.""" + return self._call_with_retry(self._model, *args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to the wrapped model.""" + return getattr(self._model, name) + + +class TokenRecordingModelWrapper: + """Wraps a model to track token usage per provider. + + Based on CoPaw's TokenRecordingModelWrapper design. + """ + + def __init__(self, model: Any): + """Initialize token recorder. + + Args: + model: The underlying AgentScope model to wrap + """ + self._model = model + self._total_tokens = 0 + self._prompt_tokens = 0 + self._completion_tokens = 0 + self._total_cost = 0.0 + + @property + def model_name(self) -> str: + return getattr(self._model, "model_name", str(self._model)) + + @property + def total_tokens(self) -> int: + return self._total_tokens + + @property + def prompt_tokens(self) -> int: + return self._prompt_tokens + + @property + def completion_tokens(self) -> int: + return self._completion_tokens + + @property + def total_cost(self) -> float: + return self._total_cost + + def record_usage(self, usage: Any) -> None: + """Record token usage from a model response. + + Args: + usage: Usage object from model response + """ + if usage is None: + return + + self._prompt_tokens += getattr(usage, "prompt_tokens", 0) + self._completion_tokens += getattr(usage, "completion_tokens", 0) + self._total_tokens += getattr(usage, "total_tokens", 0) + self._total_cost += getattr(usage, "cost", 0.0) + + def __call__(self, *args, **kwargs) -> Any: + """Forward calls and record usage.""" + result = self._model(*args, **kwargs) + + if hasattr(result, "usage") and result.usage: + self.record_usage(result.usage) + + return result + + def __getattr__(self, name: str) -> Any: + """Proxy attribute access to the wrapped model.""" + return getattr(self._model, name) + class ModelProvider(Enum): """Supported model providers""" diff --git a/backend/services/gateway.py b/backend/services/gateway.py index ceca502..2378580 100644 --- a/backend/services/gateway.py +++ b/backend/services/gateway.py @@ -37,6 +37,9 @@ from backend.services.storage import StorageService from backend.data.provider_router import get_provider_router from backend.tools.data_tools import get_prices from backend.tools.data_tools import get_company_news +from backend.tools.data_tools import get_insider_trades +from backend.tools.data_tools import prices_to_df +from backend.tools.technical_signals import StockTechnicalAnalyzer from backend.core.scheduler import Scheduler logger = logging.getLogger(__name__) @@ -99,9 +102,15 @@ class Gateway: self._provider_router = get_provider_router() self._loop: Optional[asyncio.AbstractEventLoop] = None self._project_root = Path(__file__).resolve().parents[2] + self._technical_analyzer = StockTechnicalAnalyzer() async def start(self, host: str = "0.0.0.0", port: int = 8766): - """Start gateway server""" + """Start gateway server with proper initialization order. + + Phase 1: Start WebSocket server first so frontend can connect immediately + Phase 2: Start market data service (pushes data to connected clients) + Phase 3: Start scheduler last (triggers trading cycles) + """ logger.info(f"Starting gateway on {host}:{port}") self._loop = asyncio.get_running_loop() self._provider_router.add_listener(self._on_provider_usage_changed) @@ -124,7 +133,7 @@ class Gateway: self.state_sync.load_state() self.market_service.set_price_recorder(self.storage.record_price_point) - self.state_sync.update_state("status", "running") + self.state_sync.update_state("status", "initializing") self.state_sync.update_state("server_mode", self.mode) self.state_sync.update_state("is_backtest", self.is_backtest) self.state_sync.update_state( @@ -171,30 +180,72 @@ class Gateway: f"{summary.get('totalAssetValue', 0):,.2f}", ) - await self.market_service.start(broadcast_func=self.broadcast) + # ====================================================================== + # PHASE 1: Start WebSocket server first + # This allows frontend to connect immediately and receive status updates + # ====================================================================== + logger.info("[Phase 1/4] Starting WebSocket server...") + self.state_sync.update_state("status", "websocket_ready") - if self.scheduler: - await self.scheduler.start(self.on_strategy_trigger) - elif self.scheduler_callback: - await self.scheduler_callback(callback=self.on_strategy_trigger) - - # Start market status monitoring (only for live mode) - if not self.is_backtest: - self._market_status_task = asyncio.create_task( - self._market_status_monitor(), - ) - - async with websockets.serve( + # Create server but don't block yet - we'll serve inside the context manager + server = await websockets.serve( self.handle_client, host, port, ping_interval=30, ping_timeout=60, - ): - logger.info( - f"Gateway started: ws://{host}:{port}, mode={self.mode}", + ) + logger.info(f"WebSocket server ready: ws://{host}:{port}") + + # Give a brief moment for any existing clients to reconnect + await asyncio.sleep(0.1) + + # ====================================================================== + # PHASE 2: Start market data service + # Now frontend is connected, start pushing price updates + # ====================================================================== + logger.info("[Phase 2/4] Starting market data service...") + self.state_sync.update_state("status", "market_service_starting") + await self.market_service.start(broadcast_func=self.broadcast) + self.state_sync.update_state("status", "market_service_ready") + logger.info("Market data service ready - price updates active") + + # ====================================================================== + # PHASE 3: Start market status monitoring + # Monitors market open/close and broadcasts status + # ====================================================================== + logger.info("[Phase 3/4] Starting market status monitoring...") + if not self.is_backtest: + self._market_status_task = asyncio.create_task( + self._market_status_monitor(), ) - await asyncio.Future() + + # ====================================================================== + # PHASE 4: Start scheduler last + # Only start trading after everything else is ready + # ====================================================================== + logger.info("[Phase 4/4] Starting scheduler...") + self.state_sync.update_state("status", "scheduler_starting") + + if self.scheduler: + # Wire up heartbeat callback if heartbeat is configured + heartbeat_interval = self.config.get("heartbeat_interval", 0) + if heartbeat_interval and heartbeat_interval > 0: + self.scheduler.set_heartbeat_callback(self.on_heartbeat_trigger) + logger.info( + f"[Heartbeat] Registered heartbeat callback (interval={heartbeat_interval}s)", + ) + await self.scheduler.start(self.on_strategy_trigger) + elif self.scheduler_callback: + await self.scheduler_callback(callback=self.on_strategy_trigger) + + self.state_sync.update_state("status", "running") + logger.info( + f"Gateway fully operational: ws://{host}:{port}, mode={self.mode}", + ) + + # Keep server running + await asyncio.Future() def _on_provider_usage_changed(self, snapshot: Dict[str, Any]): """Handle provider routing updates from the shared router.""" @@ -275,8 +326,8 @@ class Gateway: ensure_ascii=False, ), ) - except Exception: - pass + except Exception as e: + logger.warning(f"Failed to send error response to client: {e}") async def _handle_client_messages( self, @@ -343,10 +394,14 @@ class Gateway: await self._handle_get_stock_news_categories(websocket, data) elif msg_type == "get_stock_range_explain": await self._handle_get_stock_range_explain(websocket, data) + elif msg_type == "get_stock_insider_trades": + await self._handle_get_stock_insider_trades(websocket, data) elif msg_type == "get_stock_story": await self._handle_get_stock_story(websocket, data) elif msg_type == "get_stock_similar_days": await self._handle_get_stock_similar_days(websocket, data) + elif msg_type == "get_stock_technical_indicators": + await self._handle_get_stock_technical_indicators(websocket, data) elif msg_type == "run_stock_enrich": await self._handle_run_stock_enrich(websocket, data) @@ -862,6 +917,94 @@ class Gateway: ), ) + async def _handle_get_stock_insider_trades( + self, + websocket: ServerConnection, + data: Dict[str, Any], + ): + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send( + json.dumps( + { + "type": "stock_insider_trades_loaded", + "ticker": "", + "trades": [], + "error": "invalid ticker", + }, + ensure_ascii=False, + ), + ) + return + + end_date = str( + data.get("end_date") + or self.state_sync.state.get("current_date") + or datetime.now().strftime("%Y-%m-%d") + ).strip()[:10] + start_date = str(data.get("start_date") or "").strip()[:10] + limit = int(data.get("limit", 50)) + + trades = await asyncio.to_thread( + get_insider_trades, + ticker=ticker, + end_date=end_date, + start_date=start_date if start_date else None, + limit=limit, + ) + + # Sort by transaction date descending + sorted_trades = sorted( + trades, + key=lambda t: t.transaction_date or "", + reverse=True, + ) + + # Format for frontend + formatted_trades = [ + { + "ticker": t.ticker, + "name": t.name, + "title": t.title, + "is_board_director": t.is_board_director, + "transaction_date": t.transaction_date, + "transaction_shares": t.transaction_shares, + "transaction_price_per_share": t.transaction_price_per_share, + "transaction_value": t.transaction_value, + "shares_owned_before_transaction": t.shares_owned_before_transaction, + "shares_owned_after_transaction": t.shares_owned_after_transaction, + "security_title": t.security_title, + "filing_date": t.filing_date, + # Calculated fields + "holding_change": ( + (t.shares_owned_after_transaction or 0) + - (t.shares_owned_before_transaction or 0) + if t.shares_owned_after_transaction and t.shares_owned_before_transaction + else None + ), + "is_buy": ( + (t.transaction_shares or 0) > 0 + if t.transaction_shares is not None + else None + ), + } + for t in sorted_trades + ] + + await websocket.send( + json.dumps( + { + "type": "stock_insider_trades_loaded", + "ticker": ticker, + "start_date": start_date or None, + "end_date": end_date, + "trades": formatted_trades, + }, + ensure_ascii=False, + default=str, + ), + ) + async def _handle_get_stock_story( self, websocket: ServerConnection, @@ -969,6 +1112,136 @@ class Gateway: ), ) + async def _handle_get_stock_technical_indicators( + self, + websocket: ServerConnection, + data: Dict[str, Any], + ): + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send( + json.dumps( + { + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": "ticker is required", + }, + ensure_ascii=False, + ), + ) + return + + try: + # Get price data for the ticker + from datetime import datetime, timedelta + end_date = datetime.now() + start_date = end_date - timedelta(days=250) # ~1 year for MA200 + + prices = get_prices( + ticker=ticker, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d"), + ) + + if not prices or len(prices) < 20: + await websocket.send( + json.dumps( + { + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": "Insufficient price data", + }, + ensure_ascii=False, + ), + ) + return + + # Analyze technical indicators + df = prices_to_df(prices) + signal = self._technical_analyzer.analyze(ticker, df) + + # Calculate additional volatility metrics + import pandas as pd + df_sorted = df.sort_values("time").reset_index(drop=True) + df_sorted["returns"] = df_sorted["close"].pct_change() + + vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None + vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None + vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None + + # Calculate MA distance from current price + ma_distance = {} + for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]: + ma_value = getattr(signal, ma_key, None) + if ma_value and ma_value > 0: + ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 + else: + ma_distance[ma_key] = None + + indicators = { + "ticker": ticker, + "current_price": signal.current_price, + "ma": { + "ma5": signal.ma5, + "ma10": signal.ma10, + "ma20": signal.ma20, + "ma50": signal.ma50, + "ma200": signal.ma200, + "distance": ma_distance, + }, + "rsi": { + "rsi14": signal.rsi14, + "status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral", + }, + "macd": { + "macd": signal.macd, + "signal": signal.macd_signal, + "histogram": signal.macd - signal.macd_signal, + }, + "bollinger": { + "upper": signal.bollinger_upper, + "mid": signal.bollinger_mid, + "lower": signal.bollinger_lower, + }, + "volatility": { + "vol_10d": vol_10, + "vol_20d": vol_20, + "vol_60d": vol_60, + "annualized": signal.annualized_volatility_pct, + "risk_level": signal.risk_level, + }, + "trend": signal.trend, + "mean_reversion": signal.mean_reversion_signal, + } + + await websocket.send( + json.dumps( + { + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": indicators, + }, + ensure_ascii=False, + default=str, + ), + ) + + except Exception as e: + logger.exception(f"Error getting technical indicators for {ticker}") + await websocket.send( + json.dumps( + { + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": str(e), + }, + ensure_ascii=False, + ), + ) + async def _handle_run_stock_enrich( self, websocket: ServerConnection, @@ -2288,6 +2561,58 @@ class Gateway: else: await self._run_live_cycle(date, tickers) + async def on_heartbeat_trigger(self, date: str): + """Run lightweight heartbeat check for all analysts. + + Each analyst reads its HEARTBEAT.md and performs a self-check + without running the full trading pipeline. + """ + logger.info(f"[Heartbeat] Running heartbeat check for {date}") + + tickers = self.config.get("tickers", []) + analysts = self.pipeline._all_analysts() + + for analyst in analysts: + try: + ws_id = getattr(analyst, "workspace_id", None) + if ws_id: + from backend.agents.workspace_manager import get_workspace_dir + ws_dir = get_workspace_dir(ws_id) + if ws_dir: + from pathlib import Path + hb_path = Path(ws_dir) / "HEARTBEAT.md" + if hb_path.exists(): + content = hb_path.read_text(encoding="utf-8").strip() + if content: + hb_task = ( + f"# 定期主动检查\n\n{content}\n\n" + "请执行上述检查并报告结果。" + ) + logger.info( + f"[Heartbeat] Running heartbeat for {analyst.name}", + ) + # Build a minimal user message and let the analyst reply + from agentscope.message import Msg + msg = Msg( + role="user", + content=hb_task, + name="system", + ) + result = await analyst.reply([msg]) + logger.info( + f"[Heartbeat] {analyst.name} heartbeat complete", + ) + continue + + logger.debug( + f"[Heartbeat] No HEARTBEAT.md for {analyst.name}, skipping", + ) + except Exception as e: + logger.error( + f"[Heartbeat] {analyst.name} failed: {e}", + exc_info=True, + ) + async def _run_backtest_cycle(self, date: str, tickers: List[str]): """Run backtest cycle with pre-loaded prices""" self.market_service.set_backtest_date(date) @@ -2428,7 +2753,8 @@ class Gateway: market_caps[ticker] = market_cap else: market_caps[ticker] = 1e9 - except Exception: + except Exception as e: + logger.warning(f"Failed to get market cap for {ticker}, using default 1e9: {e}") market_caps[ticker] = 1e9 return market_caps diff --git a/backend/services/runtime_db.py b/backend/services/runtime_db.py index a9d69e2..f7865ec 100644 --- a/backend/services/runtime_db.py +++ b/backend/services/runtime_db.py @@ -48,6 +48,14 @@ CREATE TABLE IF NOT EXISTS signals ( signal TEXT, confidence REAL, reasoning_json TEXT, + reasons_json TEXT, + risks_json TEXT, + invalidation TEXT, + next_action TEXT, + intrinsic_value REAL, + fair_value_range_json TEXT, + value_gap_pct REAL, + valuation_methods_json TEXT, real_return REAL, is_correct TEXT, trade_date TEXT, @@ -270,8 +278,10 @@ class RuntimeDb: """ INSERT OR REPLACE INTO signals (id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json, + reasons_json, risks_json, invalidation, next_action, intrinsic_value, + fair_value_range_json, value_gap_pct, valuation_methods_json, real_return, is_correct, trade_date, created_at, meta_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( signal_id, @@ -282,6 +292,14 @@ class RuntimeDb: payload.get("signal"), payload.get("confidence"), _json_dumps(payload.get("reasoning")), + _json_dumps(payload.get("reasons")), + _json_dumps(payload.get("risks")), + payload.get("invalidation"), + payload.get("next_action"), + payload.get("intrinsic_value"), + _json_dumps(payload.get("fair_value_range")), + payload.get("value_gap_pct"), + _json_dumps(payload.get("valuation_methods")), payload.get("real_return"), None if payload.get("is_correct") is None else str(payload.get("is_correct")), payload.get("date"), @@ -313,8 +331,10 @@ class RuntimeDb: """ INSERT INTO signals (id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json, + reasons_json, risks_json, invalidation, next_action, intrinsic_value, + fair_value_range_json, value_gap_pct, valuation_methods_json, real_return, is_correct, trade_date, created_at, meta_json) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( signal_id, @@ -325,6 +345,14 @@ class RuntimeDb: payload.get("signal"), payload.get("confidence"), _json_dumps(payload.get("reasoning")), + _json_dumps(payload.get("reasons")), + _json_dumps(payload.get("risks")), + payload.get("invalidation"), + payload.get("next_action"), + payload.get("intrinsic_value"), + _json_dumps(payload.get("fair_value_range")), + payload.get("value_gap_pct"), + _json_dumps(payload.get("valuation_methods")), payload.get("real_return"), None if payload.get("is_correct") is None else str(payload.get("is_correct")), payload.get("date"), @@ -461,6 +489,18 @@ class RuntimeDb: else "该信号暂未完成后验评估" ), "tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral", + # Extended signal fields + "signal": row["signal"], + "confidence": row["confidence"], + "reasoning": json.loads(row["reasoning_json"]) if row["reasoning_json"] else None, + "reasons": json.loads(row["reasons_json"]) if row["reasons_json"] else None, + "risks": json.loads(row["risks_json"]) if row["risks_json"] else None, + "invalidation": row["invalidation"], + "next_action": row["next_action"], + "intrinsic_value": row["intrinsic_value"], + "fair_value_range": json.loads(row["fair_value_range_json"]) if row["fair_value_range_json"] else None, + "value_gap_pct": row["value_gap_pct"], + "valuation_methods": json.loads(row["valuation_methods_json"]) if row["valuation_methods_json"] else None, } for row in signal_rows ] diff --git a/backend/skills/builtin/fundamental_review/SKILL.md b/backend/skills/builtin/fundamental_review/SKILL.md index 936de49..2bc4897 100644 --- a/backend/skills/builtin/fundamental_review/SKILL.md +++ b/backend/skills/builtin/fundamental_review/SKILL.md @@ -8,15 +8,42 @@ version: 1.0.0 当用户希望从公司质量、资产负债表强度、盈利能力或长期盈利韧性出发判断标的时,使用这个技能。 -## 工作流程 +## 1) When to use -1. 在形成结论前,先检查盈利能力、成长性、财务健康度和经营效率。 -2. 区分可持续的业务质量和短期噪音。 -3. 明确指出会推翻当前判断的条件。 -4. 最终给出清晰的信号、置信度和主要驱动因素。 +- 适用于需要判断“公司基本面质量是否支撑当前估值/交易观点”的任务。 +- 优先在中长期视角下使用(财务稳健性、盈利韧性、成长持续性)。 +- 当任务明确以短线事件驱动为主时,不应单独依赖本技能,应与情绪/技术信号联合。 -## 约束 +## 2) Required inputs -- 不要孤立依赖单一指标。 -- 缺失数据要明确指出。 -- 当财务质量优劣混杂时,优先给出保守结论。 +- 最少输入:`tickers`、关键财务指标(盈利、成长、偿债、效率)。 +- 推荐输入:行业背景、公司阶段、近期重大事件。 +- 若关键数据缺失(例如利润质量或现金流质量无法判断),必须在结论中显式标注“不足信息风险”,并降低置信度。 + +## 3) Decision procedure + +1. 先做四维诊断:盈利能力、成长质量、财务健康度、经营效率。 +2. 区分“结构性优势”与“周期性改善/短期噪音”。 +3. 识别关键风险与失效条件(invalidation),明确什么情况会推翻当前判断。 +4. 合成最终观点:`signal + confidence + drivers + risks`。 + +## 4) Tool call policy + +- 优先使用基本面与财务相关工具组获取证据,再形成结论。 +- 在数据完备且任务允许时,可补充估值相关工具进行交叉验证。 +- 若工具失败或返回异常:保留已验证证据,明确未验证部分,不允许伪造数据。 + +## 5) Output schema + +- `signal`: `bullish | bearish | neutral` +- `confidence`: `0-100` +- `reasons`: 2-4 条核心驱动 +- `risks`: 1-3 条关键风险 +- `invalidation`: 触发观点失效的条件 +- `next_action`: 对 PM 的可执行建议(如“仅小仓位试错/等待下一季报确认”) + +## 6) Failure fallback + +- 数据稀疏或矛盾时:默认 `neutral` 或低置信度方向结论。 +- 不允许因单一亮点指标给出高置信度信号。 +- 当财务质量优劣混杂时,优先保守结论并附加“需补充验证”的下一步建议。 diff --git a/backend/skills/builtin/portfolio_decisioning/SKILL.md b/backend/skills/builtin/portfolio_decisioning/SKILL.md index 68d3a71..5d3c31b 100644 --- a/backend/skills/builtin/portfolio_decisioning/SKILL.md +++ b/backend/skills/builtin/portfolio_decisioning/SKILL.md @@ -8,15 +8,43 @@ version: 1.0.0 当用户需要把团队分析转化为最终交易决策时,使用这个技能。 -## 工作流程 +## 1) When to use -1. 行动前先阅读分析师结论和风险警示。 -2. 评估当前组合、现金和保证金约束。 -3. 使用决策工具为每个 ticker 记录一个明确决策。 -4. 在全部决策记录完成后,总结组合层面的整体理由。 +- 适用于“最终下单前”的收口阶段:将多方观点转成单一可执行指令。 +- 必须在获取分析师观点与风险审查后触发,不应跳过上游输入。 +- 当任务只要求研究观点、未要求执行决策时,不强制触发。 -## 约束 +## 2) Required inputs -- 仓位大小必须遵守资金和保证金限制。 -- 当分析师信心与风险信号不一致时,优先采用更小仓位。 -- 当任务要求完整决策清单时,不要让任何 ticker 处于未决状态。 +- 最少输入:`analyst_signals`、`risk_warnings`、`portfolio_state`、`cash`、`margin_requirement`、`prices`。 +- 推荐输入:会议共识摘要、历史表现偏差、当前组合拥挤度。 +- 若缺失关键执行约束(现金/保证金/价格),应降级为“条件决策草案”,不可直接给激进仓位。 + +## 3) Decision procedure + +1. 汇总并比较 analyst 信号,识别共识与分歧。 +2. 将风险警示映射到仓位上限与禁开条件。 +3. 在资金与保证金约束下,为每个 ticker 生成候选动作与数量。 +4. 对冲突信号执行保守仲裁:降低仓位、提高触发门槛或改为 `hold`。 +5. 逐个 ticker 记录最终决策,并给出组合级理由。 + +## 4) Tool call policy + +- 必须使用决策工具记录每个 ticker 的最终 `action/quantity`。 +- 在讨论阶段如发现当前团队能力不足,可使用团队工具动态创建或移除 analyst(再继续讨论)。 +- 若风险工具提示阻断项,优先遵循阻断,不得绕过。 +- 工具调用失败时:重试一次;仍失败则输出结构化“未完成决策清单”和人工处理建议。 + +## 5) Output schema + +- `decisions`: 每个 ticker 的 `{action: long|short|hold, quantity, confidence, reasoning}` +- `portfolio_rationale`: 组合层面的配置逻辑与取舍依据 +- `constraint_check`: 资金、保证金、集中度是否满足 +- `conflict_resolution`: 对信号冲突的处理说明 +- `pending_items`: 未决事项与补充数据需求(若有) + +## 6) Failure fallback + +- 当分析师信号与风险结论显著冲突时,默认采用更小仓位或 `hold`。 +- 当约束校验失败(现金/保证金不足)时,自动下调数量,不输出不可执行指令。 +- 当任务要求完整清单时,不允许遗漏 ticker;无法决策时必须显式标记 `hold` 并说明原因。 diff --git a/backend/skills/builtin/risk_review/SKILL.md b/backend/skills/builtin/risk_review/SKILL.md index cee6fcb..eea1099 100644 --- a/backend/skills/builtin/risk_review/SKILL.md +++ b/backend/skills/builtin/risk_review/SKILL.md @@ -8,15 +8,41 @@ version: 1.0.0 当用户需要识别集中度、波动率、杠杆和情景风险时,使用这个技能。 -## 工作流程 +## 1) When to use -1. 按 ticker 和主题检查拟议敞口。 -2. 识别集中度、波动率、流动性和杠杆方面的风险点。 -3. 按严重程度排序风险警示。 -4. 将风险结论转化为给投资经理的具体限制或注意事项。 +- 适用于下单前风险闸门、仓位复核、组合再平衡前的约束审查。 +- 当需要把“风险观点”转成“可执行限制”时必须使用本技能。 +- 若任务仅为单纯行情解读且不涉及仓位执行,可不独立触发。 -## 约束 +## 2) Required inputs -- 聚焦可执行的风险控制措施。 -- 当数据支持时尽量量化限制。 -- 明确区分致命阻断项和可管理风险。 +- 最少输入:`portfolio positions`、`cash/margin`、`proposed decisions`、`current prices`。 +- 推荐输入:波动率指标、流动性指标、相关性/主题暴露。 +- 若缺失关键风险数据,必须输出“暂定限制”并标明待补数据项。 + +## 3) Decision procedure + +1. 按 ticker、行业主题、净敞口做集中度盘点。 +2. 评估波动、流动性与杠杆压力,识别潜在连锁风险。 +3. 将风险分级:`fatal blocker / major caution / manageable`。 +4. 将每类风险映射为明确限制(仓位上限、减仓条件、禁开仓条件)。 + +## 4) Tool call policy + +- 优先调用风险工具组量化集中度、保证金压力、波动暴露。 +- 无量化证据时,不给“无风险”结论;只能给保守警示。 +- 工具失败时应回退到规则化约束(更低仓位上限、更严格止损条件)。 + +## 5) Output schema + +- `risk_level`: `low | medium | high | critical` +- `warnings`: 按严重度排序的风险列表(含原因) +- `limits`: 可执行限制(仓位/杠杆/单票上限) +- `blockers`: 必须先解决的阻断项 +- `recommendation_to_pm`: 对 PM 的执行建议(允许/限制/禁止) + +## 6) Failure fallback + +- 关键数据缺失或工具不可用时:默认提高一级风险等级并收紧仓位限制。 +- 无法确认保证金与流动性安全时,默认禁止新增高风险敞口。 +- 明确区分“硬阻断”与“可带条件执行”的风险,避免含糊建议。 diff --git a/backend/tests/test_skills_cli.py b/backend/tests/test_skills_cli.py index 3011b3c..502aa3f 100644 --- a/backend/tests/test_skills_cli.py +++ b/backend/tests/test_skills_cli.py @@ -2,6 +2,11 @@ from backend import cli from backend.agents.skill_metadata import parse_skill_metadata from backend.agents.skills_manager import SkillsManager +from backend.agents.team_pipeline_config import ( + ensure_team_pipeline_config, + load_team_pipeline_config, + update_active_analysts, +) def test_parse_skill_metadata_extended_frontmatter(tmp_path): @@ -70,3 +75,45 @@ def test_skills_enable_disable_and_list(monkeypatch, tmp_path): assert "Enabled" in text_dump assert "Disabled" in text_dump assert any(getattr(item, "title", None) == "Skill Catalog" for item in printed) + + +def test_install_external_skill_for_agent(tmp_path): + manager = SkillsManager(project_root=tmp_path) + skill_dir = tmp_path / "downloaded" / "new_skill" + skill_dir.mkdir(parents=True, exist_ok=True) + (skill_dir / "SKILL.md").write_text( + "---\n" + "name: new_skill\n" + "description: external skill\n" + "---\n\n" + "# New Skill\n", + encoding="utf-8", + ) + + result = manager.install_external_skill_for_agent( + config_name="demo", + agent_id="risk_manager", + source=str(skill_dir), + activate=True, + ) + + assert result["skill_name"] == "new_skill" + target = manager.get_agent_local_root("demo", "risk_manager") / "new_skill" + assert target.exists() + + +def test_team_pipeline_active_analyst_updates(tmp_path): + project_root = tmp_path + ensure_team_pipeline_config( + project_root=project_root, + config_name="demo", + default_analysts=["fundamentals_analyst", "technical_analyst"], + ) + update_active_analysts( + project_root=project_root, + config_name="demo", + available_analysts=["fundamentals_analyst", "technical_analyst"], + remove=["technical_analyst"], + ) + config = load_team_pipeline_config(project_root, "demo") + assert config["discussion"]["active_analysts"] == ["fundamentals_analyst"] diff --git a/frontend/package.json b/frontend/package.json index a864bcf..be9e5a7 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -27,6 +27,7 @@ "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", "framer-motion": "^12.23.13", + "jszip": "^3.10.1", "lucide-react": "^0.544.0", "react": "^19.1.1", "react-dom": "^19.1.1", diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index cd1b75d..a243630 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -5,7 +5,7 @@ import { AGENTS, INITIAL_TICKERS } from './config/constants'; // Services import { ReadOnlyClient } from './services/websocket'; -import { startRuntime } from './services/runtimeApi'; +import { startRuntime, uploadAgentSkillZip } from './services/runtimeApi'; // Hooks import { useFeedProcessor } from './hooks/useFeedProcessor'; @@ -98,6 +98,8 @@ export default function LiveTradingApp() { const [ohlcHistoryByTicker, setOhlcHistoryByTicker] = useState({}); const [explainEventsByTicker, setExplainEventsByTicker] = useState({}); const [newsByTicker, setNewsByTicker] = useState({}); + const [insiderTradesByTicker, setInsiderTradesByTicker] = useState({}); + const [technicalIndicatorsByTicker, setTechnicalIndicatorsByTicker] = useState({}); const [selectedExplainSymbol, setSelectedExplainSymbol] = useState(''); const [historySourceByTicker, setHistorySourceByTicker] = useState({}); @@ -127,6 +129,11 @@ export default function LiveTradingApp() { const [initialCashDraft, setInitialCashDraft] = useState('100000'); const [marginRequirementDraft, setMarginRequirementDraft] = useState('0'); const [enableMemoryDraft, setEnableMemoryDraft] = useState(false); + const [modeDraft, setModeDraft] = useState('live'); + const [pollIntervalDraft, setPollIntervalDraft] = useState('10'); + const [startDateDraft, setStartDateDraft] = useState(''); + const [endDateDraft, setEndDateDraft] = useState(''); + const [enableMockDraft, setEnableMockDraft] = useState(false); const [runtimeConfigFeedback, setRuntimeConfigFeedback] = useState(null); const [isRuntimeConfigSaving, setIsRuntimeConfigSaving] = useState(false); const [selectedSkillAgentId, setSelectedSkillAgentId] = useState(AGENTS[0]?.id || 'portfolio_manager'); @@ -602,7 +609,11 @@ export default function LiveTradingApp() { initial_cash: initialCash, margin_requirement: marginRequirement, enable_memory: Boolean(enableMemoryDraft), - mode: serverMode || 'live' + mode: modeDraft || 'live', + poll_interval: Number(pollIntervalDraft) || 10, + start_date: startDateDraft || null, + end_date: endDateDraft || null, + enable_mock: Boolean(enableMockDraft) }); setIsRuntimeConfigSaving(false); @@ -630,9 +641,13 @@ export default function LiveTradingApp() { initialCashDraft, marginRequirementDraft, enableMemoryDraft, + modeDraft, + pollIntervalDraft, + startDateDraft, + endDateDraft, + enableMockDraft, watchlistDraftSymbols, watchlistInputValue, - serverMode, addSystemMessage ]); @@ -644,6 +659,11 @@ export default function LiveTradingApp() { setInitialCashDraft('100000'); setMarginRequirementDraft('0'); setEnableMemoryDraft(false); + setModeDraft('live'); + setPollIntervalDraft('10'); + setStartDateDraft(''); + setEndDateDraft(''); + setEnableMockDraft(false); setRuntimeConfigFeedback(null); }, []); @@ -862,6 +882,38 @@ export default function LiveTradingApp() { } }, [selectedSkillAgentId, selectedWorkspaceFile, workspaceDraftContent]); + const handleUploadExternalSkill = useCallback(async (file) => { + if (!(file instanceof File)) { + setAgentSkillsFeedback({ type: 'error', text: '请选择 zip 文件后再上传' }); + return; + } + if (!selectedSkillAgentId) { + setAgentSkillsFeedback({ type: 'error', text: '未选择目标 Agent' }); + return; + } + setAgentSkillsSavingKey(`${selectedSkillAgentId}:__upload__`); + setAgentSkillsFeedback(null); + try { + const result = await uploadAgentSkillZip({ + agentId: selectedSkillAgentId, + file, + activate: true + }); + setAgentSkillsFeedback({ + type: 'success', + text: `已上传并安装技能 ${result.skill_name || ''}`.trim() + }); + requestAgentSkills(selectedSkillAgentId); + } catch (error) { + setAgentSkillsFeedback({ + type: 'error', + text: `上传失败: ${error.message || '未知错误'}` + }); + } finally { + setAgentSkillsSavingKey(null); + } + }, [requestAgentSkills, selectedSkillAgentId]); + useEffect(() => { setWorkspaceDraftContent(selectedWorkspaceContent); }, [selectedWorkspaceContent]); @@ -967,6 +1019,31 @@ export default function LiveTradingApp() { }); }, []); + const requestStockInsiderTrades = useCallback((symbol, startDate = null, endDate = null) => { + const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; + if (!normalized || !clientRef.current) { + return false; + } + return clientRef.current.send({ + type: 'get_stock_insider_trades', + ticker: normalized, + start_date: startDate, + end_date: endDate, + limit: 50 + }); + }, []); + + const requestStockTechnicalIndicators = useCallback((symbol) => { + const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; + if (!normalized || !clientRef.current) { + return false; + } + return clientRef.current.send({ + type: 'get_stock_technical_indicators', + ticker: normalized + }); + }, []); + const requestStockRangeExplain = useCallback((symbol, startDate, endDate, articleIds = []) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; if (!normalized || !startDate || !endDate || !clientRef.current) { @@ -1050,13 +1127,15 @@ export default function LiveTradingApp() { }, [isLiveEnabled, chartTab]); useEffect(() => { - if (!isWatchlistPanelOpen || !isWatchlistDraftDirty) { + // Only reset when watchlist panel is closed AND runtime settings is also closed + // This prevents reset when user is editing in RuntimeSettingsPanel + if ((!isWatchlistPanelOpen && !isRuntimeSettingsOpen) || !isWatchlistDraftDirty) { setWatchlistDraftSymbols(runtimeWatchlistSymbols); - if (!isWatchlistPanelOpen) { + if (!isWatchlistPanelOpen && !isRuntimeSettingsOpen) { setWatchlistInputValue(''); } } - }, [isWatchlistDraftDirty, isWatchlistPanelOpen, runtimeWatchlistSymbols]); + }, [isWatchlistDraftDirty, isWatchlistPanelOpen, isRuntimeSettingsOpen, runtimeWatchlistSymbols]); useEffect(() => { isWatchlistSavingRef.current = isWatchlistSaving; @@ -1084,6 +1163,8 @@ export default function LiveTradingApp() { requestStockNews, requestStockNewsCategories, requestStockNewsTimeline, + requestStockInsiderTrades, + requestStockTechnicalIndicators, requestStockStory, selectedExplainSymbol ]); @@ -1682,6 +1763,32 @@ export default function LiveTradingApp() { })); }, + stock_insider_trades_loaded: (e) => { + const symbol = typeof e.ticker === 'string' ? e.ticker.trim().toUpperCase() : ''; + if (!symbol) { + return; + } + setInsiderTradesByTicker((prev) => ({ + ...prev, + [symbol]: { + trades: Array.isArray(e.trades) ? e.trades : [], + startDate: e.start_date || null, + endDate: e.end_date || null + } + })); + }, + + stock_technical_indicators_loaded: (e) => { + const symbol = typeof e.ticker === 'string' ? e.ticker.trim().toUpperCase() : ''; + if (!symbol) { + return; + } + setTechnicalIndicatorsByTicker((prev) => ({ + ...prev, + [symbol]: e.indicators || null + })); + }, + stock_range_explain_loaded: (e) => { const symbol = typeof e.ticker === 'string' ? e.ticker.trim().toUpperCase() : ''; if (!symbol) { @@ -2388,6 +2495,11 @@ export default function LiveTradingApp() { initialCash={initialCashDraft} marginRequirement={marginRequirementDraft} enableMemory={enableMemoryDraft} + mode={modeDraft} + pollInterval={pollIntervalDraft} + startDate={startDateDraft} + endDate={endDateDraft} + enableMock={enableMockDraft} watchlistSymbols={watchlistDraftSymbols} watchlistInputValue={watchlistInputValue} watchlistSuggestions={watchlistSuggestions} @@ -2400,6 +2512,11 @@ export default function LiveTradingApp() { onInitialCashChange={setInitialCashDraft} onMarginRequirementChange={setMarginRequirementDraft} onEnableMemoryChange={setEnableMemoryDraft} + onModeChange={setModeDraft} + onPollIntervalChange={setPollIntervalDraft} + onStartDateChange={setStartDateDraft} + onEndDateChange={setEndDateDraft} + onEnableMockChange={setEnableMockDraft} onWatchlistInputChange={handleWatchlistInputChange} onWatchlistInputKeyDown={handleWatchlistInputKeyDown} onWatchlistAdd={() => commitWatchlistInput(watchlistInputValue)} @@ -2539,6 +2656,7 @@ export default function LiveTradingApp() { onWorkspaceFileChange={handleWorkspaceFileChange} onWorkspaceDraftChange={setWorkspaceDraftContent} onWorkspaceFileSave={handleWorkspaceFileSave} + onUploadExternalSkill={handleUploadExternalSkill} /> @@ -2573,9 +2691,13 @@ export default function LiveTradingApp() { selectedHistorySource={historySourceByTicker[selectedExplainSymbol] || null} explainEventsSnapshot={explainEventsByTicker[selectedExplainSymbol] || null} newsSnapshot={newsByTicker[selectedExplainSymbol] || null} + insiderTradesSnapshot={insiderTradesByTicker[selectedExplainSymbol] || null} + technicalIndicatorsSnapshot={technicalIndicatorsByTicker[selectedExplainSymbol] || null} onRequestRangeExplain={requestStockRangeExplain} onRequestNewsForDate={requestStockNewsForDate} onRequestStory={requestStockStory} + onRequestInsiderTrades={requestStockInsiderTrades} + onRequestTechnicalIndicators={requestStockTechnicalIndicators} currentDate={currentDate} onRequestSimilarDays={requestStockSimilarDays} onRequestStockEnrich={requestStockEnrich} diff --git a/frontend/src/components/AgentCard.jsx b/frontend/src/components/AgentCard.jsx index 0a456a9..bdcb041 100644 --- a/frontend/src/components/AgentCard.jsx +++ b/frontend/src/components/AgentCard.jsx @@ -41,6 +41,7 @@ export default function AgentCard({ agent, onClose, isClosing }) { const rankMedal = agent.rank ? getRankMedal(agent.rank) : null; const isPortfolioManager = agent.id === 'portfolio_manager'; const isRiskManager = agent.id === 'risk_manager'; + const isValuationAnalyst = agent.id === 'valuation_analyst'; const displayName = isPortfolioManager ? '团队' : agent.name; // Get model icon configuration @@ -483,6 +484,78 @@ export default function AgentCard({ agent, onClose, isClosing }) { )} + + {/* Valuation Results Card - Only show for valuation_analyst */} + {isValuationAnalyst && agent.signals && agent.signals.length > 0 && ( +
+ {JSON.stringify(runtimeState.context.team_config, null, 2)}
+
+