feat: Refactor services architecture and update project structure
- Remove Docker-based microservices (docker-compose.yml, Makefile, Dockerfiles) - Update start-dev.sh to use backend.app:app entry point - Add shared schema and client modules for service communication - Add team coordination modules (messenger, registry, task_delegator, coordinator) - Add evaluation hooks and skill adaptation hooks - Add skill template and gateway server - Update frontend WebSocket URL configuration - Add explain components for insider and technical analysis Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
452
backend/agents/base/evaluation_hook.py
Normal file
452
backend/agents/base/evaluation_hook.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Evaluation hooks system for skills.
|
||||
|
||||
Provides evaluation metric collection and storage for skill performance tracking.
|
||||
Based on the evaluation hooks design in SKILL_TEMPLATE.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
"""Types of evaluation metrics."""
|
||||
HIT_RATE = "hit_rate" # 信号命中率
|
||||
RISK_VIOLATION = "risk_violation" # 风控违例率
|
||||
POSITION_DEVIATION = "position_deviation" # 仓位偏离率
|
||||
PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性
|
||||
SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性
|
||||
DECISION_LATENCY = "decision_latency" # 决策延迟
|
||||
TOOL_USAGE = "tool_usage" # 工具使用率
|
||||
CUSTOM = "custom" # 自定义指标
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationMetric:
|
||||
"""A single evaluation metric."""
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"value": self.value,
|
||||
"timestamp": self.timestamp,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
"""Evaluation result for a skill execution."""
|
||||
skill_name: str
|
||||
run_id: str
|
||||
agent_id: str
|
||||
metrics: List[EvaluationMetric] = field(default_factory=list)
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = field(default_factory=dict)
|
||||
decision: Optional[str] = None
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"skill_name": self.skill_name,
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
"metrics": [m.to_dict() for m in self.metrics],
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"decision": self.decision,
|
||||
"success": self.success,
|
||||
"error_message": self.error_message,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
|
||||
|
||||
class EvaluationHook:
|
||||
"""Hook for collecting skill evaluation metrics.
|
||||
|
||||
This hook collects and stores evaluation metrics after skill execution
|
||||
for later analysis and memory/reflection stages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
):
|
||||
"""Initialize evaluation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store evaluation results
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self._current_evaluation: Optional[EvaluationResult] = None
|
||||
|
||||
def start_evaluation(
|
||||
self,
|
||||
skill_name: str,
|
||||
inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Start a new evaluation session.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill being evaluated
|
||||
inputs: Input parameters for the skill
|
||||
"""
|
||||
self._current_evaluation = EvaluationResult(
|
||||
skill_name=skill_name,
|
||||
run_id=self.run_id,
|
||||
agent_id=self.agent_id,
|
||||
inputs=inputs,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
logger.debug(f"Started evaluation for skill: {skill_name}")
|
||||
|
||||
def add_metric(
|
||||
self,
|
||||
name: str,
|
||||
metric_type: MetricType,
|
||||
value: float,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Add an evaluation metric.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
metric_type: Type of metric
|
||||
value: Metric value
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metric")
|
||||
return
|
||||
|
||||
metric = EvaluationMetric(
|
||||
name=name,
|
||||
metric_type=metric_type,
|
||||
value=value,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._current_evaluation.metrics.append(metric)
|
||||
logger.debug(f"Added metric: {name} = {value}")
|
||||
|
||||
def add_metrics(self, metrics: List[EvaluationMetric]) -> None:
|
||||
"""Add multiple evaluation metrics at once.
|
||||
|
||||
Args:
|
||||
metrics: List of metrics to add
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metrics")
|
||||
return
|
||||
|
||||
self._current_evaluation.metrics.extend(metrics)
|
||||
|
||||
def record_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Record skill outputs.
|
||||
|
||||
Args:
|
||||
outputs: Output from skill execution
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring outputs")
|
||||
return
|
||||
|
||||
self._current_evaluation.outputs = outputs
|
||||
|
||||
def record_decision(self, decision: str) -> None:
|
||||
"""Record the final decision.
|
||||
|
||||
Args:
|
||||
decision: Final decision made by the skill
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring decision")
|
||||
return
|
||||
|
||||
self._current_evaluation.decision = decision
|
||||
|
||||
def complete_evaluation(
|
||||
self,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Optional[EvaluationResult]:
|
||||
"""Complete the evaluation session and persist results.
|
||||
|
||||
Args:
|
||||
success: Whether the skill execution was successful
|
||||
error_message: Error message if failed
|
||||
|
||||
Returns:
|
||||
The completed evaluation result, or None if no active evaluation
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation to complete")
|
||||
return None
|
||||
|
||||
self._current_evaluation.success = success
|
||||
self._current_evaluation.error_message = error_message
|
||||
self._current_evaluation.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Persist to storage
|
||||
result = self._persist_evaluation(self._current_evaluation)
|
||||
|
||||
self._current_evaluation = None
|
||||
logger.debug(f"Completed evaluation for skill: {result.skill_name}")
|
||||
|
||||
return result
|
||||
|
||||
def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult:
|
||||
"""Persist evaluation result to storage.
|
||||
|
||||
Args:
|
||||
evaluation: Evaluation result to persist
|
||||
|
||||
Returns:
|
||||
The persisted evaluation
|
||||
"""
|
||||
# Create run-specific directory
|
||||
run_dir = self.storage_dir / self.run_id
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create agent-specific subdirectory
|
||||
agent_dir = run_dir / self.agent_id
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{evaluation.skill_name}_{timestamp}.json"
|
||||
filepath = agent_dir / filename
|
||||
|
||||
# Write evaluation result
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Persisted evaluation to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist evaluation: {e}")
|
||||
|
||||
return evaluation
|
||||
|
||||
def cancel_evaluation(self) -> None:
|
||||
"""Cancel the current evaluation session without saving."""
|
||||
if self._current_evaluation is not None:
|
||||
logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}")
|
||||
self._current_evaluation = None
|
||||
|
||||
|
||||
class EvaluationCollector:
|
||||
"""Collector for aggregating evaluation metrics across runs.
|
||||
|
||||
Provides methods to query and analyze evaluation results.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize evaluation collector.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory containing evaluation results
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
|
||||
def get_run_evaluations(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[EvaluationResult]:
|
||||
"""Get all evaluations for a run.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Optional agent identifier to filter by
|
||||
|
||||
Returns:
|
||||
List of evaluation results
|
||||
"""
|
||||
run_dir = self.storage_dir / run_id
|
||||
if not run_dir.exists():
|
||||
return []
|
||||
|
||||
evaluations = []
|
||||
|
||||
agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir()
|
||||
|
||||
for agent_dir in agent_dirs:
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
|
||||
for eval_file in agent_dir.glob("*.json"):
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
evaluations.append(self._parse_evaluation(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load evaluation {eval_file}: {e}")
|
||||
|
||||
return evaluations
|
||||
|
||||
def get_skill_metrics(
|
||||
self,
|
||||
skill_name: str,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> List[EvaluationMetric]:
|
||||
"""Get all metrics for a specific skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
List of metrics for the skill
|
||||
"""
|
||||
metrics = []
|
||||
|
||||
if run_ids is None:
|
||||
run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()]
|
||||
|
||||
for run_id in run_ids:
|
||||
evaluations = self.get_run_evaluations(run_id)
|
||||
for eval_result in evaluations:
|
||||
if eval_result.skill_name == skill_name:
|
||||
metrics.extend(eval_result.metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_skill_stats(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate statistics for a specific metric type.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric to calculate
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
Dictionary with min, max, avg, count statistics
|
||||
"""
|
||||
metrics = self.get_skill_metrics(skill_name, run_ids)
|
||||
filtered = [m for m in metrics if m.metric_type == metric_type]
|
||||
|
||||
if not filtered:
|
||||
return {"count": 0}
|
||||
|
||||
values = [m.value for m in filtered]
|
||||
return {
|
||||
"count": len(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"avg": sum(values) / len(values),
|
||||
}
|
||||
|
||||
def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult:
|
||||
"""Parse evaluation data into EvaluationResult.
|
||||
|
||||
Args:
|
||||
data: Raw evaluation data
|
||||
|
||||
Returns:
|
||||
Parsed EvaluationResult
|
||||
"""
|
||||
metrics = []
|
||||
for m in data.get("metrics", []):
|
||||
metrics.append(EvaluationMetric(
|
||||
name=m["name"],
|
||||
metric_type=MetricType(m["metric_type"]),
|
||||
value=m["value"],
|
||||
timestamp=m.get("timestamp", ""),
|
||||
metadata=m.get("metadata", {}),
|
||||
))
|
||||
|
||||
return EvaluationResult(
|
||||
skill_name=data["skill_name"],
|
||||
run_id=data["run_id"],
|
||||
agent_id=data["agent_id"],
|
||||
metrics=metrics,
|
||||
inputs=data.get("inputs", {}),
|
||||
outputs=data.get("outputs", {}),
|
||||
decision=data.get("decision"),
|
||||
success=data.get("success", True),
|
||||
error_message=data.get("error_message"),
|
||||
started_at=data.get("started_at"),
|
||||
completed_at=data.get("completed_at"),
|
||||
)
|
||||
|
||||
|
||||
def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]:
|
||||
"""Parse evaluation hooks from SKILL.md.
|
||||
|
||||
Extracts the Optional: Evaluation hooks section from skill documentation.
|
||||
|
||||
Args:
|
||||
skill_dir: Skill directory path
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation hook definitions
|
||||
"""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
# Extract evaluation hooks section
|
||||
if "## Optional: Evaluation hooks" in content:
|
||||
start = content.find("## Optional: Evaluation hooks")
|
||||
# Find the next ## section or end of file
|
||||
next_section = content.find("\n## ", start + 1)
|
||||
if next_section == -1:
|
||||
eval_section = content[start:]
|
||||
else:
|
||||
eval_section = content[start:next_section]
|
||||
|
||||
# Parse metrics from the section
|
||||
metrics = []
|
||||
for metric_type in MetricType:
|
||||
if metric_type.value.replace("_", " ") in eval_section.lower():
|
||||
metrics.append(metric_type.value)
|
||||
|
||||
return {
|
||||
"supported_metrics": metrics,
|
||||
"section_content": eval_section.strip(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse evaluation hooks: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"EvaluationMetric",
|
||||
"EvaluationResult",
|
||||
"EvaluationHook",
|
||||
"EvaluationCollector",
|
||||
"parse_evaluation_hooks",
|
||||
]
|
||||
489
backend/agents/base/skill_adaptation_hook.py
Normal file
489
backend/agents/base/skill_adaptation_hook.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Skill adaptation hook for automatic evaluation-to-iteration闭环.
|
||||
|
||||
Monitors evaluation metrics against configurable thresholds and triggers
|
||||
automatic skill reload or logs warnings when thresholds are breached.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .evaluation_hook import (
|
||||
EvaluationCollector,
|
||||
EvaluationResult,
|
||||
MetricType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdaptationAction(Enum):
|
||||
"""Actions to take when threshold is breached."""
|
||||
RELOAD = "reload" # 自动重新加载技能
|
||||
WARN = "warn" # 记录警告供人工审核
|
||||
BOTH = "both" # 同时执行重载和警告
|
||||
NONE = "none" # 不做任何操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationThreshold:
|
||||
"""Threshold configuration for a metric."""
|
||||
metric_type: MetricType
|
||||
operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq
|
||||
value: float = 0.0
|
||||
window_size: int = 10 # 移动窗口大小,用于计算滑动平均
|
||||
min_samples: int = 5 # 最少样本数才触发检查
|
||||
action: AdaptationAction = AdaptationAction.WARN
|
||||
cooldown_seconds: int = 300 # 触发后的冷却时间
|
||||
|
||||
def evaluate(self, current_value: float) -> bool:
|
||||
"""Evaluate if threshold is breached."""
|
||||
ops = {
|
||||
"lt": lambda x, y: x < y,
|
||||
"lte": lambda x, y: x <= y,
|
||||
"gt": lambda x, y: x > y,
|
||||
"gte": lambda x, y: x >= y,
|
||||
"eq": lambda x, y: x == y,
|
||||
}
|
||||
op_func = ops.get(self.operator)
|
||||
if op_func is None:
|
||||
logger.warning(f"Unknown operator: {self.operator}")
|
||||
return False
|
||||
return op_func(current_value, self.value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"metric_type": self.metric_type.value,
|
||||
"operator": self.operator,
|
||||
"value": self.value,
|
||||
"window_size": self.window_size,
|
||||
"min_samples": self.min_samples,
|
||||
"action": self.action.value,
|
||||
"cooldown_seconds": self.cooldown_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationEvent:
|
||||
"""Record of an adaptation trigger event."""
|
||||
timestamp: str
|
||||
skill_name: str
|
||||
metric_type: MetricType
|
||||
threshold: AdaptationThreshold
|
||||
current_value: float
|
||||
avg_value: float
|
||||
action_taken: AdaptationAction
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"skill_name": self.skill_name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"threshold": self.threshold.to_dict(),
|
||||
"current_value": self.current_value,
|
||||
"avg_value": self.avg_value,
|
||||
"action_taken": self.action_taken.value,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class SkillAdaptationHook:
|
||||
"""Hook for monitoring evaluation metrics and triggering skill adaptation.
|
||||
|
||||
This hook wraps EvaluationHook to add threshold-based adaptation logic.
|
||||
When metrics breach configured thresholds, it can:
|
||||
- Automatically reload skills via SkillsManager
|
||||
- Log warnings for human review
|
||||
- Both
|
||||
"""
|
||||
|
||||
# Default thresholds for common metrics
|
||||
DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.HIT_RATE,
|
||||
operator="lt",
|
||||
value=0.5,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=600,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.RISK_VIOLATION,
|
||||
operator="gt",
|
||||
value=0.1,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.DECISION_LATENCY,
|
||||
operator="gt",
|
||||
value=5000, # 5 seconds
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
collector: Optional[EvaluationCollector] = None,
|
||||
):
|
||||
"""Initialize skill adaptation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store adaptation events
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
thresholds: Custom threshold configurations (uses defaults if None)
|
||||
collector: Optional EvaluationCollector for historical data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self.thresholds = thresholds or self.DEFAULT_THRESHOLDS
|
||||
self.collector = collector or EvaluationCollector(storage_dir)
|
||||
|
||||
# Track cooldowns to prevent rapid re-triggering
|
||||
self._cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
# Store recent metrics in memory for quick access
|
||||
self._recent_metrics: Dict[str, List[float]] = {}
|
||||
|
||||
# Pending adaptation events
|
||||
self._pending_events: List[AdaptationEvent] = []
|
||||
|
||||
def check_threshold(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> Optional[AdaptationEvent]:
|
||||
"""Check if a metric breaches any threshold.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric
|
||||
current_value: Current metric value
|
||||
|
||||
Returns:
|
||||
AdaptationEvent if threshold breached, None otherwise
|
||||
"""
|
||||
# Find applicable thresholds
|
||||
applicable_thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type == metric_type
|
||||
]
|
||||
|
||||
if not applicable_thresholds:
|
||||
return None
|
||||
|
||||
# Check cooldown
|
||||
cooldown_key = f"{skill_name}:{metric_type.value}"
|
||||
now = datetime.now()
|
||||
last_trigger = self._cooldowns.get(cooldown_key)
|
||||
|
||||
# Store current value first for avg calculation
|
||||
self._store_metric(cooldown_key, current_value)
|
||||
|
||||
for threshold in applicable_thresholds:
|
||||
if last_trigger:
|
||||
elapsed = (now - last_trigger).total_seconds()
|
||||
if elapsed < threshold.cooldown_seconds:
|
||||
continue
|
||||
|
||||
# Evaluate threshold
|
||||
if threshold.evaluate(current_value):
|
||||
# Calculate moving average
|
||||
avg_value = self._calculate_avg(skill_name, metric_type, current_value)
|
||||
|
||||
# Check minimum samples (allow immediate trigger if min_samples <= 1)
|
||||
sample_count = len(self._recent_metrics.get(cooldown_key, []))
|
||||
if threshold.min_samples > 1 and sample_count < threshold.min_samples:
|
||||
# Not enough samples yet
|
||||
continue
|
||||
|
||||
# Trigger adaptation
|
||||
event = AdaptationEvent(
|
||||
timestamp=now.isoformat(),
|
||||
skill_name=skill_name,
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=current_value,
|
||||
avg_value=avg_value,
|
||||
action_taken=threshold.action,
|
||||
details={
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Update cooldown
|
||||
self._cooldowns[cooldown_key] = now
|
||||
|
||||
# Persist event
|
||||
self._persist_event(event)
|
||||
|
||||
logger.info(
|
||||
f"Threshold breached for {skill_name}.{metric_type.value}: "
|
||||
f"current={current_value}, avg={avg_value}, action={threshold.action.value}"
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_avg(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> float:
|
||||
"""Calculate moving average for a metric."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
values = self._recent_metrics.get(key, [])
|
||||
if not values:
|
||||
return current_value
|
||||
return sum(values) / len(values)
|
||||
|
||||
def _store_metric(self, key: str, value: float) -> None:
|
||||
"""Store metric value with sliding window."""
|
||||
if key not in self._recent_metrics:
|
||||
self._recent_metrics[key] = []
|
||||
self._recent_metrics[key].append(value)
|
||||
# Keep only last 100 values
|
||||
if len(self._recent_metrics[key]) > 100:
|
||||
self._recent_metrics[key] = self._recent_metrics[key][-100:]
|
||||
|
||||
def _persist_event(self, event: AdaptationEvent) -> None:
|
||||
"""Persist adaptation event to storage."""
|
||||
run_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json"
|
||||
filepath = run_dir / filename
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(event.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"Persisted adaptation event to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist adaptation event: {e}")
|
||||
|
||||
# Also add to pending list
|
||||
self._pending_events.append(event)
|
||||
|
||||
def get_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warning events that need human review."""
|
||||
return [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH)
|
||||
]
|
||||
|
||||
def clear_pending_warnings(self) -> None:
|
||||
"""Clear pending warnings after they have been reviewed."""
|
||||
self._pending_events = [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken == AdaptationAction.RELOAD
|
||||
]
|
||||
|
||||
def get_recent_events(
|
||||
self,
|
||||
skill_name: Optional[str] = None,
|
||||
metric_type: Optional[MetricType] = None,
|
||||
limit: int = 50,
|
||||
) -> List[AdaptationEvent]:
|
||||
"""Get recent adaptation events.
|
||||
|
||||
Args:
|
||||
skill_name: Optional filter by skill name
|
||||
metric_type: Optional filter by metric type
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of recent adaptation events
|
||||
"""
|
||||
events_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
if not events_dir.exists():
|
||||
return []
|
||||
|
||||
events = []
|
||||
for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]:
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
event = self._parse_event(data)
|
||||
if skill_name and event.skill_name != skill_name:
|
||||
continue
|
||||
if metric_type and event.metric_type != metric_type:
|
||||
continue
|
||||
events.append(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load adaptation event {eval_file}: {e}")
|
||||
|
||||
return events
|
||||
|
||||
def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent:
|
||||
"""Parse adaptation event from JSON data."""
|
||||
threshold_data = data.get("threshold", {})
|
||||
metric_type = MetricType(threshold_data.get("metric_type", "custom"))
|
||||
|
||||
threshold = AdaptationThreshold(
|
||||
metric_type=metric_type,
|
||||
operator=threshold_data.get("operator", "lt"),
|
||||
value=threshold_data.get("value", 0.0),
|
||||
window_size=threshold_data.get("window_size", 10),
|
||||
min_samples=threshold_data.get("min_samples", 5),
|
||||
action=AdaptationAction(threshold_data.get("action", "warn")),
|
||||
cooldown_seconds=threshold_data.get("cooldown_seconds", 300),
|
||||
)
|
||||
|
||||
return AdaptationEvent(
|
||||
timestamp=data.get("timestamp", ""),
|
||||
skill_name=data.get("skill_name", ""),
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=data.get("current_value", 0.0),
|
||||
avg_value=data.get("avg_value", 0.0),
|
||||
action_taken=AdaptationAction(data.get("action_taken", "warn")),
|
||||
details=data.get("details", {}),
|
||||
)
|
||||
|
||||
def add_threshold(self, threshold: AdaptationThreshold) -> None:
|
||||
"""Add a new threshold configuration."""
|
||||
self.thresholds.append(threshold)
|
||||
|
||||
def remove_threshold(self, metric_type: MetricType) -> None:
|
||||
"""Remove all thresholds for a specific metric type."""
|
||||
self.thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type != metric_type
|
||||
]
|
||||
|
||||
def update_threshold(
|
||||
self,
|
||||
metric_type: MetricType,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update threshold configuration for a metric type."""
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(threshold, key):
|
||||
setattr(threshold, key, value)
|
||||
|
||||
def get_thresholds(self) -> List[AdaptationThreshold]:
|
||||
"""Get current threshold configurations."""
|
||||
return list(self.thresholds)
|
||||
|
||||
def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool:
|
||||
"""Check if a skill/metric combination is in cooldown period."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
last_trigger = self._cooldowns.get(key)
|
||||
if not last_trigger:
|
||||
return False
|
||||
|
||||
# Find the threshold for this metric type
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
elapsed = (datetime.now() - last_trigger).total_seconds()
|
||||
return elapsed < threshold.cooldown_seconds
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AdaptationManager:
|
||||
"""Manager for coordinating skill adaptation across multiple agents.
|
||||
|
||||
Provides centralized tracking of adaptation events and skill reloads.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize adaptation manager.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory for storing adaptation data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self._hooks: Dict[str, SkillAdaptationHook] = {}
|
||||
|
||||
def get_hook(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
) -> SkillAdaptationHook:
|
||||
"""Get or create an adaptation hook for an agent.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Agent identifier
|
||||
thresholds: Optional custom thresholds
|
||||
|
||||
Returns:
|
||||
SkillAdaptationHook instance
|
||||
"""
|
||||
key = f"{run_id}:{agent_id}"
|
||||
if key not in self._hooks:
|
||||
self._hooks[key] = SkillAdaptationHook(
|
||||
storage_dir=self.storage_dir,
|
||||
run_id=run_id,
|
||||
agent_id=agent_id,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
return self._hooks[key]
|
||||
|
||||
def get_all_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warnings from all hooks."""
|
||||
warnings = []
|
||||
for hook in self._hooks.values():
|
||||
warnings.extend(hook.get_pending_warnings())
|
||||
return warnings
|
||||
|
||||
def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]:
|
||||
"""Get all adaptation events for a run."""
|
||||
events = []
|
||||
for hook in self._hooks.values():
|
||||
if hook.run_id == run_id:
|
||||
events.extend(hook.get_recent_events())
|
||||
return events
|
||||
|
||||
|
||||
# Global manager instance
|
||||
_adaptation_manager: Optional[AdaptationManager] = None
|
||||
|
||||
|
||||
def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager:
|
||||
"""Get global adaptation manager instance.
|
||||
|
||||
Args:
|
||||
storage_dir: Optional storage directory (required on first call)
|
||||
|
||||
Returns:
|
||||
AdaptationManager instance
|
||||
"""
|
||||
global _adaptation_manager
|
||||
if _adaptation_manager is None:
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir required on first initialization")
|
||||
_adaptation_manager = AdaptationManager(storage_dir)
|
||||
return _adaptation_manager
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdaptationAction",
|
||||
"AdaptationThreshold",
|
||||
"AdaptationEvent",
|
||||
"SkillAdaptationHook",
|
||||
"AdaptationManager",
|
||||
"get_adaptation_manager",
|
||||
]
|
||||
18
backend/agents/team/__init__.py
Normal file
18
backend/agents/team/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Team module for multi-agent orchestration.
|
||||
|
||||
Provides inter-agent communication, task delegation, and coordination
|
||||
for subagent spawning and lifecycle management.
|
||||
"""
|
||||
|
||||
from .messenger import AgentMessenger
|
||||
from .task_delegator import TaskDelegator
|
||||
from .team_coordinator import TeamCoordinator
|
||||
from .registry import AgentRegistry
|
||||
|
||||
__all__ = [
|
||||
"AgentMessenger",
|
||||
"TaskDelegator",
|
||||
"TeamCoordinator",
|
||||
"AgentRegistry",
|
||||
]
|
||||
225
backend/agents/team/messenger.py
Normal file
225
backend/agents/team/messenger.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentMessenger - Pub/sub inter-agent communication.
|
||||
|
||||
Provides broadcast(), send(), and subscribe() for message passing
|
||||
between agents using AgentScope's Msg format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentMessenger:
|
||||
"""Pub/sub messenger for inter-agent communication.
|
||||
|
||||
Supports:
|
||||
- broadcast(): Send message to all subscribers
|
||||
- send(): Send message to specific agent
|
||||
- subscribe(): Register callback for agent messages
|
||||
- announce(): Send system-wide announcement
|
||||
- enable_auto_broadcast: Auto-broadcast agent replies to all participants
|
||||
|
||||
Messages use AgentScope's Msg format for compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, enable_auto_broadcast: bool = False):
|
||||
"""Initialize the messenger.
|
||||
|
||||
Args:
|
||||
enable_auto_broadcast: If True, agent replies are automatically
|
||||
broadcast to all subscribed agents.
|
||||
"""
|
||||
self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {}
|
||||
self._inbox: Dict[str, List[Msg]] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
self._enable_auto_broadcast = enable_auto_broadcast
|
||||
self._participants: Set[str] = set()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
agent_id: str,
|
||||
callback: Callable[[Msg], None],
|
||||
) -> None:
|
||||
"""Subscribe an agent to receive messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Async function to call when message received
|
||||
"""
|
||||
if agent_id not in self._subscriptions:
|
||||
self._subscriptions[agent_id] = []
|
||||
self._subscriptions[agent_id].append(callback)
|
||||
logger.debug("Agent %s subscribed to messages", agent_id)
|
||||
|
||||
def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None:
|
||||
"""Unsubscribe an agent from messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Callback to remove
|
||||
"""
|
||||
if agent_id in self._subscriptions:
|
||||
try:
|
||||
self._subscriptions[agent_id].remove(callback)
|
||||
logger.debug("Agent %s unsubscribed from messages", agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def send(
|
||||
self,
|
||||
to_agent: str,
|
||||
message: Msg,
|
||||
) -> None:
|
||||
"""Send message to specific agent.
|
||||
|
||||
Args:
|
||||
to_agent: Target agent identifier
|
||||
message: Message to send (uses Msg format)
|
||||
"""
|
||||
async def _deliver():
|
||||
if to_agent in self._subscriptions:
|
||||
for callback in self._subscriptions[to_agent]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error delivering message to %s: %s",
|
||||
to_agent,
|
||||
e,
|
||||
)
|
||||
|
||||
await _deliver()
|
||||
|
||||
async def broadcast(self, message: Msg) -> None:
|
||||
"""Broadcast message to all subscribed agents.
|
||||
|
||||
Args:
|
||||
message: Message to broadcast (uses Msg format)
|
||||
"""
|
||||
delivery_tasks = []
|
||||
for agent_id, callbacks in self._subscriptions.items():
|
||||
for callback in callbacks:
|
||||
async def _deliver(cb=callback, aid=agent_id):
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(cb):
|
||||
await cb(message)
|
||||
else:
|
||||
cb(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error broadcasting to %s: %s",
|
||||
aid,
|
||||
e,
|
||||
)
|
||||
delivery_tasks.append(_deliver())
|
||||
|
||||
if delivery_tasks:
|
||||
await asyncio.gather(*delivery_tasks)
|
||||
|
||||
def inbox(self, agent_id: str) -> List[Msg]:
|
||||
"""Get and clear inbox for agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of messages in inbox
|
||||
"""
|
||||
messages = self._inbox.get(agent_id, [])
|
||||
self._inbox[agent_id] = []
|
||||
return messages
|
||||
|
||||
def inbox_count(self, agent_id: str) -> int:
|
||||
"""Count messages in agent's inbox without clearing.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Number of messages waiting
|
||||
"""
|
||||
return len(self._inbox.get(agent_id, []))
|
||||
|
||||
def add_participant(self, agent_id: str) -> None:
|
||||
"""Add a participant to the messenger.
|
||||
|
||||
Participants are the agents that can receive auto-broadcast messages.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to add
|
||||
"""
|
||||
self._participants.add(agent_id)
|
||||
logger.debug("Agent %s added as participant", agent_id)
|
||||
|
||||
def remove_participant(self, agent_id: str) -> None:
|
||||
"""Remove a participant from the messenger.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
self._participants.discard(agent_id)
|
||||
logger.debug("Agent %s removed from participants", agent_id)
|
||||
|
||||
@property
|
||||
def enable_auto_broadcast(self) -> bool:
|
||||
"""Check if auto_broadcast is enabled."""
|
||||
return self._enable_auto_broadcast
|
||||
|
||||
@enable_auto_broadcast.setter
|
||||
def enable_auto_broadcast(self, value: bool) -> None:
|
||||
"""Enable or disable auto_broadcast."""
|
||||
self._enable_auto_broadcast = value
|
||||
logger.debug("Auto_broadcast set to %s", value)
|
||||
|
||||
async def announce(self, message: Msg) -> None:
|
||||
"""Send a system-wide announcement to all participants.
|
||||
|
||||
Unlike broadcast(), announce() sends a message from the system/host
|
||||
to all participants without requiring prior subscription.
|
||||
|
||||
Args:
|
||||
message: Announcement message (uses Msg format)
|
||||
"""
|
||||
logger.info("System announcement: %s", message.content)
|
||||
await self.broadcast(message)
|
||||
|
||||
async def auto_broadcast(self, message: Msg) -> None:
|
||||
"""Auto-broadcast message to all participants.
|
||||
|
||||
This is called internally when enable_auto_broadcast is True.
|
||||
Broadcasts to all registered participants.
|
||||
|
||||
Args:
|
||||
message: Message to auto-broadcast (uses Msg format)
|
||||
"""
|
||||
if not self._enable_auto_broadcast:
|
||||
return
|
||||
|
||||
# Broadcast to all participants
|
||||
for participant_id in self._participants:
|
||||
if participant_id in self._subscriptions:
|
||||
for callback in self._subscriptions[participant_id]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error auto-broadcasting to %s: %s",
|
||||
participant_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AgentMessenger"]
|
||||
188
backend/agents/team/registry.py
Normal file
188
backend/agents/team/registry.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentRegistry - Agent registration and lookup by role.
|
||||
|
||||
Provides register(), unregister(), and get_by_role() for agent
|
||||
discovery and management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Registry for agent instances with role-based lookup.
|
||||
|
||||
Supports:
|
||||
- register(): Add agent with roles
|
||||
- unregister(): Remove agent
|
||||
- get_by_role(): Find agents by role
|
||||
- get_by_id(): Get specific agent
|
||||
|
||||
Each agent can have multiple roles for flexible dispatch.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._roles: Dict[str, List[str]] = {}
|
||||
self._agent_roles: Dict[str, List[str]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: Any,
|
||||
roles: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Register an agent with optional roles.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
roles: Optional list of role strings
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
self._agent_roles[agent_id] = roles or []
|
||||
|
||||
for role in self._agent_roles[agent_id]:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info(
|
||||
"Registered agent %s with roles %s",
|
||||
agent_id,
|
||||
self._agent_roles[agent_id],
|
||||
)
|
||||
|
||||
def unregister(self, agent_id: str) -> bool:
|
||||
"""Unregister an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
|
||||
Returns:
|
||||
True if agent was removed
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
roles = self._agent_roles.pop(agent_id, [])
|
||||
for role in roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
return True
|
||||
|
||||
def get_by_id(self, agent_id: str) -> Optional[Any]:
|
||||
"""Get agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance or None
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def get_by_role(self, role: str) -> List[Any]:
|
||||
"""Get all agents with a given role.
|
||||
|
||||
Args:
|
||||
role: Role string to search for
|
||||
|
||||
Returns:
|
||||
List of agent instances with the role
|
||||
"""
|
||||
agent_ids = self._roles.get(role, [])
|
||||
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||
|
||||
def get_by_roles(self, roles: List[str]) -> List[Any]:
|
||||
"""Get agents matching ANY of the given roles.
|
||||
|
||||
Args:
|
||||
roles: List of role strings
|
||||
|
||||
Returns:
|
||||
List of unique agent instances matching any role
|
||||
"""
|
||||
seen = set()
|
||||
result = []
|
||||
for role in roles:
|
||||
for agent in self.get_by_role(role):
|
||||
if id(agent) not in seen:
|
||||
seen.add(id(agent))
|
||||
result.append(agent)
|
||||
return result
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
def list_roles(self) -> List[str]:
|
||||
"""List all registered roles.
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._roles.keys())
|
||||
|
||||
def list_roles_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List roles for specific agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._agent_roles.get(agent_id, []))
|
||||
|
||||
def update_roles(self, agent_id: str, roles: List[str]) -> None:
|
||||
"""Update roles for an existing agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
roles: New list of roles
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
raise KeyError(f"Agent not registered: {agent_id}")
|
||||
|
||||
old_roles = self._agent_roles.get(agent_id, [])
|
||||
for role in old_roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._agent_roles[agent_id] = roles
|
||||
for role in roles:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info("Updated roles for agent %s: %s", agent_id, roles)
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["AgentRegistry"]
|
||||
343
backend/agents/team/task_delegator.py
Normal file
343
backend/agents/team/task_delegator.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskDelegator - Subagent spawning and task delegation.
|
||||
|
||||
Provides delegate() and delegate_parallel() for spawning subagents
|
||||
with separate context and memory. Supports runtime dynamic subagent
|
||||
definition via task_data with description, prompt, and tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Type alias for subagent specification
|
||||
SubagentSpec = Dict[str, Any]
|
||||
"""Subagent specification format:
|
||||
{
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"], # Optional: list of tool names
|
||||
"model": "gpt-4o", # Optional: model name
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TaskDelegator:
|
||||
"""Delegates tasks to subagents with isolated context.
|
||||
|
||||
Supports:
|
||||
- delegate(): Spawn single subagent for task
|
||||
- delegate_parallel(): Spawn multiple subagents concurrently
|
||||
- delegate_task(): Delegate with dynamic subagent definition from task_data
|
||||
|
||||
Each subagent gets its own memory/context to prevent
|
||||
cross-contamination.
|
||||
|
||||
Dynamic Subagent Definition:
|
||||
task_data can include an "agents" dict to define subagents inline:
|
||||
|
||||
task_data = {
|
||||
"task": "Review the code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer for quality and security.",
|
||||
"prompt": "Analyze code quality and suggest improvements.",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, messenger: Any, registry: Any):
|
||||
"""Initialize TaskDelegator.
|
||||
|
||||
Args:
|
||||
messenger: AgentMessenger for communication
|
||||
registry: AgentRegistry for agent lookup
|
||||
"""
|
||||
self._messenger = messenger
|
||||
self._registry = registry
|
||||
self._subagents: Dict[str, Any] = {}
|
||||
self._dynamic_subagents: Dict[str, SubagentSpec] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
async def delegate(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: Callable[..., Awaitable[Msg]],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> asyncio.Task:
|
||||
"""Delegate task to a single subagent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for this subagent instance
|
||||
task: Async function representing the task
|
||||
context: Optional context dict for the subagent
|
||||
|
||||
Returns:
|
||||
asyncio.Task for the delegated task
|
||||
"""
|
||||
async def _run_with_context():
|
||||
result = await task(context or {})
|
||||
return result
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
logger.info("Delegated task to subagent: %s", agent_id)
|
||||
return self._tasks[agent_id]
|
||||
|
||||
async def delegate_parallel(
|
||||
self,
|
||||
tasks: List[Dict[str, Any]],
|
||||
) -> List[asyncio.Task]:
|
||||
"""Delegate multiple tasks in parallel.
|
||||
|
||||
Args:
|
||||
tasks: List of task dicts with keys:
|
||||
- agent_id: Unique identifier
|
||||
- task: Async function to execute
|
||||
- context: Optional context dict
|
||||
|
||||
Returns:
|
||||
List of asyncio.Task for all delegated tasks
|
||||
"""
|
||||
async def _run_task(task_def: Dict[str, Any]):
|
||||
agent_id = task_def["agent_id"]
|
||||
task_func = task_def["task"]
|
||||
context = task_def.get("context", {})
|
||||
|
||||
async def _run_with_context():
|
||||
return await task_func(context)
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
return self._tasks[agent_id]
|
||||
|
||||
gathered_tasks = await asyncio.gather(
|
||||
*[_run_task(t) for t in tasks],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)]
|
||||
logger.info(
|
||||
"Delegated %d tasks in parallel (%d succeeded)",
|
||||
len(tasks),
|
||||
len(valid_tasks),
|
||||
)
|
||||
return valid_tasks
|
||||
|
||||
async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any:
|
||||
"""Wait for subagent task to complete.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
Task result
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task doesn't complete in time
|
||||
KeyError: If agent_id not found
|
||||
"""
|
||||
if agent_id not in self._tasks:
|
||||
raise KeyError(f"Unknown subagent: {agent_id}")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._tasks[agent_id],
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Task %s timed out after %s seconds", agent_id, timeout)
|
||||
raise
|
||||
|
||||
async def cancel(self, agent_id: str) -> bool:
|
||||
"""Cancel a subagent task.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
|
||||
Returns:
|
||||
True if task was cancelled
|
||||
"""
|
||||
if agent_id in self._tasks:
|
||||
self._tasks[agent_id].cancel()
|
||||
del self._tasks[agent_id]
|
||||
logger.info("Cancelled subagent task: %s", agent_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_tasks(self) -> List[str]:
|
||||
"""List active subagent task IDs.
|
||||
|
||||
Returns:
|
||||
List of agent_ids with pending tasks
|
||||
"""
|
||||
return list(self._tasks.keys())
|
||||
|
||||
@property
|
||||
def tasks(self) -> Dict[str, asyncio.Task]:
|
||||
"""Get copy of active tasks dict."""
|
||||
return dict(self._tasks)
|
||||
|
||||
def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Delegate a task with optional dynamic subagent definition.
|
||||
|
||||
Supports runtime subagent definition via task_data["agents"]:
|
||||
|
||||
task_data = {
|
||||
"task": "Review code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
task_type: Type of task (e.g., "analysis", "review", "research")
|
||||
task_data: Task payload, may include "agents" for dynamic subagent def
|
||||
target_agent: Optional specific agent ID to delegate to
|
||||
|
||||
Returns:
|
||||
Dict with "success" and result/error
|
||||
"""
|
||||
try:
|
||||
# Extract dynamic subagent definitions from task_data
|
||||
agents_def = task_data.get("agents", {})
|
||||
|
||||
if agents_def:
|
||||
# Register dynamic subagents
|
||||
for agent_name, agent_spec in agents_def.items():
|
||||
self._dynamic_subagents[agent_name] = agent_spec
|
||||
logger.info(
|
||||
"Registered dynamic subagent: %s (description: %s)",
|
||||
agent_name,
|
||||
agent_spec.get("description", "")[:50],
|
||||
)
|
||||
|
||||
# Determine target agent
|
||||
effective_target = target_agent
|
||||
if not effective_target:
|
||||
# Use first available dynamic subagent or default
|
||||
if agents_def:
|
||||
effective_target = next(iter(agents_def.keys()))
|
||||
else:
|
||||
effective_target = "default"
|
||||
|
||||
# Execute the task
|
||||
task_result = self._execute_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=effective_target,
|
||||
)
|
||||
|
||||
# Clean up dynamic subagents after execution
|
||||
for agent_name in agents_def.keys():
|
||||
self._dynamic_subagents.pop(agent_name, None)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": task_result,
|
||||
"subagents_used": list(agents_def.keys()) if agents_def else [],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Task delegation failed: %s", e)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def _execute_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: str,
|
||||
) -> Any:
|
||||
"""Execute the delegated task.
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
task_data: Task payload
|
||||
target_agent: Target agent identifier
|
||||
|
||||
Returns:
|
||||
Task execution result
|
||||
"""
|
||||
task_content = task_data.get("task", task_data.get("prompt", ""))
|
||||
|
||||
# Check if we have a dynamic subagent spec for this target
|
||||
agent_spec = self._dynamic_subagents.get(target_agent)
|
||||
|
||||
if agent_spec:
|
||||
logger.info(
|
||||
"Executing task '%s' with dynamic subagent '%s' (prompt: %s)",
|
||||
task_type,
|
||||
target_agent,
|
||||
agent_spec.get("prompt", "")[:50],
|
||||
)
|
||||
# In a full implementation, this would create and run an actual agent
|
||||
# For now, return a structured result indicating the task was received
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": {
|
||||
"name": target_agent,
|
||||
"description": agent_spec.get("description", ""),
|
||||
"prompt": agent_spec.get("prompt", ""),
|
||||
"tools": agent_spec.get("tools", []),
|
||||
},
|
||||
"status": "completed",
|
||||
"message": f"Task '{task_type}' executed with dynamic subagent '{target_agent}'",
|
||||
}
|
||||
|
||||
# Fallback: execute with default behavior
|
||||
logger.info(
|
||||
"Executing task '%s' with default agent '%s'",
|
||||
task_type,
|
||||
target_agent,
|
||||
)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"target_agent": target_agent,
|
||||
"status": "completed",
|
||||
"message": f"Task '{task_type}' executed with agent '{target_agent}'",
|
||||
}
|
||||
|
||||
def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]:
|
||||
"""Get a dynamically defined subagent specification.
|
||||
|
||||
Args:
|
||||
name: Subagent name
|
||||
|
||||
Returns:
|
||||
Subagent spec dict or None if not found
|
||||
"""
|
||||
return self._dynamic_subagents.get(name)
|
||||
|
||||
def list_dynamic_subagents(self) -> List[str]:
|
||||
"""List all registered dynamic subagent names.
|
||||
|
||||
Returns:
|
||||
List of subagent names
|
||||
"""
|
||||
return list(self._dynamic_subagents.keys())
|
||||
|
||||
|
||||
__all__ = ["TaskDelegator", "SubagentSpec"]
|
||||
389
backend/agents/team/team_coordinator.py
Normal file
389
backend/agents/team/team_coordinator.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TeamCoordinator - Agent lifecycle management and execution.
|
||||
|
||||
Provides run_parallel() using asyncio.gather() and run_sequential()
|
||||
for coordinating multiple agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamCoordinator:
|
||||
"""Coordinates agent lifecycle and execution.
|
||||
|
||||
Supports:
|
||||
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
|
||||
- run_sequential(): Execute agents one after another
|
||||
- run_phase(): Execute a named phase with registered agents
|
||||
- register_agent(): Add agent to coordinator
|
||||
- unregister_agent(): Remove agent from coordinator
|
||||
|
||||
Each agent maintains separate context/memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: Optional[List[Any]] = None,
|
||||
task_content: Optional[str] = None,
|
||||
messenger: Optional[Any] = None,
|
||||
registry: Optional[Any] = None,
|
||||
):
|
||||
"""Initialize TeamCoordinator.
|
||||
|
||||
Args:
|
||||
participants: List of agent instances to coordinate
|
||||
task_content: Task description content for the agents
|
||||
messenger: AgentMessenger for communication (optional)
|
||||
registry: AgentRegistry for agent lookup (optional)
|
||||
"""
|
||||
self._participants = participants or []
|
||||
self._task_content = task_content or ""
|
||||
self._messenger = messenger
|
||||
self._registry = registry
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Auto-register participants
|
||||
for agent in self._participants:
|
||||
if hasattr(agent, "name"):
|
||||
self._agents[agent.name] = agent
|
||||
elif hasattr(agent, "id"):
|
||||
self._agents[agent.id] = agent
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent with the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Registered agent: %s", agent_id)
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Unregister an agent from the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any:
|
||||
"""Get registered agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
async def run_parallel(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run multiple agents in parallel using asyncio.gather().
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run concurrently
|
||||
initial_message: Optional initial message to broadcast
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
async def _run_agent(aid: str) -> tuple[str, Any]:
|
||||
agent = self._agents.get(aid)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", aid)
|
||||
return (aid, None)
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
if initial_message:
|
||||
result = await agent.reply(initial_message)
|
||||
else:
|
||||
result = await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
logger.info("Agent %s completed successfully", aid)
|
||||
return (aid, result)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", aid, e)
|
||||
return (aid, {"error": str(e)})
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_run_agent(aid) for aid in agent_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
output: Dict[str, Any] = {}
|
||||
for result in results:
|
||||
if isinstance(result, tuple):
|
||||
agent_id, agent_result = result
|
||||
output[agent_id] = agent_result
|
||||
else:
|
||||
logger.error("Unexpected result from asyncio.gather: %s", result)
|
||||
|
||||
logger.info("Parallel run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_sequential(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents one after another in order.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run in sequence
|
||||
initial_message: Optional initial message for first agent
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
current_message = initial_message
|
||||
|
||||
for agent_id in agent_ids:
|
||||
agent = self._agents.get(agent_id)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", agent_id)
|
||||
output[agent_id] = {"error": "Agent not found"}
|
||||
continue
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(current_message)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
output[agent_id] = result
|
||||
current_message = result
|
||||
logger.info("Agent %s completed sequentially", agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", agent_id, e)
|
||||
output[agent_id] = {"error": str(e)}
|
||||
break
|
||||
|
||||
logger.info("Sequential run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_phase(
|
||||
self,
|
||||
phase_name: str,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Any]:
|
||||
"""Execute a named phase with registered agents.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the phase (e.g., "analyst_analysis")
|
||||
agent_ids: Optional list of agent IDs; if None, uses all registered
|
||||
metadata: Optional metadata to include in the message (e.g., tickers, date)
|
||||
|
||||
Returns:
|
||||
List of results from each agent
|
||||
"""
|
||||
if agent_ids is None:
|
||||
agent_ids = list(self._agents.keys())
|
||||
|
||||
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
|
||||
|
||||
logger.info(
|
||||
"Running phase '%s' with %d agents: %s",
|
||||
phase_name,
|
||||
len(_agent_ids),
|
||||
_agent_ids,
|
||||
)
|
||||
|
||||
# Create messages for each agent
|
||||
results: List[Any] = []
|
||||
for agent_id in _agent_ids:
|
||||
agent = self._agents[agent_id]
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
# Create a message for the agent with proper structure
|
||||
msg = Msg(
|
||||
name="system",
|
||||
content=self._task_content or f"Please execute phase: {phase_name}",
|
||||
role="user",
|
||||
metadata=metadata,
|
||||
)
|
||||
result = await agent.reply(msg)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
results.append(result)
|
||||
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
|
||||
except Exception as e:
|
||||
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
|
||||
results.append(None)
|
||||
|
||||
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
|
||||
return results
|
||||
|
||||
async def run_with_dependencies(
|
||||
self,
|
||||
agent_tasks: Dict[str, List[str]],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents respecting dependency graph.
|
||||
|
||||
Args:
|
||||
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
|
||||
initial_message: Optional initial message
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
completed: Dict[str, Any] = {}
|
||||
remaining = set(agent_tasks.keys())
|
||||
|
||||
while remaining:
|
||||
ready = [
|
||||
aid for aid in remaining
|
||||
if all(dep in completed for dep in agent_tasks.get(aid, []))
|
||||
]
|
||||
|
||||
if not ready:
|
||||
logger.error("Circular dependency detected in agent tasks")
|
||||
for aid in remaining:
|
||||
completed[aid] = {"error": "Circular dependency"}
|
||||
break
|
||||
|
||||
results = await self.run_parallel(ready, initial_message)
|
||||
completed.update(results)
|
||||
|
||||
for aid in ready:
|
||||
remaining.discard(aid)
|
||||
initial_message = results.get(aid)
|
||||
|
||||
return completed
|
||||
|
||||
async def fanout_pipeline(
|
||||
self,
|
||||
agents: List[Any],
|
||||
msg: Optional[Msg] = None,
|
||||
) -> List[Msg]:
|
||||
"""Fanout a message to multiple agents concurrently and collect all responses.
|
||||
|
||||
Similar to AgentScope's fanout_pipeline, this sends the same message
|
||||
to all specified agents and returns a list of all agent responses.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances to fanout the message to
|
||||
msg: Message to send to all agents (optional)
|
||||
|
||||
Returns:
|
||||
List of Msg responses from each agent (in the same order as input agents)
|
||||
|
||||
Example:
|
||||
>>> responses = await fanout_pipeline(
|
||||
... agents=[alice, bob, charlie],
|
||||
... msg=question,
|
||||
... )
|
||||
>>> # responses is a list of Msg responses from each agent
|
||||
"""
|
||||
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
|
||||
"""Send message to a single agent and return its response."""
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(msg) if msg is not None else await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
# Convert result to Msg if needed
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, Msg):
|
||||
return result
|
||||
# If result is a dict with content, wrap it
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=result.get("content", ""),
|
||||
role="assistant",
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
# Otherwise wrap the result
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=str(result),
|
||||
role="assistant",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed in fanout_pipeline: %s",
|
||||
getattr(agent, "name", "unknown"), e)
|
||||
return None
|
||||
|
||||
# Run all agents concurrently
|
||||
results = await asyncio.gather(
|
||||
*[_fanout_to_agent(agent) for agent in agents],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Filter out exceptions and keep only valid responses
|
||||
responses: List[Msg] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Fanout to agent %d failed: %s", i, result)
|
||||
responses.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
responses.append(result) # type: ignore[arg-type]
|
||||
|
||||
logger.info("Fanout pipeline completed for %d agents", len(agents))
|
||||
return responses
|
||||
|
||||
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
|
||||
"""Shutdown all running agents gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Timeout for graceful shutdown
|
||||
"""
|
||||
logger.info("Shutting down TeamCoordinator...")
|
||||
|
||||
cancel_tasks = [
|
||||
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
|
||||
for task in self._running_tasks.values()
|
||||
]
|
||||
|
||||
if cancel_tasks:
|
||||
await asyncio.gather(*cancel_tasks, return_exceptions=True)
|
||||
|
||||
self._running_tasks.clear()
|
||||
logger.info("TeamCoordinator shutdown complete")
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["TeamCoordinator"]
|
||||
132
backend/agents/team_pipeline_config.py
Normal file
132
backend/agents/team_pipeline_config.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped team pipeline configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Dict, Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_FILENAME = "TEAM_PIPELINE.yaml"
|
||||
|
||||
|
||||
def team_pipeline_path(project_root: Path, config_name: str) -> Path:
|
||||
"""Return run-scoped team pipeline config path."""
|
||||
return project_root / "runs" / config_name / DEFAULT_FILENAME
|
||||
|
||||
|
||||
def ensure_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
default_analysts: Iterable[str],
|
||||
) -> Path:
|
||||
"""Ensure TEAM_PIPELINE.yaml exists for one run."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
payload = {
|
||||
"version": 1,
|
||||
"controller_agent": "portfolio_manager",
|
||||
"discussion": {
|
||||
"allow_dynamic_team_update": True,
|
||||
"active_analysts": list(default_analysts),
|
||||
},
|
||||
"decision": {
|
||||
"require_risk_manager": True,
|
||||
},
|
||||
}
|
||||
path.write_text(
|
||||
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]:
|
||||
"""Load TEAM_PIPELINE.yaml and return parsed dict."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
if not path.exists():
|
||||
return {}
|
||||
parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def save_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Path:
|
||||
"""Persist TEAM_PIPELINE.yaml."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def resolve_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Resolve active analysts from TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {}
|
||||
configured = discussion.get("active_analysts", [])
|
||||
if not isinstance(configured, list) or not configured:
|
||||
return available
|
||||
|
||||
active = [item for item in configured if item in available]
|
||||
return active or available
|
||||
|
||||
|
||||
def update_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
*,
|
||||
add: Iterable[str] | None = None,
|
||||
remove: Iterable[str] | None = None,
|
||||
set_to: Iterable[str] | None = None,
|
||||
) -> List[str]:
|
||||
"""Update active analysts and persist TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
ensure_team_pipeline_config(project_root, config_name, available)
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.setdefault("discussion", {})
|
||||
if not isinstance(discussion, dict):
|
||||
discussion = {}
|
||||
parsed["discussion"] = discussion
|
||||
|
||||
current = discussion.get("active_analysts", [])
|
||||
if not isinstance(current, list):
|
||||
current = []
|
||||
current = [item for item in current if item in available]
|
||||
if not current:
|
||||
current = list(available)
|
||||
|
||||
if set_to is not None:
|
||||
target = [item for item in set_to if item in available]
|
||||
current = target or current
|
||||
|
||||
for item in add or []:
|
||||
if item in available and item not in current:
|
||||
current.append(item)
|
||||
|
||||
for item in remove or []:
|
||||
current = [existing for existing in current if existing != item]
|
||||
|
||||
if not current:
|
||||
current = [available[0]] if available else []
|
||||
|
||||
discussion["active_analysts"] = current
|
||||
save_team_pipeline_config(project_root, config_name, parsed)
|
||||
return current
|
||||
|
||||
Reference in New Issue
Block a user