feat: Add evaluation hooks, skill adaptation and team pipeline config
- Add EvaluationHook for post-execution agent evaluation - Add SkillAdaptationHook for dynamic skill adaptation - Add team/ directory with team coordination logic - Add TEAM_PIPELINE.yaml for smoke_fullstack pipeline config - Update RuntimeView, TraderView and RuntimeSettingsPanel UI - Add runtimeApi and websocket services - Add runtime_state.json to smoke_fullstack state Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,6 +13,26 @@ from .command_handler import (
|
||||
create_command_dispatcher,
|
||||
)
|
||||
|
||||
# 评估钩子 (从evaluation_hook.py导入)
|
||||
from .evaluation_hook import (
|
||||
EvaluationHook,
|
||||
EvaluationCollector,
|
||||
MetricType,
|
||||
EvaluationMetric,
|
||||
EvaluationResult,
|
||||
parse_evaluation_hooks,
|
||||
)
|
||||
|
||||
# 技能适配钩子 (从skill_adaptation_hook.py导入)
|
||||
from .skill_adaptation_hook import (
|
||||
AdaptationAction,
|
||||
AdaptationThreshold,
|
||||
AdaptationEvent,
|
||||
SkillAdaptationHook,
|
||||
AdaptationManager,
|
||||
get_adaptation_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 命令处理
|
||||
"AgentCommandDispatcher",
|
||||
@@ -20,4 +40,18 @@ __all__ = [
|
||||
"CommandHandler",
|
||||
"CommandResult",
|
||||
"create_command_dispatcher",
|
||||
# 评估钩子
|
||||
"EvaluationHook",
|
||||
"EvaluationCollector",
|
||||
"MetricType",
|
||||
"EvaluationMetric",
|
||||
"EvaluationResult",
|
||||
"parse_evaluation_hooks",
|
||||
# 技能适配钩子
|
||||
"AdaptationAction",
|
||||
"AdaptationThreshold",
|
||||
"AdaptationEvent",
|
||||
"SkillAdaptationHook",
|
||||
"AdaptationManager",
|
||||
"get_adaptation_manager",
|
||||
]
|
||||
|
||||
@@ -27,6 +27,7 @@ from .hooks import (
|
||||
HookManager,
|
||||
BootstrapHook,
|
||||
MemoryCompactionHook,
|
||||
WorkspaceWatchHook,
|
||||
HOOK_PRE_REASONING,
|
||||
)
|
||||
from ..prompts.builder import (
|
||||
@@ -36,6 +37,16 @@ from ..prompts.builder import (
|
||||
from ..agent_workspace import load_agent_workspace_config
|
||||
from ..skills_manager import SkillsManager
|
||||
|
||||
# Team infrastructure imports (graceful import - may not exist yet)
|
||||
try:
|
||||
from backend.agents.team.messenger import AgentMessenger
|
||||
from backend.agents.team.task_delegator import TaskDelegator
|
||||
TEAM_INFRA_AVAILABLE = True
|
||||
except ImportError:
|
||||
TEAM_INFRA_AVAILABLE = False
|
||||
AgentMessenger = None
|
||||
TaskDelegator = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentscope.formatter import FormatterBase
|
||||
from agentscope.model import ModelWrapperBase
|
||||
@@ -152,6 +163,12 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
memory_compact_threshold=memory_compact_threshold,
|
||||
)
|
||||
|
||||
# Initialize team infrastructure if available
|
||||
self._messenger: Optional["AgentMessenger"] = None
|
||||
self._task_delegator: Optional["TaskDelegator"] = None
|
||||
if TEAM_INFRA_AVAILABLE:
|
||||
self._init_team_infrastructure()
|
||||
|
||||
logger.info(
|
||||
"EvoAgent initialized: %s (workspace: %s)",
|
||||
agent_id,
|
||||
@@ -268,6 +285,17 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
)
|
||||
logger.debug("Registered memory compaction hook")
|
||||
|
||||
# Workspace watch hook - auto-reload markdown files on change
|
||||
workspace_watch_hook = WorkspaceWatchHook(
|
||||
workspace_dir=self.workspace_dir,
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="workspace_watch",
|
||||
hook=workspace_watch_hook,
|
||||
)
|
||||
logger.debug("Registered workspace watch hook")
|
||||
|
||||
async def _reasoning(self, **kwargs) -> Msg:
|
||||
"""Override reasoning to execute pre-reasoning hooks.
|
||||
|
||||
@@ -405,7 +433,78 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
)
|
||||
]),
|
||||
"registered_hooks": self._hook_manager.list_hooks(),
|
||||
"team_infra_available": TEAM_INFRA_AVAILABLE,
|
||||
}
|
||||
|
||||
def _init_team_infrastructure(self) -> None:
|
||||
"""Initialize team infrastructure components (messenger and task delegator).
|
||||
|
||||
This method initializes the AgentMessenger for inter-agent communication
|
||||
and the TaskDelegator for subagent delegation.
|
||||
"""
|
||||
if not TEAM_INFRA_AVAILABLE:
|
||||
return
|
||||
|
||||
try:
|
||||
self._messenger = AgentMessenger(agent_id=self.agent_id)
|
||||
self._task_delegator = TaskDelegator(agent=self)
|
||||
logger.debug(
|
||||
"Team infrastructure initialized for agent: %s",
|
||||
self.agent_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize team infrastructure for %s: %s",
|
||||
self.agent_id,
|
||||
e,
|
||||
)
|
||||
self._messenger = None
|
||||
self._task_delegator = None
|
||||
|
||||
@property
|
||||
def messenger(self) -> Optional["AgentMessenger"]:
|
||||
"""Get the agent's messenger for inter-agent communication.
|
||||
|
||||
Returns:
|
||||
AgentMessenger instance if available, None otherwise
|
||||
"""
|
||||
return self._messenger
|
||||
|
||||
def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Delegate a task to a subagent using the TaskDelegator.
|
||||
|
||||
Args:
|
||||
task_type: Type of task to delegate
|
||||
task_data: Data/payload for the task
|
||||
target_agent: Optional specific agent ID to delegate to
|
||||
|
||||
Returns:
|
||||
Dict containing the delegation result
|
||||
"""
|
||||
if not TEAM_INFRA_AVAILABLE or self._task_delegator is None:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Team infrastructure not available",
|
||||
}
|
||||
|
||||
try:
|
||||
return self._task_delegator.delegate_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=target_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Task delegation failed for %s: %s",
|
||||
self.agent_id,
|
||||
e,
|
||||
)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
__all__ = ["EvoAgent"]
|
||||
|
||||
@@ -284,19 +284,120 @@ class BootstrapHook(Hook):
|
||||
return None
|
||||
|
||||
|
||||
class WorkspaceWatchHook(Hook):
|
||||
"""Hook for auto-reloading workspace markdown files on change.
|
||||
|
||||
Monitors SOUL.md, AGENTS.md, PROFILE.md, etc. and triggers
|
||||
a prompt rebuild when any of them change. Based on CoPaw's
|
||||
AgentConfigWatcher approach but for markdown files.
|
||||
"""
|
||||
|
||||
# Files to monitor (same as PromptBuilder.DEFAULT_FILES)
|
||||
WATCHED_FILES = frozenset([
|
||||
"SOUL.md", "AGENTS.md", "PROFILE.md", "ROLE.md",
|
||||
"POLICY.md", "MEMORY.md", "HEARTBEAT.md", "STYLE.md",
|
||||
"BOOTSTRAP.md",
|
||||
])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
poll_interval: float = 2.0,
|
||||
):
|
||||
"""Initialize workspace watch hook.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory to monitor
|
||||
poll_interval: How often to check for changes (seconds)
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self.poll_interval = poll_interval
|
||||
self._last_mtimes: dict[str, float] = {}
|
||||
self._initialized = False
|
||||
|
||||
def _scan_mtimes(self) -> dict[str, float]:
|
||||
"""Scan watched files and return their current mtimes."""
|
||||
mtimes = {}
|
||||
for name in self.WATCHED_FILES:
|
||||
path = self.workspace_dir / name
|
||||
if path.exists():
|
||||
mtimes[name] = path.stat().st_mtime
|
||||
return mtimes
|
||||
|
||||
def _has_changes(self) -> bool:
|
||||
"""Check if any watched file has changed since last check."""
|
||||
current = self._scan_mtimes()
|
||||
|
||||
if not self._initialized:
|
||||
self._last_mtimes = current
|
||||
self._initialized = True
|
||||
return False
|
||||
|
||||
# Check for new, modified, or deleted files
|
||||
if set(current.keys()) != set(self._last_mtimes.keys()):
|
||||
self._last_mtimes = current
|
||||
return True
|
||||
|
||||
for name, mtime in current.items():
|
||||
if mtime != self._last_mtimes.get(name):
|
||||
self._last_mtimes = current
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Check for file changes and rebuild prompt if needed.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments (unused)
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self._has_changes():
|
||||
logger.info(
|
||||
"Workspace files changed, triggering prompt rebuild for: %s",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
if hasattr(agent, "rebuild_sys_prompt"):
|
||||
agent.rebuild_sys_prompt()
|
||||
else:
|
||||
logger.warning(
|
||||
"Agent %s has no rebuild_sys_prompt method",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Workspace watch hook failed: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MemoryCompactionHook(Hook):
|
||||
"""Hook for automatic memory compaction when context is full.
|
||||
|
||||
This hook monitors the token count of messages and triggers compaction
|
||||
when it exceeds the threshold. It preserves the system prompt and recent
|
||||
messages while summarizing older conversation history.
|
||||
|
||||
Based on CoPaw's memory compaction design with additional improvements:
|
||||
- memory_compact_ratio: Ratio to compact when threshold reached
|
||||
- memory_reserve_ratio: Always keep a reserve of tokens for recent messages
|
||||
- enable_tool_result_compact: Compact tool results separately
|
||||
- tool_result_compact_keep_n: Number of tool results to keep
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_manager: Any,
|
||||
memory_compact_threshold: Optional[int] = None,
|
||||
memory_compact_reserve: Optional[int] = None,
|
||||
memory_compact_ratio: float = 0.75,
|
||||
memory_reserve_ratio: float = 0.1,
|
||||
enable_tool_result_compact: bool = False,
|
||||
tool_result_compact_keep_n: int = 5,
|
||||
):
|
||||
@@ -305,13 +406,15 @@ class MemoryCompactionHook(Hook):
|
||||
Args:
|
||||
memory_manager: Memory manager instance for compaction
|
||||
memory_compact_threshold: Token threshold for compaction
|
||||
memory_compact_reserve: Reserve tokens for recent messages
|
||||
memory_compact_ratio: Target ratio to compact to (e.g., 0.75 = compact to 75%)
|
||||
memory_reserve_ratio: Reserve ratio to always keep free (e.g., 0.1 = 10%)
|
||||
enable_tool_result_compact: Enable tool result compaction
|
||||
tool_result_compact_keep_n: Number of tool results to keep
|
||||
"""
|
||||
self.memory_manager = memory_manager
|
||||
self.memory_compact_threshold = memory_compact_threshold
|
||||
self.memory_compact_reserve = memory_compact_reserve
|
||||
self.memory_compact_ratio = memory_compact_ratio
|
||||
self.memory_reserve_ratio = memory_reserve_ratio
|
||||
self.enable_tool_result_compact = enable_tool_result_compact
|
||||
self.tool_result_compact_keep_n = tool_result_compact_keep_n
|
||||
|
||||
@@ -382,32 +485,61 @@ class MemoryCompactionHook(Hook):
|
||||
) -> None:
|
||||
"""Compact memory by summarizing older messages.
|
||||
|
||||
Uses CoPaw-style memory management:
|
||||
- memory_compact_ratio: Target ratio to compact to (e.g., 0.75 means compact to 75%)
|
||||
- memory_reserve_ratio: Always keep this ratio free (e.g., 0.1 means keep 10% for recent)
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current messages in memory
|
||||
"""
|
||||
if self.memory_compact_reserve is None:
|
||||
if self.memory_compact_threshold is None:
|
||||
return
|
||||
|
||||
# Keep recent messages
|
||||
keep_count = min(
|
||||
len(messages) // 4,
|
||||
10, # Max 10 recent messages
|
||||
)
|
||||
keep_count = max(keep_count, 2) # At least 2
|
||||
# Estimate total tokens
|
||||
total_tokens = self._estimate_tokens(messages)
|
||||
|
||||
messages_to_compact = messages[:-keep_count] if keep_count < len(messages) else []
|
||||
# Calculate reserve based on ratio (CoPaw-style)
|
||||
reserve_tokens = int(total_tokens * self.memory_reserve_ratio)
|
||||
|
||||
# Calculate target tokens after compaction
|
||||
target_tokens = int(total_tokens * self.memory_compact_ratio)
|
||||
target_tokens = max(target_tokens, total_tokens - reserve_tokens)
|
||||
|
||||
# Find messages to compact (older ones)
|
||||
# Keep recent messages that fit within target
|
||||
messages_to_compact = []
|
||||
kept_tokens = 0
|
||||
|
||||
# Start from oldest, stop when we've kept enough
|
||||
for msg in messages:
|
||||
msg_tokens = self._estimate_tokens([msg])
|
||||
if kept_tokens + msg_tokens > target_tokens:
|
||||
messages_to_compact.append(msg)
|
||||
else:
|
||||
kept_tokens += msg_tokens
|
||||
|
||||
if not messages_to_compact:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Compacting %d messages (%d tokens) to target %d tokens",
|
||||
len(messages_to_compact),
|
||||
self._estimate_tokens(messages_to_compact),
|
||||
target_tokens,
|
||||
)
|
||||
|
||||
# Use memory manager to compact if available
|
||||
if hasattr(self.memory_manager, "compact_memory"):
|
||||
try:
|
||||
summary = await self.memory_manager.compact_memory(
|
||||
messages=messages_to_compact,
|
||||
)
|
||||
logger.info("Memory compacted: %d messages summarized", len(messages_to_compact))
|
||||
logger.info(
|
||||
"Memory compacted: %d messages summarized, summary: %s",
|
||||
len(messages_to_compact),
|
||||
summary[:200] if summary else "N/A",
|
||||
)
|
||||
|
||||
# Mark messages as compressed if supported
|
||||
if hasattr(agent.memory, "update_messages_mark"):
|
||||
@@ -420,6 +552,142 @@ class MemoryCompactionHook(Hook):
|
||||
except Exception as e:
|
||||
logger.error("Memory manager compaction failed: %s", e)
|
||||
|
||||
# Tool result compaction (CoPaw-style)
|
||||
if self.enable_tool_result_compact:
|
||||
await self._compact_tool_results(agent, messages)
|
||||
|
||||
async def _compact_tool_results(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
messages: List[Any],
|
||||
) -> None:
|
||||
"""Compact tool results by keeping only recent ones.
|
||||
|
||||
Based on CoPaw's tool_result_compact_keep_n pattern.
|
||||
Tool results can be very verbose, so we keep only the N most recent ones.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current messages in memory
|
||||
"""
|
||||
if not hasattr(agent.memory, "content"):
|
||||
return
|
||||
|
||||
# Find tool result messages (usually have "tool" role or tool_related content)
|
||||
tool_results = []
|
||||
for msg, _ in agent.memory.content:
|
||||
if hasattr(msg, "role") and msg.role == "tool":
|
||||
tool_results.append(msg)
|
||||
|
||||
if len(tool_results) <= self.tool_result_compact_keep_n:
|
||||
return
|
||||
|
||||
# Keep only the most recent N tool results
|
||||
excess_results = tool_results[:-self.tool_result_compact_keep_n]
|
||||
|
||||
logger.info(
|
||||
"Tool result compaction: %d tool results found, keeping %d, compacting %d",
|
||||
len(tool_results),
|
||||
self.tool_result_compact_keep_n,
|
||||
len(excess_results),
|
||||
)
|
||||
|
||||
# Mark excess tool results as compressed if supported
|
||||
if hasattr(agent.memory, "update_messages_mark"):
|
||||
from agentscope.agent._react_agent import _MemoryMark
|
||||
await agent.memory.update_messages_mark(
|
||||
new_mark=_MemoryMark.COMPRESSED,
|
||||
msg_ids=[msg.id for msg in excess_results],
|
||||
)
|
||||
|
||||
|
||||
class HeartbeatHook(Hook):
|
||||
"""Pre-reasoning hook that injects HEARTBEAT.md content.
|
||||
|
||||
Reads the agent's HEARTBEAT.md file and prepends it to the
|
||||
reasoning input, causing the agent to perform self-checks.
|
||||
|
||||
This enables "主动检查" (proactive monitoring) - periodic
|
||||
market condition and position checks during trading hours.
|
||||
"""
|
||||
|
||||
HEARTBEAT_FILE = "HEARTBEAT.md"
|
||||
|
||||
def __init__(self, workspace_dir: Path):
|
||||
"""Initialize heartbeat hook.
|
||||
|
||||
Args:
|
||||
workspace_dir: Working directory containing HEARTBEAT.md
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self._completed_flag = self.workspace_dir / ".heartbeat_completed"
|
||||
|
||||
def _read_heartbeat_content(self) -> Optional[str]:
|
||||
"""Read HEARTBEAT.md if it exists and is non-empty.
|
||||
|
||||
Returns:
|
||||
The HEARTBEAT.md content stripped of whitespace, or None
|
||||
if the file is absent or empty.
|
||||
"""
|
||||
hb_path = self.workspace_dir / self.HEARTBEAT_FILE
|
||||
if not hb_path.exists():
|
||||
return None
|
||||
content = hb_path.read_text(encoding="utf-8").strip()
|
||||
return content if content else None
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Prepend heartbeat task to user message.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the _reasoning method
|
||||
|
||||
Returns:
|
||||
Modified kwargs with heartbeat content prepended, or None
|
||||
if no HEARTBEAT.md content is available.
|
||||
"""
|
||||
try:
|
||||
content = self._read_heartbeat_content()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
logger.debug(
|
||||
"Heartbeat: found HEARTBEAT.md for agent %s",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
|
||||
# Build heartbeat task instruction (Chinese)
|
||||
hb_task = (
|
||||
"# 定期主动检查\n\n"
|
||||
f"{content}\n\n"
|
||||
"请执行上述检查并报告结果。"
|
||||
)
|
||||
|
||||
# Inject into the first user message in memory
|
||||
if hasattr(agent, "memory") and agent.memory.content:
|
||||
system_count = sum(
|
||||
1 for msg, _ in agent.memory.content if msg.role == "system"
|
||||
)
|
||||
for msg, _ in agent.memory.content[system_count:]:
|
||||
if msg.role == "user":
|
||||
original_content = msg.content
|
||||
msg.content = hb_task + "\n\n" + original_content
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Heartbeat task prepended for agent %s",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Heartbeat hook failed: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Hook",
|
||||
@@ -428,5 +696,7 @@ __all__ = [
|
||||
"HOOK_PRE_REASONING",
|
||||
"HOOK_POST_ACTING",
|
||||
"BootstrapHook",
|
||||
"HeartbeatHook",
|
||||
"MemoryCompactionHook",
|
||||
"WorkspaceWatchHook",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent Factory - Dynamic creation and management of EvoAgents."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -8,6 +9,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
@@ -342,9 +345,8 @@ class AgentFactory:
|
||||
"agent_type": config.get("agent_type", "unknown"),
|
||||
"config_path": str(config_path),
|
||||
})
|
||||
except Exception:
|
||||
# Skip invalid agent configs
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ Portfolio Manager Agent - Based on AgentScope ReActAgent
|
||||
Responsible for decision-making (NOT trade execution)
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
@@ -13,6 +14,8 @@ from agentscope.tool import Toolkit, ToolResponse
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
||||
from .team_pipeline_config import update_active_analysts
|
||||
from ..config.constants import ANALYST_TYPES
|
||||
|
||||
|
||||
class PMAgent(ReActAgent):
|
||||
@@ -61,6 +64,8 @@ class PMAgent(ReActAgent):
|
||||
"_toolkit_factory_kwargs",
|
||||
toolkit_factory_kwargs,
|
||||
)
|
||||
object.__setattr__(self, "_create_team_agent_cb", None)
|
||||
object.__setattr__(self, "_remove_team_agent_cb", None)
|
||||
|
||||
# Create toolkit after local state is ready so bound tool methods can be registered.
|
||||
if toolkit is None:
|
||||
@@ -152,6 +157,107 @@ class PMAgent(ReActAgent):
|
||||
],
|
||||
)
|
||||
|
||||
def _add_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||
"""Add one analyst to active discussion team."""
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
add=[agent_id],
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"Active analyst team updated. Added: {agent_id}. "
|
||||
f"Current active analysts: {', '.join(active)}"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _remove_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||
"""Remove one analyst from active discussion team."""
|
||||
callback_msg = ""
|
||||
callback = self._remove_team_agent_cb
|
||||
if callback is not None:
|
||||
callback_msg = callback(agent_id=agent_id)
|
||||
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
remove=[agent_id],
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"Active analyst team updated. Removed: {agent_id}. "
|
||||
f"Current active analysts: {', '.join(active)}"
|
||||
+ (f" | {callback_msg}" if callback_msg else "")
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _set_active_analysts(self, agent_ids: str) -> ToolResponse:
|
||||
"""Set active analysts from comma-separated agent ids."""
|
||||
requested = [
|
||||
item.strip() for item in str(agent_ids or "").split(",") if item.strip()
|
||||
]
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
set_to=requested,
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Active analyst team set to: {', '.join(active)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _create_team_analyst(self, agent_id: str, analyst_type: str) -> ToolResponse:
|
||||
"""Create a runtime analyst instance and activate it."""
|
||||
callback = self._create_team_agent_cb
|
||||
if callback is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="Runtime agent creation is not available in current pipeline.",
|
||||
),
|
||||
],
|
||||
)
|
||||
result = callback(agent_id=agent_id, analyst_type=analyst_type)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(type="text", text=result),
|
||||
],
|
||||
)
|
||||
|
||||
def set_team_controller(
|
||||
self,
|
||||
*,
|
||||
create_agent_callback: Optional[Callable[..., str]] = None,
|
||||
remove_agent_callback: Optional[Callable[..., str]] = None,
|
||||
) -> None:
|
||||
"""Inject runtime team lifecycle callbacks from pipeline."""
|
||||
object.__setattr__(self, "_create_team_agent_cb", create_agent_callback)
|
||||
object.__setattr__(self, "_remove_team_agent_cb", remove_agent_callback)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Make investment decisions
|
||||
|
||||
@@ -50,7 +50,13 @@ def build_agent_system_prompt(
|
||||
toolkit: Any,
|
||||
analyst_type: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the final system prompt for an agent."""
|
||||
"""Build the final system prompt for an agent.
|
||||
|
||||
Always reads fresh from disk — no caching.
|
||||
"""
|
||||
# Clear any cached templates before building (CoPaw-style, no caching)
|
||||
_prompt_loader.clear_cache()
|
||||
|
||||
sections: list[str] = []
|
||||
canonical_agent_id = (
|
||||
"portfolio_manager"
|
||||
|
||||
@@ -27,10 +27,6 @@ class PromptLoader:
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
|
||||
# Cache loaded prompts
|
||||
self._prompt_cache: Dict[str, str] = {}
|
||||
self._yaml_cache: Dict[str, Dict] = {}
|
||||
|
||||
def load_prompt(
|
||||
self,
|
||||
agent_type: str,
|
||||
@@ -38,37 +34,20 @@ class PromptLoader:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Load and render Prompt
|
||||
Load and render Prompt.
|
||||
|
||||
Args:
|
||||
agent_type: Agent type (analyst, portfolio_manager, risk_manager)
|
||||
prompt_name: Prompt file name (without extension)
|
||||
variables: Variable dictionary for rendering Prompt
|
||||
|
||||
Returns:
|
||||
Rendered prompt string
|
||||
|
||||
Examples:
|
||||
loader = PromptLoader()
|
||||
prompt = loader.load_prompt("analyst", "tool_selection",
|
||||
{"analyst_persona": "Technical Analyst"})
|
||||
No caching — always reads fresh from disk (CoPaw-style).
|
||||
"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||
|
||||
# Try to load from cache
|
||||
if cache_key not in self._prompt_cache:
|
||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Prompt file not found: {prompt_path}\n"
|
||||
f"Please create the prompt file or check the path.",
|
||||
)
|
||||
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Prompt file not found: {prompt_path}\n"
|
||||
f"Please create the prompt file or check the path.",
|
||||
)
|
||||
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
self._prompt_cache[cache_key] = f.read()
|
||||
|
||||
prompt_template = self._prompt_cache[cache_key]
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
prompt_template = f.read()
|
||||
|
||||
# If variables provided, use simple string replacement
|
||||
if variables:
|
||||
@@ -76,8 +55,6 @@ class PromptLoader:
|
||||
else:
|
||||
rendered = prompt_template
|
||||
|
||||
# Smart escaping: escape braces in JSON code blocks
|
||||
# rendered = self._escape_json_braces(rendered)
|
||||
return rendered
|
||||
|
||||
def _render_template(
|
||||
@@ -140,45 +117,26 @@ class PromptLoader:
|
||||
config_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load YAML configuration file
|
||||
Load YAML configuration file.
|
||||
|
||||
Args:
|
||||
agent_type: Agent type
|
||||
config_name: Configuration file name (without extension)
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
|
||||
Examples:
|
||||
>>> loader = PromptLoader()
|
||||
>>> config = loader.load_yaml_config("analyst", "personas")
|
||||
No caching — always reads fresh from disk (CoPaw-style).
|
||||
"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||
|
||||
if cache_key not in self._yaml_cache:
|
||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
self._yaml_cache[cache_key] = yaml.safe_load(f)
|
||||
|
||||
return self._yaml_cache[cache_key]
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cache (for hot reload)"""
|
||||
self._prompt_cache.clear()
|
||||
self._yaml_cache.clear()
|
||||
"""No-op — caching removed (CoPaw-style, always fresh reads)."""
|
||||
pass
|
||||
|
||||
def reload_prompt(self, agent_type: str, prompt_name: str):
|
||||
"""Reload specified prompt (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
if cache_key in self._prompt_cache:
|
||||
del self._prompt_cache[cache_key]
|
||||
"""No-op — caching removed."""
|
||||
pass
|
||||
|
||||
def reload_config(self, agent_type: str, config_name: str):
|
||||
"""Reload specified configuration (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
if cache_key in self._yaml_cache:
|
||||
del self._yaml_cache[cache_key]
|
||||
"""No-op — caching removed."""
|
||||
pass
|
||||
|
||||
@@ -19,6 +19,8 @@ class SkillMetadata:
|
||||
description: str
|
||||
version: str = ""
|
||||
tools: List[str] = field(default_factory=list)
|
||||
allowed_tools: List[str] = field(default_factory=list)
|
||||
denied_tools: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata:
|
||||
@@ -60,6 +62,8 @@ def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata:
|
||||
description=description,
|
||||
version=str(frontmatter.get("version") or "").strip(),
|
||||
tools=_string_list(frontmatter.get("tools")),
|
||||
allowed_tools=_string_list(frontmatter.get("allowed_tools")),
|
||||
denied_tools=_string_list(frontmatter.get("denied_tools")),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,14 +3,29 @@
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Dict, Iterable, List
|
||||
import tempfile
|
||||
import zipfile
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import yaml
|
||||
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata
|
||||
from backend.agents.skill_loader import validate_skill
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
|
||||
try:
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler, FileSystemEvent
|
||||
WATCHDOG_AVAILABLE = True
|
||||
except ImportError:
|
||||
WATCHDOG_AVAILABLE = False
|
||||
Observer = None
|
||||
FileSystemEventHandler = object
|
||||
FileSystemEvent = object # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Sync named skills into a run-scoped active skills workspace."""
|
||||
@@ -178,6 +193,57 @@ class SkillsManager:
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
def install_external_skill_for_agent(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
source: str,
|
||||
*,
|
||||
skill_name: str | None = None,
|
||||
activate: bool = True,
|
||||
) -> Dict[str, object]:
|
||||
"""
|
||||
Install an external skill into one agent's local skill space.
|
||||
|
||||
Supports:
|
||||
- local skill directory containing SKILL.md
|
||||
- local zip archive containing one skill directory
|
||||
- http(s) URL to zip archive
|
||||
"""
|
||||
source_path = self._resolve_external_source_path(source)
|
||||
skill_dir = self._resolve_external_skill_dir(source_path)
|
||||
metadata = parse_skill_metadata(skill_dir, source="external")
|
||||
final_name = _normalize_skill_name(skill_name or metadata.skill_name or skill_dir.name)
|
||||
if not final_name:
|
||||
raise ValueError("Could not determine skill name from external source.")
|
||||
|
||||
target_dir = self.get_agent_local_root(config_name, agent_id) / final_name
|
||||
target_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_dir.exists():
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(skill_dir, target_dir)
|
||||
|
||||
validation = validate_skill(target_dir)
|
||||
if not validation.get("valid", False):
|
||||
shutil.rmtree(target_dir, ignore_errors=True)
|
||||
raise ValueError(
|
||||
"Installed skill is invalid: "
|
||||
+ "; ".join(validation.get("errors", []))
|
||||
)
|
||||
|
||||
if activate:
|
||||
self.update_agent_skill_overrides(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
enable=[final_name],
|
||||
)
|
||||
return {
|
||||
"skill_name": final_name,
|
||||
"target_dir": str(target_dir),
|
||||
"activated": activate,
|
||||
"warnings": validation.get("warnings", []),
|
||||
}
|
||||
|
||||
def update_agent_local_skill(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -239,6 +305,58 @@ class SkillsManager:
|
||||
"content": body,
|
||||
}
|
||||
|
||||
def _resolve_external_source_path(self, source: str) -> Path:
|
||||
"""Resolve source into a local path; download URL when needed."""
|
||||
parsed = urlparse(source)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
suffix = Path(parsed.path).suffix or ".zip"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
temp_path = Path(tmp.name)
|
||||
urlretrieve(source, temp_path)
|
||||
return temp_path
|
||||
return Path(source).expanduser().resolve()
|
||||
|
||||
def _resolve_external_skill_dir(self, source_path: Path) -> Path:
|
||||
"""Resolve external source path to a skill directory containing SKILL.md."""
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Source does not exist: {source_path}")
|
||||
|
||||
if source_path.is_dir():
|
||||
if (source_path / "SKILL.md").exists():
|
||||
return source_path
|
||||
children = [
|
||||
item for item in source_path.iterdir()
|
||||
if item.is_dir() and (item / "SKILL.md").exists()
|
||||
]
|
||||
if len(children) == 1:
|
||||
return children[0]
|
||||
raise ValueError(
|
||||
"Source directory must contain SKILL.md "
|
||||
"or exactly one child directory containing SKILL.md."
|
||||
)
|
||||
|
||||
if source_path.suffix.lower() != ".zip":
|
||||
raise ValueError("External source file must be a .zip archive.")
|
||||
|
||||
temp_root = Path(tempfile.mkdtemp(prefix="external_skill_"))
|
||||
with zipfile.ZipFile(source_path, "r") as archive:
|
||||
archive.extractall(temp_root)
|
||||
|
||||
candidates = [
|
||||
item.parent
|
||||
for item in temp_root.rglob("SKILL.md")
|
||||
if item.is_file()
|
||||
]
|
||||
unique = []
|
||||
for item in candidates:
|
||||
if item not in unique:
|
||||
unique.append(item)
|
||||
if len(unique) != 1:
|
||||
raise ValueError(
|
||||
"Zip archive must contain exactly one skill directory with SKILL.md."
|
||||
)
|
||||
return unique[0]
|
||||
|
||||
def update_agent_skill_overrides(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -500,6 +618,7 @@ class SkillsManager:
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
auto_reload: bool = False,
|
||||
) -> Dict[str, List[Path]]:
|
||||
"""Resolve all agent skills into per-agent installed/active workspaces."""
|
||||
resolved: Dict[str, List[str]] = {}
|
||||
@@ -574,6 +693,9 @@ class SkillsManager:
|
||||
skill_sources=disabled_sources,
|
||||
)
|
||||
|
||||
if auto_reload:
|
||||
self.watch_active_skills(config_name, agent_defaults)
|
||||
|
||||
return active_map
|
||||
|
||||
def _is_shared_skill(self, skill_name: str) -> bool:
|
||||
@@ -583,6 +705,72 @@ class SkillsManager:
|
||||
return False
|
||||
return True
|
||||
|
||||
def watch_active_skills(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
callback: Optional[Any] = None,
|
||||
) -> "_SkillsWatcher":
|
||||
"""Start file system monitoring on active skill directories.
|
||||
|
||||
Args:
|
||||
config_name: Run configuration name.
|
||||
agent_defaults: Map of agent_id -> default skill names.
|
||||
callback: Optional callable invoked on file changes with
|
||||
(changed_paths: List[Path]).
|
||||
|
||||
Returns:
|
||||
A _SkillsWatcher instance. Call .stop() to halt monitoring.
|
||||
"""
|
||||
if not WATCHDOG_AVAILABLE:
|
||||
raise ImportError(
|
||||
"watchdog is required for watch_active_skills. "
|
||||
"Install it with: pip install watchdog"
|
||||
)
|
||||
|
||||
watched_paths: List[Path] = []
|
||||
for agent_id in agent_defaults:
|
||||
active_root = self.get_agent_active_root(config_name, agent_id)
|
||||
if active_root.exists():
|
||||
watched_paths.append(active_root)
|
||||
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||
if local_root.exists():
|
||||
watched_paths.append(local_root)
|
||||
|
||||
handler = _SkillsChangeHandler(watched_paths, callback)
|
||||
observer = Observer()
|
||||
for path in watched_paths:
|
||||
observer.schedule(handler, str(path), recursive=True)
|
||||
observer.start()
|
||||
return _SkillsWatcher(observer, handler)
|
||||
|
||||
def reload_skills_if_changed(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
) -> Dict[str, List[Path]]:
|
||||
"""Check for file changes and reload active skills if needed.
|
||||
|
||||
Args:
|
||||
config_name: Run configuration name.
|
||||
agent_defaults: Map of agent_id -> default skill names.
|
||||
|
||||
Returns:
|
||||
Map of agent_id -> list of reloaded skill paths, or empty dict
|
||||
if no changes were detected.
|
||||
"""
|
||||
changed = self._pending_skill_changes.get(config_name)
|
||||
if not changed:
|
||||
return {}
|
||||
|
||||
self._pending_skill_changes[config_name] = set()
|
||||
return self.prepare_active_skills(config_name, agent_defaults)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal change-tracking state (populated by _SkillsChangeHandler)
|
||||
# -------------------------------------------------------------------------
|
||||
_pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||
|
||||
def _resolve_disabled_skill_names(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -613,6 +801,53 @@ class SkillsManager:
|
||||
]
|
||||
|
||||
|
||||
class _SkillsWatcher:
|
||||
"""Handle returned by watch_active_skills; call .stop() to halt monitoring."""
|
||||
|
||||
def __init__(self, observer: Observer, handler: "_SkillsChangeHandler") -> None:
|
||||
self._observer = observer
|
||||
self._handler = handler
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the file system observer."""
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
|
||||
|
||||
class _SkillsChangeHandler(FileSystemEventHandler):
|
||||
"""Collects file-change events on skill directories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
watched_paths: List[Path],
|
||||
callback: Optional[Any] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._watched_paths = watched_paths
|
||||
self._callback = callback
|
||||
|
||||
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
src_path = Path(event.src_path)
|
||||
for watched in self._watched_paths:
|
||||
if src_path.is_relative_to(watched):
|
||||
SkillsManager._pending_skill_changes.setdefault(
|
||||
self._run_id_from_path(src_path), set()
|
||||
).add(src_path)
|
||||
if self._callback:
|
||||
self._callback([src_path])
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _run_id_from_path(path: Path) -> str:
|
||||
"""Infer config_name from a path like runs/{config_name}/skills/active/..."""
|
||||
parts = path.parts
|
||||
for i, part in enumerate(parts):
|
||||
if part == "runs" and i + 1 < len(parts):
|
||||
return parts[i + 1]
|
||||
return "default"
|
||||
|
||||
def _dedupe_preserve_order(items: Iterable[str]) -> List[str]:
|
||||
result: List[str] = []
|
||||
for item in items:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
以及合并Agent特定工具。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
@@ -13,6 +13,7 @@ import yaml
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.skill_loader import load_skill_from_dir, get_skill_tools
|
||||
from backend.agents.skill_metadata import parse_skill_metadata
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
|
||||
|
||||
@@ -117,6 +118,26 @@ def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
||||
pm_agent._make_decision,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_add_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._add_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_remove_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._remove_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_set_active_analysts"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._set_active_analysts,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_create_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._create_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
|
||||
|
||||
def _register_risk_tool_groups(toolkit: Any) -> None:
|
||||
@@ -223,6 +244,8 @@ def create_agent_toolkit(
|
||||
for skill_dir in active_skill_dirs:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
apply_skill_tool_restrictions(toolkit, active_skill_dirs)
|
||||
|
||||
if active_groups:
|
||||
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
||||
|
||||
@@ -309,6 +332,8 @@ def create_toolkit_from_workspace(
|
||||
for skill_dir in skill_dirs:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
apply_skill_tool_restrictions(toolkit, skill_dirs)
|
||||
|
||||
# 激活指定的工具组
|
||||
if active_groups is None:
|
||||
# 从配置中读取
|
||||
@@ -397,3 +422,96 @@ def refresh_toolkit_skills(
|
||||
for skill_dir in sorted(local_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
|
||||
def apply_skill_tool_restrictions(toolkit: Any, skill_dirs: List[Path]) -> None:
|
||||
"""Apply per-skill allowed_tools / denied_tools restrictions to a toolkit.
|
||||
|
||||
If a skill specifies allowed_tools, only those tools are accessible when
|
||||
that skill is active. If a skill specifies denied_tools, those tools are
|
||||
removed regardless of allowed_tools. Denied tools take precedence.
|
||||
|
||||
This function annotates the toolkit with a _skill_tool_restrictions map
|
||||
that downstream code can consult when resolving available tools.
|
||||
|
||||
Args:
|
||||
toolkit: The agentscope Toolkit instance.
|
||||
skill_dirs: List of skill directory paths to inspect.
|
||||
"""
|
||||
restrictions: Dict[str, Dict[str, Set[str]]] = {}
|
||||
for skill_dir in skill_dirs:
|
||||
metadata = parse_skill_metadata(skill_dir, source="active")
|
||||
if not metadata.allowed_tools and not metadata.denied_tools:
|
||||
continue
|
||||
restrictions[skill_dir.name] = {
|
||||
"allowed": set(metadata.allowed_tools),
|
||||
"denied": set(metadata.denied_tools),
|
||||
}
|
||||
if hasattr(toolkit, "agent_skills"):
|
||||
for skill in toolkit.agent_skills:
|
||||
skill_name = getattr(skill, "name", "") or ""
|
||||
if skill_name in restrictions:
|
||||
setattr(
|
||||
skill,
|
||||
"_tool_allowed",
|
||||
restrictions[skill_name]["allowed"],
|
||||
)
|
||||
setattr(
|
||||
skill,
|
||||
"_tool_denied",
|
||||
restrictions[skill_name]["denied"],
|
||||
)
|
||||
|
||||
|
||||
def get_skill_effective_tools(skill: Any) -> Optional[Set[str]]:
|
||||
"""Return the effective tool set for a skill after applying restrictions.
|
||||
|
||||
If the skill has no restrictions (no allowed_tools / denied_tools),
|
||||
returns None to indicate "all tools allowed".
|
||||
|
||||
If allowed_tools is set, returns only those tools minus denied_tools.
|
||||
If only denied_tools is set, returns all tools minus denied_tools.
|
||||
|
||||
Args:
|
||||
skill: A skill object previously registered via register_agent_skill.
|
||||
|
||||
Returns:
|
||||
A set of allowed tool names, or None if unrestricted.
|
||||
"""
|
||||
allowed = getattr(skill, "_tool_allowed", None)
|
||||
denied = getattr(skill, "_tool_denied", set())
|
||||
|
||||
if allowed is None:
|
||||
return None
|
||||
|
||||
effective = allowed - denied
|
||||
return effective
|
||||
|
||||
|
||||
def filter_toolkit_by_skill(
|
||||
toolkit: Any,
|
||||
skill_name: str,
|
||||
) -> Set[str]:
|
||||
"""Return the set of tool names that are accessible for a given skill.
|
||||
|
||||
Args:
|
||||
toolkit: The agentscope Toolkit instance.
|
||||
skill_name: Name of the skill to query.
|
||||
|
||||
Returns:
|
||||
Set of allowed tool names, or all registered tool names if unrestricted.
|
||||
"""
|
||||
if not hasattr(toolkit, "agent_skills"):
|
||||
return set()
|
||||
|
||||
for skill in toolkit.agent_skills:
|
||||
name = getattr(skill, "name", "") or ""
|
||||
if name != skill_name:
|
||||
continue
|
||||
effective = get_skill_effective_tools(skill)
|
||||
if effective is None:
|
||||
return set()
|
||||
return effective
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Workspace Manager - Create and manage agent workspaces."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
@@ -123,9 +126,8 @@ class WorkspaceRegistry:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
workspaces.append(WorkspaceConfig.from_dict(data))
|
||||
except Exception:
|
||||
# Skip invalid workspace configs
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load workspace config {config_path}: {e}")
|
||||
|
||||
return workspaces
|
||||
|
||||
@@ -167,9 +169,8 @@ class WorkspaceRegistry:
|
||||
"agent_type": config.get("agent_type", "unknown"),
|
||||
"config_path": str(config_path),
|
||||
})
|
||||
except Exception:
|
||||
# Skip invalid agent configs
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||
|
||||
return agents
|
||||
|
||||
@@ -294,8 +295,8 @@ class WorkspaceRegistry:
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
current_config = yaml.safe_load(f) or {}
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing config {config_path}: {e}")
|
||||
|
||||
# Update fields
|
||||
if name is not None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Dict, Iterable, Optional
|
||||
import yaml
|
||||
|
||||
from .skills_manager import SkillsManager
|
||||
from .team_pipeline_config import ensure_team_pipeline_config
|
||||
|
||||
|
||||
class RunWorkspaceManager:
|
||||
@@ -23,6 +24,16 @@ class RunWorkspaceManager:
|
||||
run_dir = self.get_run_dir(config_name)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.skills_manager.ensure_activation_manifest(config_name)
|
||||
ensure_team_pipeline_config(
|
||||
project_root=self.project_root,
|
||||
config_name=config_name,
|
||||
default_analysts=[
|
||||
"fundamentals_analyst",
|
||||
"technical_analyst",
|
||||
"sentiment_analyst",
|
||||
"valuation_analyst",
|
||||
],
|
||||
)
|
||||
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
||||
if not bootstrap_path.exists():
|
||||
bootstrap_path.write_text(
|
||||
|
||||
@@ -4,15 +4,20 @@ Agent API Routes
|
||||
|
||||
Provides REST API endpoints for agent management within workspaces.
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/workspaces/{workspace_id}/agents", tags=["agents"])
|
||||
|
||||
|
||||
@@ -35,6 +40,13 @@ class UpdateAgentRequest(BaseModel):
|
||||
disabled_skills: Optional[List[str]] = None
|
||||
|
||||
|
||||
class InstallExternalSkillRequest(BaseModel):
|
||||
"""Request to install an external skill for one agent."""
|
||||
source: str = Field(..., description="Directory path, zip path, or http(s) zip URL")
|
||||
name: Optional[str] = Field(None, description="Optional override skill name")
|
||||
activate: bool = Field(True, description="Whether to enable skill immediately")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Agent information response."""
|
||||
agent_id: str
|
||||
@@ -344,6 +356,86 @@ async def disable_skill(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/install")
|
||||
async def install_external_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
request: InstallExternalSkillRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
"""Install an external skill into one agent's local skills."""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
source=request.source,
|
||||
skill_name=request.name,
|
||||
activate=request.activate,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {
|
||||
"message": f"Installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||
**result,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/upload")
|
||||
async def upload_external_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
activate: bool = Form(True),
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
"""Upload a zip skill package from frontend and install for one agent."""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
original_name = (file.filename or "").strip()
|
||||
if not original_name.lower().endswith(".zip"):
|
||||
raise HTTPException(status_code=400, detail="Uploaded file must be a .zip archive")
|
||||
|
||||
suffix = Path(original_name).suffix or ".zip"
|
||||
temp_path: Optional[str] = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
temp_path = tmp.name
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
source=temp_path,
|
||||
skill_name=name,
|
||||
activate=activate,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
try:
|
||||
await file.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close uploaded file: {e}")
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
return {
|
||||
"message": f"Uploaded and installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||
**result,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{agent_id}/files/{filename}", response_model=AgentFileResponse)
|
||||
async def get_agent_file(
|
||||
workspace_id: str,
|
||||
|
||||
@@ -1,19 +1,25 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime API routes exposing the latest trading run state."""
|
||||
"""Runtime API routes - Control Plane for managing Gateway processes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.runtime.agent_runtime import AgentRuntimeState
|
||||
from backend.runtime.context import TradingRunContext
|
||||
from backend.runtime.manager import TradingRuntimeManager, get_global_runtime_manager
|
||||
|
||||
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
||||
@@ -21,9 +27,9 @@ router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
||||
runtime_manager: Optional[TradingRuntimeManager] = None
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
# Global task reference for running pipeline
|
||||
_running_task: Optional[asyncio.Task] = None
|
||||
_stop_event: Optional[asyncio.Event] = None
|
||||
# Gateway process management
|
||||
_gateway_process: Optional[subprocess.Popen] = None
|
||||
_gateway_port: int = 8765
|
||||
|
||||
|
||||
class RunContextResponse(BaseModel):
|
||||
@@ -67,12 +73,15 @@ class LaunchConfig(BaseModel):
|
||||
mode: str = Field(default="live", description="运行模式: live, backtest")
|
||||
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
|
||||
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
|
||||
poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)")
|
||||
enable_mock: bool = Field(default=False, description="是否启用模拟模式(使用模拟价格数据)")
|
||||
|
||||
|
||||
class LaunchResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str
|
||||
run_dir: str
|
||||
gateway_port: int
|
||||
message: str
|
||||
|
||||
|
||||
@@ -81,10 +90,10 @@ class StopResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str
|
||||
message: str
|
||||
class GatewayStatusResponse(BaseModel):
|
||||
is_running: bool
|
||||
port: int
|
||||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
def _generate_run_id() -> str:
|
||||
@@ -97,44 +106,92 @@ def _get_run_dir(run_id: str) -> Path:
|
||||
return PROJECT_ROOT / "runs" / run_id
|
||||
|
||||
|
||||
def _latest_snapshot_path() -> Optional[Path]:
|
||||
candidates = sorted(
|
||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
|
||||
"""Find an available port for Gateway."""
|
||||
import socket
|
||||
for port in range(start_port, max_port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(('localhost', port)) != 0:
|
||||
return port
|
||||
raise RuntimeError("No available port found")
|
||||
|
||||
|
||||
def _is_gateway_running() -> bool:
|
||||
"""Check if Gateway process is running."""
|
||||
global _gateway_process
|
||||
if _gateway_process is None:
|
||||
return False
|
||||
return _gateway_process.poll() is None
|
||||
|
||||
|
||||
def _stop_gateway() -> bool:
|
||||
"""Stop the Gateway process."""
|
||||
global _gateway_process
|
||||
if _gateway_process is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Try graceful shutdown first
|
||||
_gateway_process.terminate()
|
||||
try:
|
||||
_gateway_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if graceful shutdown fails
|
||||
_gateway_process.kill()
|
||||
_gateway_process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during gateway shutdown: {e}")
|
||||
finally:
|
||||
_gateway_process = None
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _start_gateway_process(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: Dict[str, Any],
|
||||
port: int
|
||||
) -> subprocess.Popen:
|
||||
"""Start Gateway as a separate process."""
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
|
||||
# Create command arguments
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m", "backend.gateway_server",
|
||||
"--run-id", run_id,
|
||||
"--run-dir", str(run_dir),
|
||||
"--port", str(port),
|
||||
"--bootstrap", json.dumps(bootstrap)
|
||||
]
|
||||
|
||||
# Start process
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
|
||||
def _load_snapshot() -> Dict[str, Any]:
|
||||
snapshot_path = _latest_snapshot_path()
|
||||
if snapshot_path is None or not snapshot_path.exists():
|
||||
raise HTTPException(status_code=503, detail="runtime manager is not initialized")
|
||||
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_runtime_payload() -> Dict[str, Any]:
|
||||
if runtime_manager is not None:
|
||||
return runtime_manager.build_snapshot()
|
||||
return _load_snapshot()
|
||||
|
||||
|
||||
def _to_state_response(state: AgentRuntimeState) -> RuntimeAgentState:
|
||||
return RuntimeAgentState(
|
||||
agent_id=state.agent_id,
|
||||
status=state.status,
|
||||
last_session=state.last_session,
|
||||
last_updated=state.last_updated.isoformat(),
|
||||
)
|
||||
return process
|
||||
|
||||
|
||||
@router.get("/context", response_model=RunContextResponse)
|
||||
async def get_run_context() -> RunContextResponse:
|
||||
"""Return the most recent run context."""
|
||||
payload = _get_runtime_payload()
|
||||
context = payload.get("context")
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No run context available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context")
|
||||
if context is None:
|
||||
raise HTTPException(status_code=404, detail="run context is not ready")
|
||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||
|
||||
return RunContextResponse(
|
||||
config_name=context["config_name"],
|
||||
@@ -144,88 +201,74 @@ async def get_run_context() -> RunContextResponse:
|
||||
|
||||
|
||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||
async def list_agent_states() -> RuntimeAgentsResponse:
|
||||
"""List the current runtime state of every registered agent."""
|
||||
payload = _get_runtime_payload()
|
||||
agents = [RuntimeAgentState(**agent) for agent in payload.get("agents", [])]
|
||||
return RuntimeAgentsResponse(agents=agents)
|
||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
"""Return agent states from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
agents = latest.get("agents", [])
|
||||
|
||||
return RuntimeAgentsResponse(
|
||||
agents=[RuntimeAgentState(**a) for a in agents]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||
async def list_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return the recent runtime events that TradingRuntimeManager emitted."""
|
||||
payload = _get_runtime_payload()
|
||||
events = [RuntimeEvent(**event) for event in payload.get("events", [])]
|
||||
return RuntimeEventsResponse(events=events)
|
||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return events from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
@router.get("/agents/{agent_id}", response_model=RuntimeAgentState)
|
||||
async def get_agent_state(agent_id: str) -> RuntimeAgentState:
|
||||
"""Return the current runtime state for a single agent."""
|
||||
payload = _get_runtime_payload()
|
||||
state = next(
|
||||
(agent for agent in payload.get("agents", []) if agent["agent_id"] == agent_id),
|
||||
None,
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
events = latest.get("events", [])
|
||||
|
||||
return RuntimeEventsResponse(
|
||||
events=[RuntimeEvent(**e) for e in events]
|
||||
)
|
||||
if state is None:
|
||||
raise HTTPException(status_code=404, detail=f"agent '{agent_id}' not registered")
|
||||
return RuntimeAgentState(**state)
|
||||
|
||||
|
||||
def register_runtime_manager(manager: TradingRuntimeManager) -> None:
|
||||
"""Allow other modules to expose the runtime manager to the API."""
|
||||
global runtime_manager
|
||||
runtime_manager = manager
|
||||
@router.get("/gateway/status", response_model=GatewayStatusResponse)
|
||||
async def get_gateway_status() -> GatewayStatusResponse:
|
||||
"""Get Gateway process status and port."""
|
||||
global _gateway_port
|
||||
|
||||
is_running = _is_gateway_running()
|
||||
run_id = None
|
||||
|
||||
def unregister_runtime_manager() -> None:
|
||||
"""Drop the runtime manager reference (used for shutdown/testing)."""
|
||||
global runtime_manager
|
||||
runtime_manager = None
|
||||
|
||||
|
||||
async def _stop_current_runtime(force: bool = True) -> bool:
|
||||
"""Stop the current running runtime if exists.
|
||||
|
||||
Args:
|
||||
force: If True, cancel the running task immediately
|
||||
|
||||
Returns:
|
||||
True if a runtime was stopped, False if no runtime was running
|
||||
"""
|
||||
global _running_task, _stop_event
|
||||
|
||||
# Signal stop
|
||||
if _stop_event is not None:
|
||||
_stop_event.set()
|
||||
|
||||
# Cancel running task
|
||||
if _running_task is not None and not _running_task.done():
|
||||
if force:
|
||||
_running_task.cancel()
|
||||
if is_running:
|
||||
# Try to find run_id from runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if snapshots:
|
||||
try:
|
||||
await _running_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
# Wait for graceful shutdown
|
||||
try:
|
||||
await asyncio.wait_for(_running_task, timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
_running_task.cancel()
|
||||
try:
|
||||
await _running_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
run_id = latest.get("context", {}).get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse latest snapshot: {e}")
|
||||
|
||||
_running_task = None
|
||||
_stop_event = None
|
||||
return GatewayStatusResponse(
|
||||
is_running=is_running,
|
||||
port=_gateway_port,
|
||||
run_id=run_id
|
||||
)
|
||||
|
||||
# Unregister runtime manager
|
||||
if runtime_manager is not None:
|
||||
unregister_runtime_manager()
|
||||
|
||||
return True
|
||||
@router.get("/gateway/port")
|
||||
async def get_gateway_port() -> Dict[str, Any]:
|
||||
"""Get WebSocket Gateway port for frontend connection."""
|
||||
global _gateway_port
|
||||
return {
|
||||
"port": _gateway_port,
|
||||
"is_running": _is_gateway_running(),
|
||||
"ws_url": f"ws://localhost:{_gateway_port}"
|
||||
}
|
||||
|
||||
|
||||
@router.post("/start", response_model=LaunchResponse)
|
||||
@@ -235,13 +278,18 @@ async def start_runtime(
|
||||
) -> LaunchResponse:
|
||||
"""Start a new trading runtime with the given configuration.
|
||||
|
||||
If a runtime is already running, it will be forcefully stopped first.
|
||||
Creates a new timestamped run directory.
|
||||
1. Stop existing Gateway if running
|
||||
2. Generate run ID and directory
|
||||
3. Create runtime manager
|
||||
4. Start Gateway as subprocess (Data Plane)
|
||||
5. Return Gateway port for WebSocket connection
|
||||
"""
|
||||
global _running_task, _stop_event, runtime_manager
|
||||
global _gateway_process, _gateway_port
|
||||
|
||||
# 1. Stop current runtime if exists
|
||||
await _stop_current_runtime(force=True)
|
||||
# 1. Stop existing Gateway
|
||||
if _is_gateway_running():
|
||||
_stop_gateway()
|
||||
await asyncio.sleep(1) # Wait for port release
|
||||
|
||||
# 2. Generate run ID and directory
|
||||
run_id = _generate_run_id()
|
||||
@@ -260,92 +308,136 @@ async def start_runtime(
|
||||
"mode": config.mode,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
"poll_interval": config.poll_interval,
|
||||
"enable_mock": config.enable_mock,
|
||||
}
|
||||
|
||||
# 4. Create and prepare runtime manager
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
# 4. Create runtime manager
|
||||
manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
set_global_runtime_manager = None # Will be set by main module
|
||||
manager.prepare_run()
|
||||
register_runtime_manager(manager)
|
||||
|
||||
# 5. Write BOOTSTRAP.md
|
||||
_write_bootstrap_md(run_dir, bootstrap)
|
||||
|
||||
# 6. Start pipeline in background
|
||||
_stop_event = asyncio.Event()
|
||||
_running_task = asyncio.create_task(
|
||||
_run_pipeline(run_id, run_dir, bootstrap, _stop_event)
|
||||
)
|
||||
# 6. Find available port and start Gateway process
|
||||
_gateway_port = _find_available_port(start_port=8765)
|
||||
|
||||
try:
|
||||
_gateway_process = _start_gateway_process(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
port=_gateway_port
|
||||
)
|
||||
|
||||
# Wait briefly to check if process started successfully
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not _is_gateway_running():
|
||||
stdout, stderr = _gateway_process.communicate(timeout=1)
|
||||
_gateway_process = None
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stop_gateway()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start Gateway: {str(e)}")
|
||||
|
||||
return LaunchResponse(
|
||||
run_id=run_id,
|
||||
status="started",
|
||||
run_dir=str(run_dir),
|
||||
message=f"Runtime started with run_id: {run_id}",
|
||||
gateway_port=_gateway_port,
|
||||
message=f"Runtime started with run_id: {run_id}, Gateway on port: {_gateway_port}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stop", response_model=StopResponse)
|
||||
async def stop_runtime(force: bool = True) -> StopResponse:
|
||||
"""Stop the current running runtime.
|
||||
"""Stop the current running runtime."""
|
||||
global _gateway_process
|
||||
|
||||
Args:
|
||||
force: If True, forcefully cancel the running task
|
||||
"""
|
||||
was_running = await _stop_current_runtime(force=force)
|
||||
was_running = _is_gateway_running()
|
||||
|
||||
if not was_running:
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Stop Gateway process
|
||||
_stop_gateway()
|
||||
|
||||
# Unregister runtime manager
|
||||
unregister_runtime_manager()
|
||||
|
||||
return StopResponse(
|
||||
status="stopped",
|
||||
message="Runtime stopped successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
@router.post("/restart")
|
||||
async def restart_runtime(
|
||||
config: LaunchConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
) -> RestartResponse:
|
||||
"""Restart the runtime with a new configuration.
|
||||
|
||||
Equivalent to stop + start.
|
||||
"""
|
||||
):
|
||||
"""Restart the runtime with a new configuration."""
|
||||
# Stop current runtime
|
||||
await _stop_current_runtime(force=True)
|
||||
await stop_runtime(force=True)
|
||||
|
||||
# Start new runtime
|
||||
response = await start_runtime(config, background_tasks)
|
||||
|
||||
return RestartResponse(
|
||||
run_id=response.run_id,
|
||||
status="restarted",
|
||||
message=f"Runtime restarted with run_id: {response.run_id}",
|
||||
)
|
||||
return {
|
||||
"run_id": response.run_id,
|
||||
"status": "restarted",
|
||||
"gateway_port": response.gateway_port,
|
||||
"message": f"Runtime restarted with run_id: {response.run_id}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_runtime():
|
||||
"""Get information about the currently running runtime."""
|
||||
global _running_task, runtime_manager
|
||||
|
||||
is_running = _running_task is not None and not _running_task.done()
|
||||
|
||||
if not is_running or runtime_manager is None:
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Find latest runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context", {})
|
||||
|
||||
return {
|
||||
"run_id": runtime_manager.config_name,
|
||||
"run_dir": str(runtime_manager.run_dir),
|
||||
"is_running": is_running,
|
||||
"bootstrap": runtime_manager.bootstrap,
|
||||
"run_id": context.get("config_name"),
|
||||
"run_dir": context.get("run_dir"),
|
||||
"is_running": True,
|
||||
"gateway_port": _gateway_port,
|
||||
"bootstrap": context.get("bootstrap_values", {}),
|
||||
}
|
||||
|
||||
|
||||
def register_runtime_manager(manager: TradingRuntimeManager) -> None:
|
||||
"""Allow other modules to expose the runtime manager to the API."""
|
||||
global runtime_manager
|
||||
runtime_manager = manager
|
||||
|
||||
|
||||
def unregister_runtime_manager() -> None:
|
||||
"""Drop the runtime manager reference."""
|
||||
global runtime_manager
|
||||
runtime_manager = None
|
||||
|
||||
|
||||
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
||||
"""Write bootstrap configuration to BOOTSTRAP.md."""
|
||||
try:
|
||||
@@ -362,38 +454,7 @@ def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
||||
if yaml:
|
||||
front_matter = yaml.safe_dump(values, allow_unicode=True, sort_keys=False)
|
||||
else:
|
||||
# Fallback to JSON if yaml not available
|
||||
front_matter = json.dumps(values, ensure_ascii=False, indent=2)
|
||||
|
||||
content = f"---\n{front_matter}---\n"
|
||||
bootstrap_path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
async def _run_pipeline(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: Dict[str, Any],
|
||||
stop_event: asyncio.Event
|
||||
) -> None:
|
||||
"""Background task to run the trading pipeline."""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from backend.core.pipeline_runner import run_pipeline
|
||||
|
||||
try:
|
||||
logger.info(f"Starting pipeline for run_id: {run_id}")
|
||||
await run_pipeline(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
logger.info(f"Pipeline completed for run_id: {run_id}")
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Pipeline cancelled for run_id: {run_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"Pipeline failed for run_id: {run_id}: {e}")
|
||||
# Re-raise to allow proper cleanup
|
||||
raise
|
||||
|
||||
@@ -8,6 +8,7 @@ and frontend development server.
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=R0912, R0915
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -17,7 +18,10 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
@@ -27,7 +31,12 @@ from dotenv import load_dotenv
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.team_pipeline_config import (
|
||||
ensure_team_pipeline_config,
|
||||
load_team_pipeline_config,
|
||||
)
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled
|
||||
@@ -42,6 +51,8 @@ ingest_app = typer.Typer(help="Ingest Polygon market data into the research ware
|
||||
app.add_typer(ingest_app, name="ingest")
|
||||
skills_app = typer.Typer(help="Inspect and manage per-agent skills.")
|
||||
app.add_typer(skills_app, name="skills")
|
||||
team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.")
|
||||
app.add_typer(team_app, name="team")
|
||||
|
||||
console = Console()
|
||||
_prompt_loader = PromptLoader()
|
||||
@@ -95,8 +106,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
||||
)
|
||||
else:
|
||||
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not calculate directory size: {e}")
|
||||
|
||||
# Show last modified time
|
||||
state_dir = base_data_dir / "state"
|
||||
@@ -197,7 +208,8 @@ def run_data_updater(project_root: Path) -> None:
|
||||
console.print(
|
||||
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(f"Data updater check failed: {e}")
|
||||
console.print(
|
||||
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
||||
)
|
||||
@@ -777,6 +789,78 @@ def skills_disable(
|
||||
console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}")
|
||||
|
||||
|
||||
@skills_app.command("install")
|
||||
def skills_install(
|
||||
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Target agent id."),
|
||||
source: str = typer.Option(
|
||||
...,
|
||||
"--source",
|
||||
"-s",
|
||||
help="External skill source: directory path, zip path, or http(s) zip URL.",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
name: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--name",
|
||||
help="Optional override skill name.",
|
||||
),
|
||||
activate: bool = typer.Option(
|
||||
True,
|
||||
"--activate/--no-activate",
|
||||
help="Enable the skill for this agent immediately.",
|
||||
),
|
||||
):
|
||||
"""Install an external skill into one agent's local skill directory."""
|
||||
_require_agent_asset_dir(config_name, agent_id)
|
||||
skills_manager = SkillsManager(project_root=get_project_root())
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
source=source,
|
||||
skill_name=name,
|
||||
activate=activate,
|
||||
)
|
||||
console.print(
|
||||
f"[green]Installed[/green] `{result['skill_name']}` to `{agent_id}`",
|
||||
)
|
||||
console.print(f"Path: {result['target_dir']}")
|
||||
console.print(f"Activated: {result['activated']}")
|
||||
warnings = result.get("warnings") or []
|
||||
if warnings:
|
||||
console.print(f"Warnings: {'; '.join(warnings)}")
|
||||
|
||||
|
||||
@team_app.command("show")
|
||||
def team_show(
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
):
|
||||
"""Show TEAM_PIPELINE.yaml for one run."""
|
||||
project_root = get_project_root()
|
||||
ensure_team_pipeline_config(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
default_analysts=list(ANALYST_TYPES.keys()),
|
||||
)
|
||||
config = load_team_pipeline_config(project_root, config_name)
|
||||
console.print(
|
||||
Panel.fit(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
title=f"TEAM_PIPELINE ({config_name})",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
start: Optional[str] = typer.Option(
|
||||
|
||||
@@ -10,6 +10,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
@@ -21,6 +23,26 @@ from backend.core.state_sync import StateSync
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
from backend.runtime.manager import TradingRuntimeManager
|
||||
from backend.runtime.session import TradingSessionKey
|
||||
from backend.agents.team_pipeline_config import (
|
||||
resolve_active_analysts,
|
||||
update_active_analysts,
|
||||
)
|
||||
from backend.agents import AnalystAgent
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
|
||||
# Team infrastructure imports (graceful import - may not exist yet)
|
||||
try:
|
||||
from backend.agents.team.team_coordinator import TeamCoordinator
|
||||
from backend.agents.team.msg_hub import MsgHub as TeamMsgHub
|
||||
TEAM_COORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TEAM_COORD_AVAILABLE = False
|
||||
TeamCoordinator = None
|
||||
TeamMsgHub = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -77,6 +99,13 @@ class TradingPipeline:
|
||||
self.agent_factory = agent_factory
|
||||
self.runtime_manager = runtime_manager
|
||||
self._session_key: Optional[str] = None
|
||||
self._dynamic_analysts: Dict[str, Any] = {}
|
||||
|
||||
if hasattr(self.pm, "set_team_controller"):
|
||||
self.pm.set_team_controller(
|
||||
create_agent_callback=self._create_runtime_analyst,
|
||||
remove_agent_callback=self._remove_runtime_analyst,
|
||||
)
|
||||
|
||||
async def run_cycle(
|
||||
self,
|
||||
@@ -115,16 +144,17 @@ class TradingPipeline:
|
||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||
session_key = TradingSessionKey(date=date).key()
|
||||
self._session_key = session_key
|
||||
active_analysts = self._get_active_analysts()
|
||||
if self.runtime_manager:
|
||||
self.runtime_manager.set_session_key(session_key)
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||
self._runtime_batch_status(self.analysts, "analysis_in_progress")
|
||||
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||
|
||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
|
||||
participants = self.analysts + [self.risk_manager, self.pm]
|
||||
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||
|
||||
# Single MsgHub for entire cycle - no nesting
|
||||
async with MsgHub(
|
||||
@@ -135,9 +165,13 @@ class TradingPipeline:
|
||||
"system",
|
||||
),
|
||||
):
|
||||
# Phase 1.1: Analysts
|
||||
_log("Phase 1.1: Analyst analysis")
|
||||
analyst_results = await self._run_analysts_with_sync(tickers, date)
|
||||
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Phase 1.2: Risk Manager
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
@@ -164,6 +198,7 @@ class TradingPipeline:
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Record final predictions for leaderboard ranking
|
||||
@@ -212,7 +247,7 @@ class TradingPipeline:
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self.analysts + [self.pm],
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
|
||||
@@ -246,13 +281,13 @@ class TradingPipeline:
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self.analysts + [self.pm],
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
|
||||
_log(f"Cycle complete: {date}")
|
||||
self._runtime_batch_status(
|
||||
self.analysts + [self.risk_manager, self.pm],
|
||||
self._all_analysts() + [self.risk_manager, self.pm],
|
||||
"idle",
|
||||
)
|
||||
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||
@@ -288,7 +323,7 @@ class TradingPipeline:
|
||||
},
|
||||
)
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
analyst.reload_runtime_assets(
|
||||
active_skill_dirs=active_skill_map.get(analyst.name, []),
|
||||
)
|
||||
@@ -302,7 +337,7 @@ class TradingPipeline:
|
||||
|
||||
return {
|
||||
"config_name": config_name,
|
||||
"reloaded_agents": [agent.name for agent in self.analysts]
|
||||
"reloaded_agents": [agent.name for agent in self._all_analysts()]
|
||||
+ ["risk_manager", "portfolio_manager"],
|
||||
"active_skills": {
|
||||
agent_id: [path.name for path in paths]
|
||||
@@ -313,7 +348,7 @@ class TradingPipeline:
|
||||
|
||||
async def _clear_all_agent_memory(self):
|
||||
"""Clear short-term memory for all agents"""
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
await analyst.memory.clear()
|
||||
|
||||
await self.risk_manager.memory.clear()
|
||||
@@ -395,7 +430,7 @@ class TradingPipeline:
|
||||
trajectories = {}
|
||||
|
||||
# Capture analyst trajectories
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
try:
|
||||
msgs = await analyst.memory.get_memory()
|
||||
if msgs:
|
||||
@@ -605,7 +640,7 @@ class TradingPipeline:
|
||||
)
|
||||
|
||||
# Record for analysts
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
if (
|
||||
hasattr(analyst, "long_term_memory")
|
||||
and analyst.long_term_memory is not None
|
||||
@@ -724,67 +759,82 @@ class TradingPipeline:
|
||||
date=date,
|
||||
)
|
||||
|
||||
# Run discussion cycles (no new MsgHub - use parent's)
|
||||
for cycle in range(self.max_comm_cycles):
|
||||
# Conference participants: analysts + PM
|
||||
conference_participants = self._get_active_analysts() + [self.pm]
|
||||
|
||||
# Use TeamMsgHub for conference if available
|
||||
if TEAM_COORD_AVAILABLE and TeamMsgHub is not None:
|
||||
_log(
|
||||
"Phase 2.1: Conference discussion - "
|
||||
f"Conference {cycle + 1}/{self.max_comm_cycles}",
|
||||
f"Phase 2.1: Conference using TeamMsgHub with "
|
||||
f"{len(conference_participants)} participants"
|
||||
)
|
||||
conference_hub = TeamMsgHub(participants=conference_participants)
|
||||
else:
|
||||
_log("Phase 2.1: Conference using standard MsgHub context")
|
||||
conference_hub = None
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_start(
|
||||
cycle=cycle + 1,
|
||||
total_cycles=self.max_comm_cycles,
|
||||
# Run discussion cycles
|
||||
async with conference_hub if conference_hub else nullcontext(None):
|
||||
for cycle in range(self.max_comm_cycles):
|
||||
_log(
|
||||
"Phase 2.1: Conference discussion - "
|
||||
f"Conference {cycle + 1}/{self.max_comm_cycles}",
|
||||
)
|
||||
|
||||
# PM sets agenda or asks questions
|
||||
pm_prompt = self._build_pm_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_start(
|
||||
cycle=cycle + 1,
|
||||
total_cycles=self.max_comm_cycles,
|
||||
)
|
||||
|
||||
pm_msg = Msg(name="system", content=pm_prompt, role="user")
|
||||
pm_response = await self.pm.reply(pm_msg)
|
||||
|
||||
if self.state_sync:
|
||||
pm_content = self._extract_text_content(pm_response.content)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id="portfolio_manager",
|
||||
content=pm_content,
|
||||
)
|
||||
|
||||
# Analysts share perspectives
|
||||
for analyst in self.analysts:
|
||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||
# PM sets agenda or asks questions
|
||||
pm_prompt = self._build_pm_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
analyst_results=analyst_results,
|
||||
risk_assessment=risk_assessment,
|
||||
)
|
||||
|
||||
analyst_msg = Msg(
|
||||
name="system",
|
||||
content=analyst_prompt,
|
||||
role="user",
|
||||
)
|
||||
analyst_response = await analyst.reply(analyst_msg)
|
||||
pm_msg = Msg(name="system", content=pm_prompt, role="user")
|
||||
pm_response = await self.pm.reply(pm_msg)
|
||||
|
||||
if self.state_sync:
|
||||
analyst_content = self._extract_text_content(
|
||||
analyst_response.content,
|
||||
)
|
||||
pm_content = self._extract_text_content(pm_response.content)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id=analyst.name,
|
||||
content=analyst_content,
|
||||
agent_id="portfolio_manager",
|
||||
content=pm_content,
|
||||
)
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_end(
|
||||
cycle=cycle + 1,
|
||||
)
|
||||
# Analysts share perspectives (supports per-round active team updates)
|
||||
for analyst in self._get_active_analysts():
|
||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
)
|
||||
|
||||
analyst_msg = Msg(
|
||||
name="system",
|
||||
content=analyst_prompt,
|
||||
role="user",
|
||||
)
|
||||
analyst_response = await analyst.reply(analyst_msg)
|
||||
|
||||
if self.state_sync:
|
||||
analyst_content = self._extract_text_content(
|
||||
analyst_response.content,
|
||||
)
|
||||
await self.state_sync.on_conference_message(
|
||||
agent_id=analyst.name,
|
||||
content=analyst_content,
|
||||
)
|
||||
|
||||
if self.state_sync:
|
||||
await self.state_sync.on_conference_cycle_end(
|
||||
cycle=cycle + 1,
|
||||
)
|
||||
|
||||
# Generate conference summary by PM
|
||||
_log(
|
||||
@@ -885,6 +935,7 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Collect final predictions from all analysts as simple text responses.
|
||||
@@ -892,14 +943,15 @@ class TradingPipeline:
|
||||
"""
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Starting _collect_final_predictions for {len(self.analysts)} analysts",
|
||||
f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts",
|
||||
)
|
||||
final_predictions = []
|
||||
|
||||
for i, analyst in enumerate(self.analysts):
|
||||
analysts = active_analysts or self.analysts
|
||||
for i, analyst in enumerate(analysts):
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}",
|
||||
f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}",
|
||||
)
|
||||
|
||||
prompt = (
|
||||
@@ -995,11 +1047,13 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts with real-time sync after each completion"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -1029,15 +1083,107 @@ class TradingPipeline:
|
||||
|
||||
return results
|
||||
|
||||
async def _run_analysts_parallel(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts in parallel using TeamCoordinator.
|
||||
|
||||
This method replaces the sequential analyst loop with parallel execution
|
||||
using the TeamCoordinator for orchestration.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers to analyze
|
||||
date: Trading date
|
||||
active_analysts: Optional list of analysts to run
|
||||
|
||||
Returns:
|
||||
List of analyst result dictionaries
|
||||
"""
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
if not analysts:
|
||||
return []
|
||||
|
||||
if not TEAM_COORD_AVAILABLE:
|
||||
_log("TeamCoordinator not available, falling back to sequential execution")
|
||||
return await self._run_analysts_with_sync(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Running {len(analysts)} analysts in parallel "
|
||||
f"[{', '.join(a.name for a in analysts)}]"
|
||||
)
|
||||
|
||||
# Build the analyst prompt
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
)
|
||||
|
||||
# Create coordinator for parallel execution
|
||||
coordinator = TeamCoordinator(
|
||||
participants=analysts,
|
||||
task_content=content,
|
||||
)
|
||||
|
||||
# Run analysts in parallel via TeamCoordinator
|
||||
results = await coordinator.run_phase(
|
||||
"analyst_analysis",
|
||||
metadata={"tickers": tickers, "date": date},
|
||||
)
|
||||
|
||||
# Process results and sync
|
||||
processed_results = []
|
||||
for i, (analyst, result) in enumerate(zip(analysts, results)):
|
||||
if result is not None:
|
||||
extracted = self._extract_result_from_msg(result)
|
||||
processed_results.append(extracted)
|
||||
|
||||
# Sync retrieved memory
|
||||
await self._sync_memory_if_retrieved(analyst)
|
||||
|
||||
# Broadcast agent result via StateSync
|
||||
if self.state_sync:
|
||||
text_content = self._extract_text_content(result.content)
|
||||
await self.state_sync.on_agent_complete(
|
||||
agent_id=analyst.name,
|
||||
content=text_content,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Analyst %s returned no result",
|
||||
analyst.name,
|
||||
)
|
||||
processed_results.append({
|
||||
"agent": analyst.name,
|
||||
"content": "",
|
||||
"success": False,
|
||||
})
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Parallel analyst execution complete "
|
||||
f"({len(processed_results)}/{len(analysts)} successful)"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def _run_analysts(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts (without sync, for backward compatibility)"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -1461,6 +1607,83 @@ class TradingPipeline:
|
||||
for agent in agents:
|
||||
self._runtime_update_status(agent, status)
|
||||
|
||||
def _all_analysts(self) -> List[Any]:
|
||||
"""Return static analysts plus runtime-created analysts."""
|
||||
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||
|
||||
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||
"""Create one runtime analyst instance."""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
return (
|
||||
f"Unknown analyst_type '{analyst_type}'. "
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||
)
|
||||
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||
return f"Analyst '{agent_id}' already exists."
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = PromptLoader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
WorkspaceManager(project_root=project_root).ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
role_seed=persona.get("description", "").strip(),
|
||||
style_seed="\n".join(f"- {item}" for item in persona.get("focus", [])),
|
||||
policy_seed=(
|
||||
"State a clear signal, confidence, and the conditions "
|
||||
"that would invalidate the thesis."
|
||||
),
|
||||
)
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=create_agent_toolkit(
|
||||
agent_id=agent_id,
|
||||
config_name=config_name,
|
||||
active_skill_dirs=[],
|
||||
),
|
||||
model=get_agent_model(analyst_type),
|
||||
formatter=get_agent_formatter(analyst_type),
|
||||
agent_id=agent_id,
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
self._dynamic_analysts[agent_id] = agent
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
add=[agent_id],
|
||||
)
|
||||
return f"Created runtime analyst '{agent_id}' ({analyst_type})."
|
||||
|
||||
def _remove_runtime_analyst(self, agent_id: str) -> str:
|
||||
"""Remove one runtime-created analyst instance."""
|
||||
if agent_id not in self._dynamic_analysts:
|
||||
return f"Runtime analyst '{agent_id}' not found."
|
||||
self._dynamic_analysts.pop(agent_id, None)
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
remove=[agent_id],
|
||||
)
|
||||
return f"Removed runtime analyst '{agent_id}'."
|
||||
|
||||
def _get_active_analysts(self) -> List[Any]:
|
||||
"""Resolve active analyst participants from run-scoped team pipeline config."""
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
analyst_map = {agent.name: agent for agent in self._all_analysts()}
|
||||
active_ids = resolve_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(analyst_map.keys()),
|
||||
)
|
||||
return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map]
|
||||
|
||||
def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self.runtime_manager:
|
||||
return
|
||||
|
||||
@@ -61,7 +61,7 @@ def stop_gateway() -> None:
|
||||
_gateway_instance = None
|
||||
|
||||
|
||||
async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||
try:
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
@@ -206,6 +206,13 @@ async def run_pipeline(
|
||||
"""
|
||||
Run the trading pipeline with the given configuration.
|
||||
|
||||
Service Startup Order:
|
||||
Phase 1: WebSocket Server - Frontend can connect
|
||||
Phase 2: Market Service - Price data starts flowing
|
||||
Phase 3: Agent Runtime - Create all agents
|
||||
Phase 4: Pipeline & Scheduler - Trading logic ready
|
||||
Phase 5: Gateway Fully Operational - All systems running
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier (timestamp)
|
||||
run_dir: Run directory path
|
||||
@@ -219,7 +226,9 @@ async def run_pipeline(
|
||||
# Set global shutdown event
|
||||
set_shutdown_event(stop_event)
|
||||
|
||||
logger.info(f"[Pipeline {run_id}] Starting...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
|
||||
try:
|
||||
# Extract config values
|
||||
@@ -230,15 +239,21 @@ async def run_pipeline(
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
enable_mock = bootstrap.get("enable_mock", False)
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
is_mock = mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 0: Initialize runtime manager
|
||||
# ======================================================================
|
||||
logger.info("[Phase 0/5] Initializing runtime manager...")
|
||||
|
||||
# Get or create runtime manager
|
||||
from backend.api.runtime import runtime_manager
|
||||
|
||||
if runtime_manager is None:
|
||||
@@ -255,16 +270,11 @@ async def run_pipeline(
|
||||
from backend.api.runtime import register_runtime_manager
|
||||
register_runtime_manager(runtime_manager)
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
mock_mode=is_mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
# ======================================================================
|
||||
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||
# These will be started by Gateway in the correct order
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1-2/5] Creating infrastructure services...")
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
@@ -278,7 +288,22 @@ async def run_pipeline(
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents and pipeline
|
||||
# Create market service (data source)
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
mock_mode=is_mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Create Agent Runtime
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/5] Creating agent runtime...")
|
||||
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
@@ -303,6 +328,11 @@ async def run_pipeline(
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Create Pipeline & Scheduler
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
@@ -336,6 +366,7 @@ async def run_pipeline(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
@@ -344,7 +375,15 @@ async def run_pipeline(
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Create Gateway for WebSocket connections (after pipeline and scheduler are ready)
|
||||
# ======================================================================
|
||||
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
|
||||
# Gateway.start() will handle the final startup sequence:
|
||||
# - WebSocket Server first (frontend can connect)
|
||||
# - Market Service second (price data flows)
|
||||
# - Scheduler last (trading begins)
|
||||
# ======================================================================
|
||||
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
|
||||
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
@@ -359,6 +398,7 @@ async def run_pipeline(
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
@@ -374,13 +414,17 @@ async def run_pipeline(
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Start Gateway in background task
|
||||
# Start Gateway - this will execute the 4-phase startup:
|
||||
# Phase 1: WebSocket Server (frontend can connect immediately)
|
||||
# Phase 2: Market Service (price updates start flowing)
|
||||
# Phase 3: Market Status Monitor
|
||||
# Phase 4: Scheduler (trading cycles begin)
|
||||
gateway_task = asyncio.create_task(
|
||||
gateway.start(host="0.0.0.0", port=8765)
|
||||
)
|
||||
logger.info("[Pipeline] Gateway started on ws://localhost:8765")
|
||||
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
|
||||
|
||||
# Give Gateway a moment to start
|
||||
# Wait for Gateway to fully initialize all phases
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Define the trading cycle callback
|
||||
|
||||
@@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -28,17 +28,21 @@ class Scheduler:
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
heartbeat_interval: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._callback: Optional[Callable] = None
|
||||
self._heartbeat_callback: Optional[Callable] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
@@ -53,6 +57,15 @@ class Scheduler:
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _is_trading_hours(self, now: datetime) -> bool:
|
||||
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
|
||||
market_time = now.time()
|
||||
return time(9, 30) <= market_time <= time(16, 0)
|
||||
|
||||
def set_heartbeat_callback(self, callback: Callable) -> None:
|
||||
"""Register callback for heartbeat triggers."""
|
||||
self._heartbeat_callback = callback
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
@@ -72,6 +85,13 @@ class Scheduler:
|
||||
self._callback = callback
|
||||
self._schedule_task()
|
||||
|
||||
# Start heartbeat loop if configured
|
||||
if self.heartbeat_interval and self._heartbeat_callback:
|
||||
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
|
||||
logger.info(
|
||||
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
@@ -132,6 +152,30 @@ class Scheduler:
|
||||
|
||||
return changed
|
||||
|
||||
async def _run_heartbeat_loop(self):
|
||||
"""Run heartbeat checks on a separate interval during trading hours."""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
if self._is_trading_day(now) and self._is_trading_hours(now):
|
||||
if self._heartbeat_callback:
|
||||
try:
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.debug(
|
||||
f"[Heartbeat] Triggering heartbeat check for {current_date}",
|
||||
)
|
||||
await self._heartbeat_callback(date=current_date)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Heartbeat] Callback failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Heartbeat] Callback not set, skipping heartbeat",
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
@@ -206,6 +250,9 @@ class Scheduler:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
|
||||
@@ -163,6 +163,16 @@ class AnalystSignal(BaseModel):
|
||||
signal: str | None = None
|
||||
confidence: float | None = None
|
||||
reasoning: dict | str | None = None
|
||||
# Extended fields for richer signal information
|
||||
reasons: list[str] | None = None # Core drivers/reasons for the signal
|
||||
risks: list[str] | None = None # Key risk factors
|
||||
invalidation: str | None = None # Conditions that would invalidate the thesis
|
||||
next_action: str | None = None # Suggested next action for PM
|
||||
# Valuation-related fields
|
||||
intrinsic_value: float | None = None # DCF intrinsic value
|
||||
fair_value_range: dict | None = None # {bear, base, bull} fair value range
|
||||
value_gap_pct: float | None = None # Value gap percentage
|
||||
valuation_methods: list[str] | None = None # List of valuation methods used
|
||||
max_position_size: float | None = None # For risk management signals
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
@@ -12,6 +13,8 @@ from pydantic import BaseModel, Field
|
||||
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
|
||||
from backend.llm.models import create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnrichedNewsItem(BaseModel):
|
||||
"""Structured output schema for one enriched article."""
|
||||
@@ -156,7 +159,8 @@ def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
]
|
||||
try:
|
||||
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
|
||||
@@ -268,7 +272,8 @@ def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=RangeAnalysisPayload),
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
||||
from agentscope.formatter import (
|
||||
AnthropicChatFormatter,
|
||||
DashScopeChatFormatter,
|
||||
@@ -26,6 +28,244 @@ from backend.config.env_config import (
|
||||
get_env_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry wrapper types
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RetryChatModel:
|
||||
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||
|
||||
Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts,
|
||||
and other transient failures with exponential backoff.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_INITIAL_DELAY = 1.0
|
||||
DEFAULT_MAX_DELAY = 60.0
|
||||
DEFAULT_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Transient error codes/messages that should trigger retry
|
||||
TRANSIENT_ERROR_KEYWORDS = frozenset([
|
||||
"rate_limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"503",
|
||||
"502",
|
||||
"504",
|
||||
"connection",
|
||||
"temporary",
|
||||
"overloaded",
|
||||
"too_many_requests",
|
||||
])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
initial_delay: float = DEFAULT_INITIAL_DELAY,
|
||||
max_delay: float = DEFAULT_MAX_DELAY,
|
||||
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
|
||||
on_retry: Optional[Callable[[int, Exception, float], None]] = None,
|
||||
):
|
||||
"""Initialize retry wrapper.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds before first retry
|
||||
max_delay: Maximum delay between retries
|
||||
backoff_multiplier: Multiplier for exponential backoff
|
||||
on_retry: Optional callback(retry_count, exception, delay) for logging
|
||||
"""
|
||||
self._model = model
|
||||
self._max_retries = max_retries
|
||||
self._initial_delay = initial_delay
|
||||
self._max_delay = max_delay
|
||||
self._backoff_multiplier = backoff_multiplier
|
||||
self._on_retry = on_retry
|
||||
self._total_tokens_used = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens_used(self) -> int:
|
||||
return self._total_tokens_used
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def _is_transient_error(self, error: Exception) -> bool:
|
||||
"""Check if an error is transient and should be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is transient
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
for keyword in self.TRANSIENT_ERROR_KEYWORDS:
|
||||
if keyword in error_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_delay(self, retry_count: int) -> float:
|
||||
"""Calculate delay for given retry attempt with exponential backoff.
|
||||
|
||||
Args:
|
||||
retry_count: Current retry attempt number (1-based)
|
||||
|
||||
Returns:
|
||||
Delay in seconds
|
||||
"""
|
||||
delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1))
|
||||
return min(delay, self._max_delay)
|
||||
|
||||
def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call a function with retry logic for transient errors.
|
||||
|
||||
Args:
|
||||
func: Function to call
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
Last exception if all retries exhausted
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Track usage if available
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls to the wrapped model with retry logic."""
|
||||
return self._call_with_retry(self._model, *args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class TokenRecordingModelWrapper:
|
||||
"""Wraps a model to track token usage per provider.
|
||||
|
||||
Based on CoPaw's TokenRecordingModelWrapper design.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any):
|
||||
"""Initialize token recorder.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
"""
|
||||
self._model = model
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self._prompt_tokens
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return self._completion_tokens
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def record_usage(self, usage: Any) -> None:
|
||||
"""Record token usage from a model response.
|
||||
|
||||
Args:
|
||||
usage: Usage object from model response
|
||||
"""
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
self._prompt_tokens += getattr(usage, "prompt_tokens", 0)
|
||||
self._completion_tokens += getattr(usage, "completion_tokens", 0)
|
||||
self._total_tokens += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls and record usage."""
|
||||
result = self._model(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
self.record_usage(result.usage)
|
||||
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""Supported model providers"""
|
||||
|
||||
@@ -37,6 +37,9 @@ from backend.services.storage import StorageService
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.tools.data_tools import get_prices
|
||||
from backend.tools.data_tools import get_company_news
|
||||
from backend.tools.data_tools import get_insider_trades
|
||||
from backend.tools.data_tools import prices_to_df
|
||||
from backend.tools.technical_signals import StockTechnicalAnalyzer
|
||||
from backend.core.scheduler import Scheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -99,9 +102,15 @@ class Gateway:
|
||||
self._provider_router = get_provider_router()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._project_root = Path(__file__).resolve().parents[2]
|
||||
self._technical_analyzer = StockTechnicalAnalyzer()
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
"""Start gateway server with proper initialization order.
|
||||
|
||||
Phase 1: Start WebSocket server first so frontend can connect immediately
|
||||
Phase 2: Start market data service (pushes data to connected clients)
|
||||
Phase 3: Start scheduler last (triggers trading cycles)
|
||||
"""
|
||||
logger.info(f"Starting gateway on {host}:{port}")
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._provider_router.add_listener(self._on_provider_usage_changed)
|
||||
@@ -124,7 +133,7 @@ class Gateway:
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("status", "initializing")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
self.state_sync.update_state(
|
||||
@@ -171,30 +180,72 @@ class Gateway:
|
||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||
)
|
||||
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
# ======================================================================
|
||||
# PHASE 1: Start WebSocket server first
|
||||
# This allows frontend to connect immediately and receive status updates
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1/4] Starting WebSocket server...")
|
||||
self.state_sync.update_state("status", "websocket_ready")
|
||||
|
||||
if self.scheduler:
|
||||
await self.scheduler.start(self.on_strategy_trigger)
|
||||
elif self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
# Start market status monitoring (only for live mode)
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
|
||||
async with websockets.serve(
|
||||
# Create server but don't block yet - we'll serve inside the context manager
|
||||
server = await websockets.serve(
|
||||
self.handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=30,
|
||||
ping_timeout=60,
|
||||
):
|
||||
logger.info(
|
||||
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
logger.info(f"WebSocket server ready: ws://{host}:{port}")
|
||||
|
||||
# Give a brief moment for any existing clients to reconnect
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 2: Start market data service
|
||||
# Now frontend is connected, start pushing price updates
|
||||
# ======================================================================
|
||||
logger.info("[Phase 2/4] Starting market data service...")
|
||||
self.state_sync.update_state("status", "market_service_starting")
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
self.state_sync.update_state("status", "market_service_ready")
|
||||
logger.info("Market data service ready - price updates active")
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Start market status monitoring
|
||||
# Monitors market open/close and broadcasts status
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/4] Starting market status monitoring...")
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
await asyncio.Future()
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Start scheduler last
|
||||
# Only start trading after everything else is ready
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/4] Starting scheduler...")
|
||||
self.state_sync.update_state("status", "scheduler_starting")
|
||||
|
||||
if self.scheduler:
|
||||
# Wire up heartbeat callback if heartbeat is configured
|
||||
heartbeat_interval = self.config.get("heartbeat_interval", 0)
|
||||
if heartbeat_interval and heartbeat_interval > 0:
|
||||
self.scheduler.set_heartbeat_callback(self.on_heartbeat_trigger)
|
||||
logger.info(
|
||||
f"[Heartbeat] Registered heartbeat callback (interval={heartbeat_interval}s)",
|
||||
)
|
||||
await self.scheduler.start(self.on_strategy_trigger)
|
||||
elif self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
self.state_sync.update_state("status", "running")
|
||||
logger.info(
|
||||
f"Gateway fully operational: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
|
||||
# Keep server running
|
||||
await asyncio.Future()
|
||||
|
||||
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
|
||||
"""Handle provider routing updates from the shared router."""
|
||||
@@ -275,8 +326,8 @@ class Gateway:
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send error response to client: {e}")
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
@@ -343,10 +394,14 @@ class Gateway:
|
||||
await self._handle_get_stock_news_categories(websocket, data)
|
||||
elif msg_type == "get_stock_range_explain":
|
||||
await self._handle_get_stock_range_explain(websocket, data)
|
||||
elif msg_type == "get_stock_insider_trades":
|
||||
await self._handle_get_stock_insider_trades(websocket, data)
|
||||
elif msg_type == "get_stock_story":
|
||||
await self._handle_get_stock_story(websocket, data)
|
||||
elif msg_type == "get_stock_similar_days":
|
||||
await self._handle_get_stock_similar_days(websocket, data)
|
||||
elif msg_type == "get_stock_technical_indicators":
|
||||
await self._handle_get_stock_technical_indicators(websocket, data)
|
||||
elif msg_type == "run_stock_enrich":
|
||||
await self._handle_run_stock_enrich(websocket, data)
|
||||
|
||||
@@ -862,6 +917,94 @@ class Gateway:
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_insider_trades(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": "",
|
||||
"trades": [],
|
||||
"error": "invalid ticker",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
end_date = str(
|
||||
data.get("end_date")
|
||||
or self.state_sync.state.get("current_date")
|
||||
or datetime.now().strftime("%Y-%m-%d")
|
||||
).strip()[:10]
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
limit = int(data.get("limit", 50))
|
||||
|
||||
trades = await asyncio.to_thread(
|
||||
get_insider_trades,
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Sort by transaction date descending
|
||||
sorted_trades = sorted(
|
||||
trades,
|
||||
key=lambda t: t.transaction_date or "",
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Format for frontend
|
||||
formatted_trades = [
|
||||
{
|
||||
"ticker": t.ticker,
|
||||
"name": t.name,
|
||||
"title": t.title,
|
||||
"is_board_director": t.is_board_director,
|
||||
"transaction_date": t.transaction_date,
|
||||
"transaction_shares": t.transaction_shares,
|
||||
"transaction_price_per_share": t.transaction_price_per_share,
|
||||
"transaction_value": t.transaction_value,
|
||||
"shares_owned_before_transaction": t.shares_owned_before_transaction,
|
||||
"shares_owned_after_transaction": t.shares_owned_after_transaction,
|
||||
"security_title": t.security_title,
|
||||
"filing_date": t.filing_date,
|
||||
# Calculated fields
|
||||
"holding_change": (
|
||||
(t.shares_owned_after_transaction or 0)
|
||||
- (t.shares_owned_before_transaction or 0)
|
||||
if t.shares_owned_after_transaction and t.shares_owned_before_transaction
|
||||
else None
|
||||
),
|
||||
"is_buy": (
|
||||
(t.transaction_shares or 0) > 0
|
||||
if t.transaction_shares is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
for t in sorted_trades
|
||||
]
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date or None,
|
||||
"end_date": end_date,
|
||||
"trades": formatted_trades,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_story(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
@@ -969,6 +1112,136 @@ class Gateway:
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_get_stock_technical_indicators(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
data: Dict[str, Any],
|
||||
):
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "ticker is required",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Get price data for the ticker
|
||||
from datetime import datetime, timedelta
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=250) # ~1 year for MA200
|
||||
|
||||
prices = get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
if not prices or len(prices) < 20:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "Insufficient price data",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
# Analyze technical indicators
|
||||
df = prices_to_df(prices)
|
||||
signal = self._technical_analyzer.analyze(ticker, df)
|
||||
|
||||
# Calculate additional volatility metrics
|
||||
import pandas as pd
|
||||
df_sorted = df.sort_values("time").reset_index(drop=True)
|
||||
df_sorted["returns"] = df_sorted["close"].pct_change()
|
||||
|
||||
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
|
||||
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
|
||||
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
|
||||
|
||||
# Calculate MA distance from current price
|
||||
ma_distance = {}
|
||||
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
|
||||
ma_value = getattr(signal, ma_key, None)
|
||||
if ma_value and ma_value > 0:
|
||||
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100
|
||||
else:
|
||||
ma_distance[ma_key] = None
|
||||
|
||||
indicators = {
|
||||
"ticker": ticker,
|
||||
"current_price": signal.current_price,
|
||||
"ma": {
|
||||
"ma5": signal.ma5,
|
||||
"ma10": signal.ma10,
|
||||
"ma20": signal.ma20,
|
||||
"ma50": signal.ma50,
|
||||
"ma200": signal.ma200,
|
||||
"distance": ma_distance,
|
||||
},
|
||||
"rsi": {
|
||||
"rsi14": signal.rsi14,
|
||||
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
|
||||
},
|
||||
"macd": {
|
||||
"macd": signal.macd,
|
||||
"signal": signal.macd_signal,
|
||||
"histogram": signal.macd - signal.macd_signal,
|
||||
},
|
||||
"bollinger": {
|
||||
"upper": signal.bollinger_upper,
|
||||
"mid": signal.bollinger_mid,
|
||||
"lower": signal.bollinger_lower,
|
||||
},
|
||||
"volatility": {
|
||||
"vol_10d": vol_10,
|
||||
"vol_20d": vol_20,
|
||||
"vol_60d": vol_60,
|
||||
"annualized": signal.annualized_volatility_pct,
|
||||
"risk_level": signal.risk_level,
|
||||
},
|
||||
"trend": signal.trend,
|
||||
"mean_reversion": signal.mean_reversion_signal,
|
||||
}
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": indicators,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting technical indicators for {ticker}")
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": str(e),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_run_stock_enrich(
|
||||
self,
|
||||
websocket: ServerConnection,
|
||||
@@ -2288,6 +2561,58 @@ class Gateway:
|
||||
else:
|
||||
await self._run_live_cycle(date, tickers)
|
||||
|
||||
async def on_heartbeat_trigger(self, date: str):
|
||||
"""Run lightweight heartbeat check for all analysts.
|
||||
|
||||
Each analyst reads its HEARTBEAT.md and performs a self-check
|
||||
without running the full trading pipeline.
|
||||
"""
|
||||
logger.info(f"[Heartbeat] Running heartbeat check for {date}")
|
||||
|
||||
tickers = self.config.get("tickers", [])
|
||||
analysts = self.pipeline._all_analysts()
|
||||
|
||||
for analyst in analysts:
|
||||
try:
|
||||
ws_id = getattr(analyst, "workspace_id", None)
|
||||
if ws_id:
|
||||
from backend.agents.workspace_manager import get_workspace_dir
|
||||
ws_dir = get_workspace_dir(ws_id)
|
||||
if ws_dir:
|
||||
from pathlib import Path
|
||||
hb_path = Path(ws_dir) / "HEARTBEAT.md"
|
||||
if hb_path.exists():
|
||||
content = hb_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
hb_task = (
|
||||
f"# 定期主动检查\n\n{content}\n\n"
|
||||
"请执行上述检查并报告结果。"
|
||||
)
|
||||
logger.info(
|
||||
f"[Heartbeat] Running heartbeat for {analyst.name}",
|
||||
)
|
||||
# Build a minimal user message and let the analyst reply
|
||||
from agentscope.message import Msg
|
||||
msg = Msg(
|
||||
role="user",
|
||||
content=hb_task,
|
||||
name="system",
|
||||
)
|
||||
result = await analyst.reply([msg])
|
||||
logger.info(
|
||||
f"[Heartbeat] {analyst.name} heartbeat complete",
|
||||
)
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
f"[Heartbeat] No HEARTBEAT.md for {analyst.name}, skipping",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Heartbeat] {analyst.name} failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||
"""Run backtest cycle with pre-loaded prices"""
|
||||
self.market_service.set_backtest_date(date)
|
||||
@@ -2428,7 +2753,8 @@ class Gateway:
|
||||
market_caps[ticker] = market_cap
|
||||
else:
|
||||
market_caps[ticker] = 1e9
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get market cap for {ticker}, using default 1e9: {e}")
|
||||
market_caps[ticker] = 1e9
|
||||
|
||||
return market_caps
|
||||
|
||||
@@ -48,6 +48,14 @@ CREATE TABLE IF NOT EXISTS signals (
|
||||
signal TEXT,
|
||||
confidence REAL,
|
||||
reasoning_json TEXT,
|
||||
reasons_json TEXT,
|
||||
risks_json TEXT,
|
||||
invalidation TEXT,
|
||||
next_action TEXT,
|
||||
intrinsic_value REAL,
|
||||
fair_value_range_json TEXT,
|
||||
value_gap_pct REAL,
|
||||
valuation_methods_json TEXT,
|
||||
real_return REAL,
|
||||
is_correct TEXT,
|
||||
trade_date TEXT,
|
||||
@@ -270,8 +278,10 @@ class RuntimeDb:
|
||||
"""
|
||||
INSERT OR REPLACE INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
@@ -282,6 +292,14 @@ class RuntimeDb:
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
_json_dumps(payload.get("reasons")),
|
||||
_json_dumps(payload.get("risks")),
|
||||
payload.get("invalidation"),
|
||||
payload.get("next_action"),
|
||||
payload.get("intrinsic_value"),
|
||||
_json_dumps(payload.get("fair_value_range")),
|
||||
payload.get("value_gap_pct"),
|
||||
_json_dumps(payload.get("valuation_methods")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
@@ -313,8 +331,10 @@ class RuntimeDb:
|
||||
"""
|
||||
INSERT INTO signals
|
||||
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||
real_return, is_correct, trade_date, created_at, meta_json)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
signal_id,
|
||||
@@ -325,6 +345,14 @@ class RuntimeDb:
|
||||
payload.get("signal"),
|
||||
payload.get("confidence"),
|
||||
_json_dumps(payload.get("reasoning")),
|
||||
_json_dumps(payload.get("reasons")),
|
||||
_json_dumps(payload.get("risks")),
|
||||
payload.get("invalidation"),
|
||||
payload.get("next_action"),
|
||||
payload.get("intrinsic_value"),
|
||||
_json_dumps(payload.get("fair_value_range")),
|
||||
payload.get("value_gap_pct"),
|
||||
_json_dumps(payload.get("valuation_methods")),
|
||||
payload.get("real_return"),
|
||||
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||
payload.get("date"),
|
||||
@@ -461,6 +489,18 @@ class RuntimeDb:
|
||||
else "该信号暂未完成后验评估"
|
||||
),
|
||||
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
|
||||
# Extended signal fields
|
||||
"signal": row["signal"],
|
||||
"confidence": row["confidence"],
|
||||
"reasoning": json.loads(row["reasoning_json"]) if row["reasoning_json"] else None,
|
||||
"reasons": json.loads(row["reasons_json"]) if row["reasons_json"] else None,
|
||||
"risks": json.loads(row["risks_json"]) if row["risks_json"] else None,
|
||||
"invalidation": row["invalidation"],
|
||||
"next_action": row["next_action"],
|
||||
"intrinsic_value": row["intrinsic_value"],
|
||||
"fair_value_range": json.loads(row["fair_value_range_json"]) if row["fair_value_range_json"] else None,
|
||||
"value_gap_pct": row["value_gap_pct"],
|
||||
"valuation_methods": json.loads(row["valuation_methods_json"]) if row["valuation_methods_json"] else None,
|
||||
}
|
||||
for row in signal_rows
|
||||
]
|
||||
|
||||
@@ -8,15 +8,42 @@ version: 1.0.0
|
||||
|
||||
当用户希望从公司质量、资产负债表强度、盈利能力或长期盈利韧性出发判断标的时,使用这个技能。
|
||||
|
||||
## 工作流程
|
||||
## 1) When to use
|
||||
|
||||
1. 在形成结论前,先检查盈利能力、成长性、财务健康度和经营效率。
|
||||
2. 区分可持续的业务质量和短期噪音。
|
||||
3. 明确指出会推翻当前判断的条件。
|
||||
4. 最终给出清晰的信号、置信度和主要驱动因素。
|
||||
- 适用于需要判断“公司基本面质量是否支撑当前估值/交易观点”的任务。
|
||||
- 优先在中长期视角下使用(财务稳健性、盈利韧性、成长持续性)。
|
||||
- 当任务明确以短线事件驱动为主时,不应单独依赖本技能,应与情绪/技术信号联合。
|
||||
|
||||
## 约束
|
||||
## 2) Required inputs
|
||||
|
||||
- 不要孤立依赖单一指标。
|
||||
- 缺失数据要明确指出。
|
||||
- 当财务质量优劣混杂时,优先给出保守结论。
|
||||
- 最少输入:`tickers`、关键财务指标(盈利、成长、偿债、效率)。
|
||||
- 推荐输入:行业背景、公司阶段、近期重大事件。
|
||||
- 若关键数据缺失(例如利润质量或现金流质量无法判断),必须在结论中显式标注“不足信息风险”,并降低置信度。
|
||||
|
||||
## 3) Decision procedure
|
||||
|
||||
1. 先做四维诊断:盈利能力、成长质量、财务健康度、经营效率。
|
||||
2. 区分“结构性优势”与“周期性改善/短期噪音”。
|
||||
3. 识别关键风险与失效条件(invalidation),明确什么情况会推翻当前判断。
|
||||
4. 合成最终观点:`signal + confidence + drivers + risks`。
|
||||
|
||||
## 4) Tool call policy
|
||||
|
||||
- 优先使用基本面与财务相关工具组获取证据,再形成结论。
|
||||
- 在数据完备且任务允许时,可补充估值相关工具进行交叉验证。
|
||||
- 若工具失败或返回异常:保留已验证证据,明确未验证部分,不允许伪造数据。
|
||||
|
||||
## 5) Output schema
|
||||
|
||||
- `signal`: `bullish | bearish | neutral`
|
||||
- `confidence`: `0-100`
|
||||
- `reasons`: 2-4 条核心驱动
|
||||
- `risks`: 1-3 条关键风险
|
||||
- `invalidation`: 触发观点失效的条件
|
||||
- `next_action`: 对 PM 的可执行建议(如“仅小仓位试错/等待下一季报确认”)
|
||||
|
||||
## 6) Failure fallback
|
||||
|
||||
- 数据稀疏或矛盾时:默认 `neutral` 或低置信度方向结论。
|
||||
- 不允许因单一亮点指标给出高置信度信号。
|
||||
- 当财务质量优劣混杂时,优先保守结论并附加“需补充验证”的下一步建议。
|
||||
|
||||
@@ -8,15 +8,43 @@ version: 1.0.0
|
||||
|
||||
当用户需要把团队分析转化为最终交易决策时,使用这个技能。
|
||||
|
||||
## 工作流程
|
||||
## 1) When to use
|
||||
|
||||
1. 行动前先阅读分析师结论和风险警示。
|
||||
2. 评估当前组合、现金和保证金约束。
|
||||
3. 使用决策工具为每个 ticker 记录一个明确决策。
|
||||
4. 在全部决策记录完成后,总结组合层面的整体理由。
|
||||
- 适用于“最终下单前”的收口阶段:将多方观点转成单一可执行指令。
|
||||
- 必须在获取分析师观点与风险审查后触发,不应跳过上游输入。
|
||||
- 当任务只要求研究观点、未要求执行决策时,不强制触发。
|
||||
|
||||
## 约束
|
||||
## 2) Required inputs
|
||||
|
||||
- 仓位大小必须遵守资金和保证金限制。
|
||||
- 当分析师信心与风险信号不一致时,优先采用更小仓位。
|
||||
- 当任务要求完整决策清单时,不要让任何 ticker 处于未决状态。
|
||||
- 最少输入:`analyst_signals`、`risk_warnings`、`portfolio_state`、`cash`、`margin_requirement`、`prices`。
|
||||
- 推荐输入:会议共识摘要、历史表现偏差、当前组合拥挤度。
|
||||
- 若缺失关键执行约束(现金/保证金/价格),应降级为“条件决策草案”,不可直接给激进仓位。
|
||||
|
||||
## 3) Decision procedure
|
||||
|
||||
1. 汇总并比较 analyst 信号,识别共识与分歧。
|
||||
2. 将风险警示映射到仓位上限与禁开条件。
|
||||
3. 在资金与保证金约束下,为每个 ticker 生成候选动作与数量。
|
||||
4. 对冲突信号执行保守仲裁:降低仓位、提高触发门槛或改为 `hold`。
|
||||
5. 逐个 ticker 记录最终决策,并给出组合级理由。
|
||||
|
||||
## 4) Tool call policy
|
||||
|
||||
- 必须使用决策工具记录每个 ticker 的最终 `action/quantity`。
|
||||
- 在讨论阶段如发现当前团队能力不足,可使用团队工具动态创建或移除 analyst(再继续讨论)。
|
||||
- 若风险工具提示阻断项,优先遵循阻断,不得绕过。
|
||||
- 工具调用失败时:重试一次;仍失败则输出结构化“未完成决策清单”和人工处理建议。
|
||||
|
||||
## 5) Output schema
|
||||
|
||||
- `decisions`: 每个 ticker 的 `{action: long|short|hold, quantity, confidence, reasoning}`
|
||||
- `portfolio_rationale`: 组合层面的配置逻辑与取舍依据
|
||||
- `constraint_check`: 资金、保证金、集中度是否满足
|
||||
- `conflict_resolution`: 对信号冲突的处理说明
|
||||
- `pending_items`: 未决事项与补充数据需求(若有)
|
||||
|
||||
## 6) Failure fallback
|
||||
|
||||
- 当分析师信号与风险结论显著冲突时,默认采用更小仓位或 `hold`。
|
||||
- 当约束校验失败(现金/保证金不足)时,自动下调数量,不输出不可执行指令。
|
||||
- 当任务要求完整清单时,不允许遗漏 ticker;无法决策时必须显式标记 `hold` 并说明原因。
|
||||
|
||||
@@ -8,15 +8,41 @@ version: 1.0.0
|
||||
|
||||
当用户需要识别集中度、波动率、杠杆和情景风险时,使用这个技能。
|
||||
|
||||
## 工作流程
|
||||
## 1) When to use
|
||||
|
||||
1. 按 ticker 和主题检查拟议敞口。
|
||||
2. 识别集中度、波动率、流动性和杠杆方面的风险点。
|
||||
3. 按严重程度排序风险警示。
|
||||
4. 将风险结论转化为给投资经理的具体限制或注意事项。
|
||||
- 适用于下单前风险闸门、仓位复核、组合再平衡前的约束审查。
|
||||
- 当需要把“风险观点”转成“可执行限制”时必须使用本技能。
|
||||
- 若任务仅为单纯行情解读且不涉及仓位执行,可不独立触发。
|
||||
|
||||
## 约束
|
||||
## 2) Required inputs
|
||||
|
||||
- 聚焦可执行的风险控制措施。
|
||||
- 当数据支持时尽量量化限制。
|
||||
- 明确区分致命阻断项和可管理风险。
|
||||
- 最少输入:`portfolio positions`、`cash/margin`、`proposed decisions`、`current prices`。
|
||||
- 推荐输入:波动率指标、流动性指标、相关性/主题暴露。
|
||||
- 若缺失关键风险数据,必须输出“暂定限制”并标明待补数据项。
|
||||
|
||||
## 3) Decision procedure
|
||||
|
||||
1. 按 ticker、行业主题、净敞口做集中度盘点。
|
||||
2. 评估波动、流动性与杠杆压力,识别潜在连锁风险。
|
||||
3. 将风险分级:`fatal blocker / major caution / manageable`。
|
||||
4. 将每类风险映射为明确限制(仓位上限、减仓条件、禁开仓条件)。
|
||||
|
||||
## 4) Tool call policy
|
||||
|
||||
- 优先调用风险工具组量化集中度、保证金压力、波动暴露。
|
||||
- 无量化证据时,不给“无风险”结论;只能给保守警示。
|
||||
- 工具失败时应回退到规则化约束(更低仓位上限、更严格止损条件)。
|
||||
|
||||
## 5) Output schema
|
||||
|
||||
- `risk_level`: `low | medium | high | critical`
|
||||
- `warnings`: 按严重度排序的风险列表(含原因)
|
||||
- `limits`: 可执行限制(仓位/杠杆/单票上限)
|
||||
- `blockers`: 必须先解决的阻断项
|
||||
- `recommendation_to_pm`: 对 PM 的执行建议(允许/限制/禁止)
|
||||
|
||||
## 6) Failure fallback
|
||||
|
||||
- 关键数据缺失或工具不可用时:默认提高一级风险等级并收紧仓位限制。
|
||||
- 无法确认保证金与流动性安全时,默认禁止新增高风险敞口。
|
||||
- 明确区分“硬阻断”与“可带条件执行”的风险,避免含糊建议。
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
from backend import cli
|
||||
from backend.agents.skill_metadata import parse_skill_metadata
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.team_pipeline_config import (
|
||||
ensure_team_pipeline_config,
|
||||
load_team_pipeline_config,
|
||||
update_active_analysts,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_skill_metadata_extended_frontmatter(tmp_path):
|
||||
@@ -70,3 +75,45 @@ def test_skills_enable_disable_and_list(monkeypatch, tmp_path):
|
||||
assert "Enabled" in text_dump
|
||||
assert "Disabled" in text_dump
|
||||
assert any(getattr(item, "title", None) == "Skill Catalog" for item in printed)
|
||||
|
||||
|
||||
def test_install_external_skill_for_agent(tmp_path):
|
||||
manager = SkillsManager(project_root=tmp_path)
|
||||
skill_dir = tmp_path / "downloaded" / "new_skill"
|
||||
skill_dir.mkdir(parents=True, exist_ok=True)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
"name: new_skill\n"
|
||||
"description: external skill\n"
|
||||
"---\n\n"
|
||||
"# New Skill\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
result = manager.install_external_skill_for_agent(
|
||||
config_name="demo",
|
||||
agent_id="risk_manager",
|
||||
source=str(skill_dir),
|
||||
activate=True,
|
||||
)
|
||||
|
||||
assert result["skill_name"] == "new_skill"
|
||||
target = manager.get_agent_local_root("demo", "risk_manager") / "new_skill"
|
||||
assert target.exists()
|
||||
|
||||
|
||||
def test_team_pipeline_active_analyst_updates(tmp_path):
|
||||
project_root = tmp_path
|
||||
ensure_team_pipeline_config(
|
||||
project_root=project_root,
|
||||
config_name="demo",
|
||||
default_analysts=["fundamentals_analyst", "technical_analyst"],
|
||||
)
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name="demo",
|
||||
available_analysts=["fundamentals_analyst", "technical_analyst"],
|
||||
remove=["technical_analyst"],
|
||||
)
|
||||
config = load_team_pipeline_config(project_root, "demo")
|
||||
assert config["discussion"]["active_analysts"] == ["fundamentals_analyst"]
|
||||
|
||||
Reference in New Issue
Block a user