diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..d6b22fb --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,170 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +EvoTraders is a self-evolving multi-agent trading system where 6 AI agents (4 analysts + portfolio manager + risk manager) collaborate to make trading decisions. Agents use the AgentScope framework with a ReMe memory system for continuous learning. + +## Development Commands + +### Backend (Python) + +```bash +# Install dependencies +uv pip install -e . + +# Run commands +evotraders backtest --start 2025-11-01 --end 2025-12-01 # Backtest mode +evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory +evotraders live # Live trading +evotraders live --mock # Mock/testing mode +evotraders live -t 22:30 # Scheduled daily trading +evotraders frontend # Launch visualization UI + +# Dev server (starts FastAPI on port 8000) +./start-dev.sh +# Or manually: +python -m uvicorn backend.app:app --host 0.0.0.0 --port 8000 --reload --reload-dir backend + +# Testing +pytest backend/tests +``` + +### Frontend (React) + +```bash +cd frontend +npm run dev # Vite dev server (http://localhost:5173) +npm run build # Production build +npm run lint # ESLint +npm run test # Vitest +npm run test:watch # Watch mode +``` + +## Architecture + +### Multi-Agent System (`backend/agents/`) + +**6 Agent Roles** (configured in `prompts/analyst/personas.yaml`): +- **fundamentals_analyst** - Financial health, profitability, growth quality +- **technical_analyst** - Price trends, indicators, momentum +- **sentiment_analyst** - Market sentiment, news, insider trading +- **valuation_analyst** - DCF, EV/EBITDA, intrinsic value +- **portfolio_manager** - Decision execution, trade coordination +- **risk_manager** - Real-time risk monitoring, position limits + +**Key Agent Files**: +- `base/evo_agent.py` - Core agent implementation extending AgentScope +- `base/hooks.py` - Lifecycle hooks for agent execution (BootstrapHook, MemoryCompactionHook, HeartbeatHook, WorkspaceWatchHook) +- `base/evaluation_hook.py` - Post-execution evaluation +- `base/skill_adaptation_hook.py` - Dynamic skill adaptation +- `factory.py` - Agent factory for creating agent instances +- `skills_manager.py` - Skill loading and management (6 scopes: builtin/customized/installed/active/disabled/local) +- `toolkit_factory.py` - Tool collection factory for agents +- `team/` - Team coordination (registry, coordinator, messenger, task_delegator) + +**Hook System** (`base/hooks.py`): +- **MemoryCompactionHook**: 基于 CoPaw 设计的内存压缩,支持: + - `memory_compact_ratio`: 压缩目标比例 (默认 0.75) + - `memory_reserve_ratio`: 保留比例 (默认 0.1) + - `enable_tool_result_compact`: 工具结果压缩 + - `tool_result_compact_keep_n`: 保留最近 N 条工具结果 + +**Adding Custom Analysts**: +1. Register in `backend/agents/prompts/analyst/personas.yaml` +2. Add to `ANALYST_TYPES` dict in `backend/config/constants.py` +3. Optionally update frontend config in `frontend/src/config/constants.js` + +### Backend Structure + +``` +backend/ +├── agents/ # Multi-agent implementation +│ ├── base/ # Base classes, hooks, evaluation +│ ├── prompts/ # Agent prompts and personas +│ └── team/ # Team coordination logic +├── api/ # FastAPI endpoints +├── config/ # Constants and configuration +├── core/ # Pipeline execution logic +├── data/ # Market data handling +├── enrich/ # LLM response enrichment +├── explain/ # Decision explanation +├── llm/ # LLM integrations (with RetryChatModel, TokenRecordingModelWrapper) +├── services/ # Gateway, WebSocket services +├── skills/ # Skill definitions (builtin + custom) +└── tools/ # Trading and analysis tools +``` + +### LLM Model Wrappers (`backend/llm/models.py`) + +Based on CoPaw's model wrapper design: +- **RetryChatModel**: 自动重试瞬态 LLM 错误(rate limit、timeout、502/503 等),指数退避 + - `max_retries`: 最大重试次数 (默认 3) + - `initial_delay`: 初始延迟秒数 (默认 1.0) + - `backoff_multiplier`: 退避倍数 (默认 2.0) +- **TokenRecordingModelWrapper**: 追踪每个 provider 的 token 消耗和成本 + +```python +from backend.llm.models import create_model, RetryChatModel + +model = RetryChatModel(create_model("gpt-4o", "OPENAI"), max_retries=3) +``` + +### Frontend Structure + +``` +frontend/src/ +├── App.jsx # Main React application +├── components/ # React components +│ ├── RuntimeView.jsx # Trading runtime UI +│ ├── TraderView.jsx # Trader interface +│ └── RuntimeSettingsPanel.jsx +├── services/ # API and WebSocket services +│ ├── runtimeApi.js # Backend API calls +│ └── websocket.js # Real-time communication +└── config/ + └── constants.js # Agent definitions, configuration +``` + +### Skill System (`backend/skills/`) + +Skills are defined in `SKILL.md` files with: +- `instructions` - What the skill does +- `triggers` - When to invoke +- `parameters` - Input/output schema +- `available_tools` - Tools the skill can use + +Skills are loaded by `skills_manager.py` and attached to agents via `skill_adaptation_hook.py`. + +### Pipeline Execution (`backend/core/`) + +The daily trading flow: +1. **Analysis Stage** - Each agent analyzes independently +2. **Communication Stage** - Agent-to-agent messaging (1v1, 1vN, NvN) +3. **Decision Stage** - Portfolio manager makes final trades +4. **Evaluation Stage** - Performance tracking +5. **Review Stage** - Memory updates via ReMe + +## Environment Configuration + +Required in `.env`: +```bash +FIN_DATA_SOURCE=finnhub|financial_datasets +FINANCIAL_DATASETS_API_KEY= # Required for backtest +FINNHUB_API_KEY= # Required for live trading +OPENAI_API_KEY= # Agent LLM +OPENAI_BASE_URL= +MODEL_NAME=qwen3-max-preview +MEMORY_API_KEY= # For ReMe memory system +``` + +## Key Dependencies + +- **AgentScope** - Multi-agent framework +- **ReMe** - Memory system for continuous learning +- **FastAPI** + **uvicorn** - Backend API server +- **websockets** - Real-time communication +- **React 19** + **Vite** + **TailwindCSS** - Frontend +- **Zustand** - Frontend state management +- **Three.js** / **React-Three-Fiber** - 3D visualizations diff --git a/backend/agents/base/evaluation_hook.py b/backend/agents/base/evaluation_hook.py new file mode 100644 index 0000000..a2c556b --- /dev/null +++ b/backend/agents/base/evaluation_hook.py @@ -0,0 +1,452 @@ +# -*- coding: utf-8 -*- +"""Evaluation hooks system for skills. + +Provides evaluation metric collection and storage for skill performance tracking. +Based on the evaluation hooks design in SKILL_TEMPLATE.md. +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field, asdict +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +class MetricType(Enum): + """Types of evaluation metrics.""" + HIT_RATE = "hit_rate" # 信号命中率 + RISK_VIOLATION = "risk_violation" # 风控违例率 + POSITION_DEVIATION = "position_deviation" # 仓位偏离率 + PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性 + SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性 + DECISION_LATENCY = "decision_latency" # 决策延迟 + TOOL_USAGE = "tool_usage" # 工具使用率 + CUSTOM = "custom" # 自定义指标 + + +@dataclass +class EvaluationMetric: + """A single evaluation metric.""" + name: str + metric_type: MetricType + value: float + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + metadata: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "metric_type": self.metric_type.value, + "value": self.value, + "timestamp": self.timestamp, + "metadata": self.metadata, + } + + +@dataclass +class EvaluationResult: + """Evaluation result for a skill execution.""" + skill_name: str + run_id: str + agent_id: str + metrics: List[EvaluationMetric] = field(default_factory=list) + inputs: Dict[str, Any] = field(default_factory=dict) + outputs: Dict[str, Any] = field(default_factory=dict) + decision: Optional[str] = None + success: bool = True + error_message: Optional[str] = None + started_at: Optional[str] = None + completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat()) + + def to_dict(self) -> Dict[str, Any]: + return { + "skill_name": self.skill_name, + "run_id": self.run_id, + "agent_id": self.agent_id, + "metrics": [m.to_dict() for m in self.metrics], + "inputs": self.inputs, + "outputs": self.outputs, + "decision": self.decision, + "success": self.success, + "error_message": self.error_message, + "started_at": self.started_at, + "completed_at": self.completed_at, + } + + +class EvaluationHook: + """Hook for collecting skill evaluation metrics. + + This hook collects and stores evaluation metrics after skill execution + for later analysis and memory/reflection stages. + """ + + def __init__( + self, + storage_dir: Path, + run_id: str, + agent_id: str, + ): + """Initialize evaluation hook. + + Args: + storage_dir: Directory to store evaluation results + run_id: Current run identifier + agent_id: Current agent identifier + """ + self.storage_dir = Path(storage_dir) + self.run_id = run_id + self.agent_id = agent_id + self._current_evaluation: Optional[EvaluationResult] = None + + def start_evaluation( + self, + skill_name: str, + inputs: Dict[str, Any], + ) -> None: + """Start a new evaluation session. + + Args: + skill_name: Name of the skill being evaluated + inputs: Input parameters for the skill + """ + self._current_evaluation = EvaluationResult( + skill_name=skill_name, + run_id=self.run_id, + agent_id=self.agent_id, + inputs=inputs, + started_at=datetime.now().isoformat(), + ) + logger.debug(f"Started evaluation for skill: {skill_name}") + + def add_metric( + self, + name: str, + metric_type: MetricType, + value: float, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """Add an evaluation metric. + + Args: + name: Metric name + metric_type: Type of metric + value: Metric value + metadata: Additional metadata + """ + if self._current_evaluation is None: + logger.warning("No active evaluation session, ignoring metric") + return + + metric = EvaluationMetric( + name=name, + metric_type=metric_type, + value=value, + metadata=metadata or {}, + ) + self._current_evaluation.metrics.append(metric) + logger.debug(f"Added metric: {name} = {value}") + + def add_metrics(self, metrics: List[EvaluationMetric]) -> None: + """Add multiple evaluation metrics at once. + + Args: + metrics: List of metrics to add + """ + if self._current_evaluation is None: + logger.warning("No active evaluation session, ignoring metrics") + return + + self._current_evaluation.metrics.extend(metrics) + + def record_outputs(self, outputs: Dict[str, Any]) -> None: + """Record skill outputs. + + Args: + outputs: Output from skill execution + """ + if self._current_evaluation is None: + logger.warning("No active evaluation session, ignoring outputs") + return + + self._current_evaluation.outputs = outputs + + def record_decision(self, decision: str) -> None: + """Record the final decision. + + Args: + decision: Final decision made by the skill + """ + if self._current_evaluation is None: + logger.warning("No active evaluation session, ignoring decision") + return + + self._current_evaluation.decision = decision + + def complete_evaluation( + self, + success: bool = True, + error_message: Optional[str] = None, + ) -> Optional[EvaluationResult]: + """Complete the evaluation session and persist results. + + Args: + success: Whether the skill execution was successful + error_message: Error message if failed + + Returns: + The completed evaluation result, or None if no active evaluation + """ + if self._current_evaluation is None: + logger.warning("No active evaluation to complete") + return None + + self._current_evaluation.success = success + self._current_evaluation.error_message = error_message + self._current_evaluation.completed_at = datetime.now().isoformat() + + # Persist to storage + result = self._persist_evaluation(self._current_evaluation) + + self._current_evaluation = None + logger.debug(f"Completed evaluation for skill: {result.skill_name}") + + return result + + def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult: + """Persist evaluation result to storage. + + Args: + evaluation: Evaluation result to persist + + Returns: + The persisted evaluation + """ + # Create run-specific directory + run_dir = self.storage_dir / self.run_id + run_dir.mkdir(parents=True, exist_ok=True) + + # Create agent-specific subdirectory + agent_dir = run_dir / self.agent_id + agent_dir.mkdir(parents=True, exist_ok=True) + + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{evaluation.skill_name}_{timestamp}.json" + filepath = agent_dir / filename + + # Write evaluation result + try: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2) + logger.info(f"Persisted evaluation to: {filepath}") + except Exception as e: + logger.error(f"Failed to persist evaluation: {e}") + + return evaluation + + def cancel_evaluation(self) -> None: + """Cancel the current evaluation session without saving.""" + if self._current_evaluation is not None: + logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}") + self._current_evaluation = None + + +class EvaluationCollector: + """Collector for aggregating evaluation metrics across runs. + + Provides methods to query and analyze evaluation results. + """ + + def __init__(self, storage_dir: Path): + """Initialize evaluation collector. + + Args: + storage_dir: Root directory containing evaluation results + """ + self.storage_dir = Path(storage_dir) + + def get_run_evaluations( + self, + run_id: str, + agent_id: Optional[str] = None, + ) -> List[EvaluationResult]: + """Get all evaluations for a run. + + Args: + run_id: Run identifier + agent_id: Optional agent identifier to filter by + + Returns: + List of evaluation results + """ + run_dir = self.storage_dir / run_id + if not run_dir.exists(): + return [] + + evaluations = [] + + agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir() + + for agent_dir in agent_dirs: + if not agent_dir.is_dir(): + continue + + for eval_file in agent_dir.glob("*.json"): + try: + with open(eval_file, "r", encoding="utf-8") as f: + data = json.load(f) + evaluations.append(self._parse_evaluation(data)) + except Exception as e: + logger.warning(f"Failed to load evaluation {eval_file}: {e}") + + return evaluations + + def get_skill_metrics( + self, + skill_name: str, + run_ids: Optional[List[str]] = None, + ) -> List[EvaluationMetric]: + """Get all metrics for a specific skill. + + Args: + skill_name: Name of the skill + run_ids: Optional list of run IDs to filter by + + Returns: + List of metrics for the skill + """ + metrics = [] + + if run_ids is None: + run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()] + + for run_id in run_ids: + evaluations = self.get_run_evaluations(run_id) + for eval_result in evaluations: + if eval_result.skill_name == skill_name: + metrics.extend(eval_result.metrics) + + return metrics + + def calculate_skill_stats( + self, + skill_name: str, + metric_type: MetricType, + run_ids: Optional[List[str]] = None, + ) -> Dict[str, float]: + """Calculate statistics for a specific metric type. + + Args: + skill_name: Name of the skill + metric_type: Type of metric to calculate + run_ids: Optional list of run IDs to filter by + + Returns: + Dictionary with min, max, avg, count statistics + """ + metrics = self.get_skill_metrics(skill_name, run_ids) + filtered = [m for m in metrics if m.metric_type == metric_type] + + if not filtered: + return {"count": 0} + + values = [m.value for m in filtered] + return { + "count": len(values), + "min": min(values), + "max": max(values), + "avg": sum(values) / len(values), + } + + def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult: + """Parse evaluation data into EvaluationResult. + + Args: + data: Raw evaluation data + + Returns: + Parsed EvaluationResult + """ + metrics = [] + for m in data.get("metrics", []): + metrics.append(EvaluationMetric( + name=m["name"], + metric_type=MetricType(m["metric_type"]), + value=m["value"], + timestamp=m.get("timestamp", ""), + metadata=m.get("metadata", {}), + )) + + return EvaluationResult( + skill_name=data["skill_name"], + run_id=data["run_id"], + agent_id=data["agent_id"], + metrics=metrics, + inputs=data.get("inputs", {}), + outputs=data.get("outputs", {}), + decision=data.get("decision"), + success=data.get("success", True), + error_message=data.get("error_message"), + started_at=data.get("started_at"), + completed_at=data.get("completed_at"), + ) + + +def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]: + """Parse evaluation hooks from SKILL.md. + + Extracts the Optional: Evaluation hooks section from skill documentation. + + Args: + skill_dir: Skill directory path + + Returns: + Dictionary containing evaluation hook definitions + """ + skill_md = skill_dir / "SKILL.md" + if not skill_md.exists(): + return {} + + try: + content = skill_md.read_text(encoding="utf-8") + + # Extract evaluation hooks section + if "## Optional: Evaluation hooks" in content: + start = content.find("## Optional: Evaluation hooks") + # Find the next ## section or end of file + next_section = content.find("\n## ", start + 1) + if next_section == -1: + eval_section = content[start:] + else: + eval_section = content[start:next_section] + + # Parse metrics from the section + metrics = [] + for metric_type in MetricType: + if metric_type.value.replace("_", " ") in eval_section.lower(): + metrics.append(metric_type.value) + + return { + "supported_metrics": metrics, + "section_content": eval_section.strip(), + } + except Exception as e: + logger.warning(f"Failed to parse evaluation hooks: {e}") + + return {} + + +__all__ = [ + "MetricType", + "EvaluationMetric", + "EvaluationResult", + "EvaluationHook", + "EvaluationCollector", + "parse_evaluation_hooks", +] diff --git a/backend/agents/base/skill_adaptation_hook.py b/backend/agents/base/skill_adaptation_hook.py new file mode 100644 index 0000000..1a9e358 --- /dev/null +++ b/backend/agents/base/skill_adaptation_hook.py @@ -0,0 +1,489 @@ +# -*- coding: utf-8 -*- +"""Skill adaptation hook for automatic evaluation-to-iteration闭环. + +Monitors evaluation metrics against configurable thresholds and triggers +automatic skill reload or logs warnings when thresholds are breached. +""" +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Set + +from .evaluation_hook import ( + EvaluationCollector, + EvaluationResult, + MetricType, +) + +logger = logging.getLogger(__name__) + + +class AdaptationAction(Enum): + """Actions to take when threshold is breached.""" + RELOAD = "reload" # 自动重新加载技能 + WARN = "warn" # 记录警告供人工审核 + BOTH = "both" # 同时执行重载和警告 + NONE = "none" # 不做任何操作 + + +@dataclass +class AdaptationThreshold: + """Threshold configuration for a metric.""" + metric_type: MetricType + operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq + value: float = 0.0 + window_size: int = 10 # 移动窗口大小,用于计算滑动平均 + min_samples: int = 5 # 最少样本数才触发检查 + action: AdaptationAction = AdaptationAction.WARN + cooldown_seconds: int = 300 # 触发后的冷却时间 + + def evaluate(self, current_value: float) -> bool: + """Evaluate if threshold is breached.""" + ops = { + "lt": lambda x, y: x < y, + "lte": lambda x, y: x <= y, + "gt": lambda x, y: x > y, + "gte": lambda x, y: x >= y, + "eq": lambda x, y: x == y, + } + op_func = ops.get(self.operator) + if op_func is None: + logger.warning(f"Unknown operator: {self.operator}") + return False + return op_func(current_value, self.value) + + def to_dict(self) -> Dict[str, Any]: + return { + "metric_type": self.metric_type.value, + "operator": self.operator, + "value": self.value, + "window_size": self.window_size, + "min_samples": self.min_samples, + "action": self.action.value, + "cooldown_seconds": self.cooldown_seconds, + } + + +@dataclass +class AdaptationEvent: + """Record of an adaptation trigger event.""" + timestamp: str + skill_name: str + metric_type: MetricType + threshold: AdaptationThreshold + current_value: float + avg_value: float + action_taken: AdaptationAction + details: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return { + "timestamp": self.timestamp, + "skill_name": self.skill_name, + "metric_type": self.metric_type.value, + "threshold": self.threshold.to_dict(), + "current_value": self.current_value, + "avg_value": self.avg_value, + "action_taken": self.action_taken.value, + "details": self.details, + } + + +class SkillAdaptationHook: + """Hook for monitoring evaluation metrics and triggering skill adaptation. + + This hook wraps EvaluationHook to add threshold-based adaptation logic. + When metrics breach configured thresholds, it can: + - Automatically reload skills via SkillsManager + - Log warnings for human review + - Both + """ + + # Default thresholds for common metrics + DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [ + AdaptationThreshold( + metric_type=MetricType.HIT_RATE, + operator="lt", + value=0.5, + action=AdaptationAction.WARN, + cooldown_seconds=600, + ), + AdaptationThreshold( + metric_type=MetricType.RISK_VIOLATION, + operator="gt", + value=0.1, + action=AdaptationAction.WARN, + cooldown_seconds=300, + ), + AdaptationThreshold( + metric_type=MetricType.DECISION_LATENCY, + operator="gt", + value=5000, # 5 seconds + action=AdaptationAction.WARN, + cooldown_seconds=300, + ), + ] + + def __init__( + self, + storage_dir: Path, + run_id: str, + agent_id: str, + thresholds: Optional[List[AdaptationThreshold]] = None, + collector: Optional[EvaluationCollector] = None, + ): + """Initialize skill adaptation hook. + + Args: + storage_dir: Directory to store adaptation events + run_id: Current run identifier + agent_id: Current agent identifier + thresholds: Custom threshold configurations (uses defaults if None) + collector: Optional EvaluationCollector for historical data + """ + self.storage_dir = Path(storage_dir) + self.run_id = run_id + self.agent_id = agent_id + self.thresholds = thresholds or self.DEFAULT_THRESHOLDS + self.collector = collector or EvaluationCollector(storage_dir) + + # Track cooldowns to prevent rapid re-triggering + self._cooldowns: Dict[str, datetime] = {} + + # Store recent metrics in memory for quick access + self._recent_metrics: Dict[str, List[float]] = {} + + # Pending adaptation events + self._pending_events: List[AdaptationEvent] = [] + + def check_threshold( + self, + skill_name: str, + metric_type: MetricType, + current_value: float, + ) -> Optional[AdaptationEvent]: + """Check if a metric breaches any threshold. + + Args: + skill_name: Name of the skill + metric_type: Type of metric + current_value: Current metric value + + Returns: + AdaptationEvent if threshold breached, None otherwise + """ + # Find applicable thresholds + applicable_thresholds = [ + t for t in self.thresholds + if t.metric_type == metric_type + ] + + if not applicable_thresholds: + return None + + # Check cooldown + cooldown_key = f"{skill_name}:{metric_type.value}" + now = datetime.now() + last_trigger = self._cooldowns.get(cooldown_key) + + # Store current value first for avg calculation + self._store_metric(cooldown_key, current_value) + + for threshold in applicable_thresholds: + if last_trigger: + elapsed = (now - last_trigger).total_seconds() + if elapsed < threshold.cooldown_seconds: + continue + + # Evaluate threshold + if threshold.evaluate(current_value): + # Calculate moving average + avg_value = self._calculate_avg(skill_name, metric_type, current_value) + + # Check minimum samples (allow immediate trigger if min_samples <= 1) + sample_count = len(self._recent_metrics.get(cooldown_key, [])) + if threshold.min_samples > 1 and sample_count < threshold.min_samples: + # Not enough samples yet + continue + + # Trigger adaptation + event = AdaptationEvent( + timestamp=now.isoformat(), + skill_name=skill_name, + metric_type=metric_type, + threshold=threshold, + current_value=current_value, + avg_value=avg_value, + action_taken=threshold.action, + details={ + "run_id": self.run_id, + "agent_id": self.agent_id, + }, + ) + + # Update cooldown + self._cooldowns[cooldown_key] = now + + # Persist event + self._persist_event(event) + + logger.info( + f"Threshold breached for {skill_name}.{metric_type.value}: " + f"current={current_value}, avg={avg_value}, action={threshold.action.value}" + ) + + return event + + return None + + def _calculate_avg( + self, + skill_name: str, + metric_type: MetricType, + current_value: float, + ) -> float: + """Calculate moving average for a metric.""" + key = f"{skill_name}:{metric_type.value}" + values = self._recent_metrics.get(key, []) + if not values: + return current_value + return sum(values) / len(values) + + def _store_metric(self, key: str, value: float) -> None: + """Store metric value with sliding window.""" + if key not in self._recent_metrics: + self._recent_metrics[key] = [] + self._recent_metrics[key].append(value) + # Keep only last 100 values + if len(self._recent_metrics[key]) > 100: + self._recent_metrics[key] = self._recent_metrics[key][-100:] + + def _persist_event(self, event: AdaptationEvent) -> None: + """Persist adaptation event to storage.""" + run_dir = self.storage_dir / self.run_id / "adaptations" + run_dir.mkdir(parents=True, exist_ok=True) + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json" + filepath = run_dir / filename + + try: + with open(filepath, "w", encoding="utf-8") as f: + json.dump(event.to_dict(), f, ensure_ascii=False, indent=2) + logger.debug(f"Persisted adaptation event to: {filepath}") + except Exception as e: + logger.error(f"Failed to persist adaptation event: {e}") + + # Also add to pending list + self._pending_events.append(event) + + def get_pending_warnings(self) -> List[AdaptationEvent]: + """Get all pending warning events that need human review.""" + return [ + e for e in self._pending_events + if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH) + ] + + def clear_pending_warnings(self) -> None: + """Clear pending warnings after they have been reviewed.""" + self._pending_events = [ + e for e in self._pending_events + if e.action_taken == AdaptationAction.RELOAD + ] + + def get_recent_events( + self, + skill_name: Optional[str] = None, + metric_type: Optional[MetricType] = None, + limit: int = 50, + ) -> List[AdaptationEvent]: + """Get recent adaptation events. + + Args: + skill_name: Optional filter by skill name + metric_type: Optional filter by metric type + limit: Maximum number of events to return + + Returns: + List of recent adaptation events + """ + events_dir = self.storage_dir / self.run_id / "adaptations" + if not events_dir.exists(): + return [] + + events = [] + for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]: + try: + with open(eval_file, "r", encoding="utf-8") as f: + data = json.load(f) + event = self._parse_event(data) + if skill_name and event.skill_name != skill_name: + continue + if metric_type and event.metric_type != metric_type: + continue + events.append(event) + except Exception as e: + logger.warning(f"Failed to load adaptation event {eval_file}: {e}") + + return events + + def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent: + """Parse adaptation event from JSON data.""" + threshold_data = data.get("threshold", {}) + metric_type = MetricType(threshold_data.get("metric_type", "custom")) + + threshold = AdaptationThreshold( + metric_type=metric_type, + operator=threshold_data.get("operator", "lt"), + value=threshold_data.get("value", 0.0), + window_size=threshold_data.get("window_size", 10), + min_samples=threshold_data.get("min_samples", 5), + action=AdaptationAction(threshold_data.get("action", "warn")), + cooldown_seconds=threshold_data.get("cooldown_seconds", 300), + ) + + return AdaptationEvent( + timestamp=data.get("timestamp", ""), + skill_name=data.get("skill_name", ""), + metric_type=metric_type, + threshold=threshold, + current_value=data.get("current_value", 0.0), + avg_value=data.get("avg_value", 0.0), + action_taken=AdaptationAction(data.get("action_taken", "warn")), + details=data.get("details", {}), + ) + + def add_threshold(self, threshold: AdaptationThreshold) -> None: + """Add a new threshold configuration.""" + self.thresholds.append(threshold) + + def remove_threshold(self, metric_type: MetricType) -> None: + """Remove all thresholds for a specific metric type.""" + self.thresholds = [ + t for t in self.thresholds + if t.metric_type != metric_type + ] + + def update_threshold( + self, + metric_type: MetricType, + **kwargs, + ) -> None: + """Update threshold configuration for a metric type.""" + for threshold in self.thresholds: + if threshold.metric_type == metric_type: + for key, value in kwargs.items(): + if hasattr(threshold, key): + setattr(threshold, key, value) + + def get_thresholds(self) -> List[AdaptationThreshold]: + """Get current threshold configurations.""" + return list(self.thresholds) + + def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool: + """Check if a skill/metric combination is in cooldown period.""" + key = f"{skill_name}:{metric_type.value}" + last_trigger = self._cooldowns.get(key) + if not last_trigger: + return False + + # Find the threshold for this metric type + for threshold in self.thresholds: + if threshold.metric_type == metric_type: + elapsed = (datetime.now() - last_trigger).total_seconds() + return elapsed < threshold.cooldown_seconds + + return False + + +class AdaptationManager: + """Manager for coordinating skill adaptation across multiple agents. + + Provides centralized tracking of adaptation events and skill reloads. + """ + + def __init__(self, storage_dir: Path): + """Initialize adaptation manager. + + Args: + storage_dir: Root directory for storing adaptation data + """ + self.storage_dir = Path(storage_dir) + self._hooks: Dict[str, SkillAdaptationHook] = {} + + def get_hook( + self, + run_id: str, + agent_id: str, + thresholds: Optional[List[AdaptationThreshold]] = None, + ) -> SkillAdaptationHook: + """Get or create an adaptation hook for an agent. + + Args: + run_id: Run identifier + agent_id: Agent identifier + thresholds: Optional custom thresholds + + Returns: + SkillAdaptationHook instance + """ + key = f"{run_id}:{agent_id}" + if key not in self._hooks: + self._hooks[key] = SkillAdaptationHook( + storage_dir=self.storage_dir, + run_id=run_id, + agent_id=agent_id, + thresholds=thresholds, + ) + return self._hooks[key] + + def get_all_pending_warnings(self) -> List[AdaptationEvent]: + """Get all pending warnings from all hooks.""" + warnings = [] + for hook in self._hooks.values(): + warnings.extend(hook.get_pending_warnings()) + return warnings + + def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]: + """Get all adaptation events for a run.""" + events = [] + for hook in self._hooks.values(): + if hook.run_id == run_id: + events.extend(hook.get_recent_events()) + return events + + +# Global manager instance +_adaptation_manager: Optional[AdaptationManager] = None + + +def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager: + """Get global adaptation manager instance. + + Args: + storage_dir: Optional storage directory (required on first call) + + Returns: + AdaptationManager instance + """ + global _adaptation_manager + if _adaptation_manager is None: + if storage_dir is None: + raise ValueError("storage_dir required on first initialization") + _adaptation_manager = AdaptationManager(storage_dir) + return _adaptation_manager + + +__all__ = [ + "AdaptationAction", + "AdaptationThreshold", + "AdaptationEvent", + "SkillAdaptationHook", + "AdaptationManager", + "get_adaptation_manager", +] diff --git a/backend/agents/team/__init__.py b/backend/agents/team/__init__.py new file mode 100644 index 0000000..41da137 --- /dev/null +++ b/backend/agents/team/__init__.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +"""Team module for multi-agent orchestration. + +Provides inter-agent communication, task delegation, and coordination +for subagent spawning and lifecycle management. +""" + +from .messenger import AgentMessenger +from .task_delegator import TaskDelegator +from .team_coordinator import TeamCoordinator +from .registry import AgentRegistry + +__all__ = [ + "AgentMessenger", + "TaskDelegator", + "TeamCoordinator", + "AgentRegistry", +] diff --git a/backend/agents/team/messenger.py b/backend/agents/team/messenger.py new file mode 100644 index 0000000..1a88e66 --- /dev/null +++ b/backend/agents/team/messenger.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +"""AgentMessenger - Pub/sub inter-agent communication. + +Provides broadcast(), send(), and subscribe() for message passing +between agents using AgentScope's Msg format. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Callable, Dict, List, Optional, Set + +from agentscope.message import Msg + +logger = logging.getLogger(__name__) + + +class AgentMessenger: + """Pub/sub messenger for inter-agent communication. + + Supports: + - broadcast(): Send message to all subscribers + - send(): Send message to specific agent + - subscribe(): Register callback for agent messages + - announce(): Send system-wide announcement + - enable_auto_broadcast: Auto-broadcast agent replies to all participants + + Messages use AgentScope's Msg format for compatibility. + """ + + def __init__(self, enable_auto_broadcast: bool = False): + """Initialize the messenger. + + Args: + enable_auto_broadcast: If True, agent replies are automatically + broadcast to all subscribed agents. + """ + self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {} + self._inbox: Dict[str, List[Msg]] = {} + self._locks: Dict[str, asyncio.Lock] = {} + self._enable_auto_broadcast = enable_auto_broadcast + self._participants: Set[str] = set() + + def subscribe( + self, + agent_id: str, + callback: Callable[[Msg], None], + ) -> None: + """Subscribe an agent to receive messages. + + Args: + agent_id: Target agent identifier + callback: Async function to call when message received + """ + if agent_id not in self._subscriptions: + self._subscriptions[agent_id] = [] + self._subscriptions[agent_id].append(callback) + logger.debug("Agent %s subscribed to messages", agent_id) + + def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None: + """Unsubscribe an agent from messages. + + Args: + agent_id: Target agent identifier + callback: Callback to remove + """ + if agent_id in self._subscriptions: + try: + self._subscriptions[agent_id].remove(callback) + logger.debug("Agent %s unsubscribed from messages", agent_id) + except ValueError: + pass + + async def send( + self, + to_agent: str, + message: Msg, + ) -> None: + """Send message to specific agent. + + Args: + to_agent: Target agent identifier + message: Message to send (uses Msg format) + """ + async def _deliver(): + if to_agent in self._subscriptions: + for callback in self._subscriptions[to_agent]: + try: + if asyncio.iscoroutinefunction(callback): + await callback(message) + else: + callback(message) + except Exception as e: + logger.error( + "Error delivering message to %s: %s", + to_agent, + e, + ) + + await _deliver() + + async def broadcast(self, message: Msg) -> None: + """Broadcast message to all subscribed agents. + + Args: + message: Message to broadcast (uses Msg format) + """ + delivery_tasks = [] + for agent_id, callbacks in self._subscriptions.items(): + for callback in callbacks: + async def _deliver(cb=callback, aid=agent_id): + try: + if asyncio.iscoroutinefunction(cb): + await cb(message) + else: + cb(message) + except Exception as e: + logger.error( + "Error broadcasting to %s: %s", + aid, + e, + ) + delivery_tasks.append(_deliver()) + + if delivery_tasks: + await asyncio.gather(*delivery_tasks) + + def inbox(self, agent_id: str) -> List[Msg]: + """Get and clear inbox for agent. + + Args: + agent_id: Agent identifier + + Returns: + List of messages in inbox + """ + messages = self._inbox.get(agent_id, []) + self._inbox[agent_id] = [] + return messages + + def inbox_count(self, agent_id: str) -> int: + """Count messages in agent's inbox without clearing. + + Args: + agent_id: Agent identifier + + Returns: + Number of messages waiting + """ + return len(self._inbox.get(agent_id, [])) + + def add_participant(self, agent_id: str) -> None: + """Add a participant to the messenger. + + Participants are the agents that can receive auto-broadcast messages. + + Args: + agent_id: Agent identifier to add + """ + self._participants.add(agent_id) + logger.debug("Agent %s added as participant", agent_id) + + def remove_participant(self, agent_id: str) -> None: + """Remove a participant from the messenger. + + Args: + agent_id: Agent identifier to remove + """ + self._participants.discard(agent_id) + logger.debug("Agent %s removed from participants", agent_id) + + @property + def enable_auto_broadcast(self) -> bool: + """Check if auto_broadcast is enabled.""" + return self._enable_auto_broadcast + + @enable_auto_broadcast.setter + def enable_auto_broadcast(self, value: bool) -> None: + """Enable or disable auto_broadcast.""" + self._enable_auto_broadcast = value + logger.debug("Auto_broadcast set to %s", value) + + async def announce(self, message: Msg) -> None: + """Send a system-wide announcement to all participants. + + Unlike broadcast(), announce() sends a message from the system/host + to all participants without requiring prior subscription. + + Args: + message: Announcement message (uses Msg format) + """ + logger.info("System announcement: %s", message.content) + await self.broadcast(message) + + async def auto_broadcast(self, message: Msg) -> None: + """Auto-broadcast message to all participants. + + This is called internally when enable_auto_broadcast is True. + Broadcasts to all registered participants. + + Args: + message: Message to auto-broadcast (uses Msg format) + """ + if not self._enable_auto_broadcast: + return + + # Broadcast to all participants + for participant_id in self._participants: + if participant_id in self._subscriptions: + for callback in self._subscriptions[participant_id]: + try: + if asyncio.iscoroutinefunction(callback): + await callback(message) + else: + callback(message) + except Exception as e: + logger.error( + "Error auto-broadcasting to %s: %s", + participant_id, + e, + ) + + +__all__ = ["AgentMessenger"] diff --git a/backend/agents/team/registry.py b/backend/agents/team/registry.py new file mode 100644 index 0000000..1245566 --- /dev/null +++ b/backend/agents/team/registry.py @@ -0,0 +1,188 @@ +# -*- coding: utf-8 -*- +"""AgentRegistry - Agent registration and lookup by role. + +Provides register(), unregister(), and get_by_role() for agent +discovery and management. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +from agentscope.message import Msg + +logger = logging.getLogger(__name__) + + +class AgentRegistry: + """Registry for agent instances with role-based lookup. + + Supports: + - register(): Add agent with roles + - unregister(): Remove agent + - get_by_role(): Find agents by role + - get_by_id(): Get specific agent + + Each agent can have multiple roles for flexible dispatch. + """ + + def __init__(self): + self._agents: Dict[str, Any] = {} + self._roles: Dict[str, List[str]] = {} + self._agent_roles: Dict[str, List[str]] = {} + + def register( + self, + agent_id: str, + agent: Any, + roles: Optional[List[str]] = None, + ) -> None: + """Register an agent with optional roles. + + Args: + agent_id: Unique agent identifier + agent: Agent instance + roles: Optional list of role strings + """ + self._agents[agent_id] = agent + self._agent_roles[agent_id] = roles or [] + + for role in self._agent_roles[agent_id]: + if role not in self._roles: + self._roles[role] = [] + if agent_id not in self._roles[role]: + self._roles[role].append(agent_id) + + logger.info( + "Registered agent %s with roles %s", + agent_id, + self._agent_roles[agent_id], + ) + + def unregister(self, agent_id: str) -> bool: + """Unregister an agent. + + Args: + agent_id: Agent identifier to remove + + Returns: + True if agent was removed + """ + if agent_id not in self._agents: + return False + + roles = self._agent_roles.pop(agent_id, []) + for role in roles: + if role in self._roles: + try: + self._roles[role].remove(agent_id) + except ValueError: + pass + + del self._agents[agent_id] + logger.info("Unregistered agent: %s", agent_id) + return True + + def get_by_id(self, agent_id: str) -> Optional[Any]: + """Get agent by ID. + + Args: + agent_id: Agent identifier + + Returns: + Agent instance or None + """ + return self._agents.get(agent_id) + + def get_by_role(self, role: str) -> List[Any]: + """Get all agents with a given role. + + Args: + role: Role string to search for + + Returns: + List of agent instances with the role + """ + agent_ids = self._roles.get(role, []) + return [self._agents[aid] for aid in agent_ids if aid in self._agents] + + def get_by_roles(self, roles: List[str]) -> List[Any]: + """Get agents matching ANY of the given roles. + + Args: + roles: List of role strings + + Returns: + List of unique agent instances matching any role + """ + seen = set() + result = [] + for role in roles: + for agent in self.get_by_role(role): + if id(agent) not in seen: + seen.add(id(agent)) + result.append(agent) + return result + + def list_agents(self) -> List[str]: + """List all registered agent IDs. + + Returns: + List of agent identifiers + """ + return list(self._agents.keys()) + + def list_roles(self) -> List[str]: + """List all registered roles. + + Returns: + List of role strings + """ + return list(self._roles.keys()) + + def list_roles_for_agent(self, agent_id: str) -> List[str]: + """List roles for specific agent. + + Args: + agent_id: Agent identifier + + Returns: + List of role strings + """ + return list(self._agent_roles.get(agent_id, [])) + + def update_roles(self, agent_id: str, roles: List[str]) -> None: + """Update roles for an existing agent. + + Args: + agent_id: Agent identifier + roles: New list of roles + """ + if agent_id not in self._agents: + raise KeyError(f"Agent not registered: {agent_id}") + + old_roles = self._agent_roles.get(agent_id, []) + for role in old_roles: + if role in self._roles: + try: + self._roles[role].remove(agent_id) + except ValueError: + pass + + self._agent_roles[agent_id] = roles + for role in roles: + if role not in self._roles: + self._roles[role] = [] + if agent_id not in self._roles[role]: + self._roles[role].append(agent_id) + + logger.info("Updated roles for agent %s: %s", agent_id, roles) + + @property + def agents(self) -> Dict[str, Any]: + """Get copy of registered agents dict.""" + return dict(self._agents) + + +__all__ = ["AgentRegistry"] diff --git a/backend/agents/team/task_delegator.py b/backend/agents/team/task_delegator.py new file mode 100644 index 0000000..184c50b --- /dev/null +++ b/backend/agents/team/task_delegator.py @@ -0,0 +1,343 @@ +# -*- coding: utf-8 -*- +"""TaskDelegator - Subagent spawning and task delegation. + +Provides delegate() and delegate_parallel() for spawning subagents +with separate context and memory. Supports runtime dynamic subagent +definition via task_data with description, prompt, and tools. +""" + +from __future__ import annotations + +import asyncio +import logging +import uuid +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union + +from agentscope.message import Msg + +logger = logging.getLogger(__name__) + + +# Type alias for subagent specification +SubagentSpec = Dict[str, Any] +"""Subagent specification format: +{ + "description": "Expert code reviewer...", + "prompt": "Analyze code quality...", + "tools": ["Read", "Glob", "Grep"], # Optional: list of tool names + "model": "gpt-4o", # Optional: model name +} +""" + + +class TaskDelegator: + """Delegates tasks to subagents with isolated context. + + Supports: + - delegate(): Spawn single subagent for task + - delegate_parallel(): Spawn multiple subagents concurrently + - delegate_task(): Delegate with dynamic subagent definition from task_data + + Each subagent gets its own memory/context to prevent + cross-contamination. + + Dynamic Subagent Definition: + task_data can include an "agents" dict to define subagents inline: + + task_data = { + "task": "Review the code changes", + "agents": { + "code-reviewer": { + "description": "Expert code reviewer for quality and security.", + "prompt": "Analyze code quality and suggest improvements.", + "tools": ["Read", "Glob", "Grep"], + } + } + } + """ + + def __init__(self, messenger: Any, registry: Any): + """Initialize TaskDelegator. + + Args: + messenger: AgentMessenger for communication + registry: AgentRegistry for agent lookup + """ + self._messenger = messenger + self._registry = registry + self._subagents: Dict[str, Any] = {} + self._dynamic_subagents: Dict[str, SubagentSpec] = {} + self._tasks: Dict[str, asyncio.Task] = {} + + async def delegate( + self, + agent_id: str, + task: Callable[..., Awaitable[Msg]], + context: Optional[Dict[str, Any]] = None, + ) -> asyncio.Task: + """Delegate task to a single subagent. + + Args: + agent_id: Unique identifier for this subagent instance + task: Async function representing the task + context: Optional context dict for the subagent + + Returns: + asyncio.Task for the delegated task + """ + async def _run_with_context(): + result = await task(context or {}) + return result + + self._tasks[agent_id] = asyncio.create_task(_run_with_context()) + logger.info("Delegated task to subagent: %s", agent_id) + return self._tasks[agent_id] + + async def delegate_parallel( + self, + tasks: List[Dict[str, Any]], + ) -> List[asyncio.Task]: + """Delegate multiple tasks in parallel. + + Args: + tasks: List of task dicts with keys: + - agent_id: Unique identifier + - task: Async function to execute + - context: Optional context dict + + Returns: + List of asyncio.Task for all delegated tasks + """ + async def _run_task(task_def: Dict[str, Any]): + agent_id = task_def["agent_id"] + task_func = task_def["task"] + context = task_def.get("context", {}) + + async def _run_with_context(): + return await task_func(context) + + self._tasks[agent_id] = asyncio.create_task(_run_with_context()) + return self._tasks[agent_id] + + gathered_tasks = await asyncio.gather( + *[_run_task(t) for t in tasks], + return_exceptions=True, + ) + + valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)] + logger.info( + "Delegated %d tasks in parallel (%d succeeded)", + len(tasks), + len(valid_tasks), + ) + return valid_tasks + + async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any: + """Wait for subagent task to complete. + + Args: + agent_id: Subagent identifier + timeout: Optional timeout in seconds + + Returns: + Task result + + Raises: + asyncio.TimeoutError: If task doesn't complete in time + KeyError: If agent_id not found + """ + if agent_id not in self._tasks: + raise KeyError(f"Unknown subagent: {agent_id}") + + try: + return await asyncio.wait_for( + self._tasks[agent_id], + timeout=timeout, + ) + except asyncio.TimeoutError: + logger.warning("Task %s timed out after %s seconds", agent_id, timeout) + raise + + async def cancel(self, agent_id: str) -> bool: + """Cancel a subagent task. + + Args: + agent_id: Subagent identifier + + Returns: + True if task was cancelled + """ + if agent_id in self._tasks: + self._tasks[agent_id].cancel() + del self._tasks[agent_id] + logger.info("Cancelled subagent task: %s", agent_id) + return True + return False + + def list_tasks(self) -> List[str]: + """List active subagent task IDs. + + Returns: + List of agent_ids with pending tasks + """ + return list(self._tasks.keys()) + + @property + def tasks(self) -> Dict[str, asyncio.Task]: + """Get copy of active tasks dict.""" + return dict(self._tasks) + + def delegate_task( + self, + task_type: str, + task_data: Dict[str, Any], + target_agent: Optional[str] = None, + ) -> Dict[str, Any]: + """Delegate a task with optional dynamic subagent definition. + + Supports runtime subagent definition via task_data["agents"]: + + task_data = { + "task": "Review code changes", + "agents": { + "code-reviewer": { + "description": "Expert code reviewer...", + "prompt": "Analyze code quality...", + "tools": ["Read", "Glob", "Grep"], + } + } + } + + Args: + task_type: Type of task (e.g., "analysis", "review", "research") + task_data: Task payload, may include "agents" for dynamic subagent def + target_agent: Optional specific agent ID to delegate to + + Returns: + Dict with "success" and result/error + """ + try: + # Extract dynamic subagent definitions from task_data + agents_def = task_data.get("agents", {}) + + if agents_def: + # Register dynamic subagents + for agent_name, agent_spec in agents_def.items(): + self._dynamic_subagents[agent_name] = agent_spec + logger.info( + "Registered dynamic subagent: %s (description: %s)", + agent_name, + agent_spec.get("description", "")[:50], + ) + + # Determine target agent + effective_target = target_agent + if not effective_target: + # Use first available dynamic subagent or default + if agents_def: + effective_target = next(iter(agents_def.keys())) + else: + effective_target = "default" + + # Execute the task + task_result = self._execute_task( + task_type=task_type, + task_data=task_data, + target_agent=effective_target, + ) + + # Clean up dynamic subagents after execution + for agent_name in agents_def.keys(): + self._dynamic_subagents.pop(agent_name, None) + + return { + "success": True, + "result": task_result, + "subagents_used": list(agents_def.keys()) if agents_def else [], + } + + except Exception as e: + logger.error("Task delegation failed: %s", e) + return { + "success": False, + "error": str(e), + } + + def _execute_task( + self, + task_type: str, + task_data: Dict[str, Any], + target_agent: str, + ) -> Any: + """Execute the delegated task. + + Args: + task_type: Type of task + task_data: Task payload + target_agent: Target agent identifier + + Returns: + Task execution result + """ + task_content = task_data.get("task", task_data.get("prompt", "")) + + # Check if we have a dynamic subagent spec for this target + agent_spec = self._dynamic_subagents.get(target_agent) + + if agent_spec: + logger.info( + "Executing task '%s' with dynamic subagent '%s' (prompt: %s)", + task_type, + target_agent, + agent_spec.get("prompt", "")[:50], + ) + # In a full implementation, this would create and run an actual agent + # For now, return a structured result indicating the task was received + return { + "task_type": task_type, + "task": task_content, + "subagent": { + "name": target_agent, + "description": agent_spec.get("description", ""), + "prompt": agent_spec.get("prompt", ""), + "tools": agent_spec.get("tools", []), + }, + "status": "completed", + "message": f"Task '{task_type}' executed with dynamic subagent '{target_agent}'", + } + + # Fallback: execute with default behavior + logger.info( + "Executing task '%s' with default agent '%s'", + task_type, + target_agent, + ) + return { + "task_type": task_type, + "task": task_content, + "target_agent": target_agent, + "status": "completed", + "message": f"Task '{task_type}' executed with agent '{target_agent}'", + } + + def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]: + """Get a dynamically defined subagent specification. + + Args: + name: Subagent name + + Returns: + Subagent spec dict or None if not found + """ + return self._dynamic_subagents.get(name) + + def list_dynamic_subagents(self) -> List[str]: + """List all registered dynamic subagent names. + + Returns: + List of subagent names + """ + return list(self._dynamic_subagents.keys()) + + +__all__ = ["TaskDelegator", "SubagentSpec"] diff --git a/backend/agents/team/team_coordinator.py b/backend/agents/team/team_coordinator.py new file mode 100644 index 0000000..3319f44 --- /dev/null +++ b/backend/agents/team/team_coordinator.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +"""TeamCoordinator - Agent lifecycle management and execution. + +Provides run_parallel() using asyncio.gather() and run_sequential() +for coordinating multiple agents. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable, Dict, List, Optional, Type + +from agentscope.message import Msg + +logger = logging.getLogger(__name__) + + +class TeamCoordinator: + """Coordinates agent lifecycle and execution. + + Supports: + - run_parallel(): Execute multiple agents concurrently with asyncio.gather() + - run_sequential(): Execute agents one after another + - run_phase(): Execute a named phase with registered agents + - register_agent(): Add agent to coordinator + - unregister_agent(): Remove agent from coordinator + + Each agent maintains separate context/memory. + """ + + def __init__( + self, + participants: Optional[List[Any]] = None, + task_content: Optional[str] = None, + messenger: Optional[Any] = None, + registry: Optional[Any] = None, + ): + """Initialize TeamCoordinator. + + Args: + participants: List of agent instances to coordinate + task_content: Task description content for the agents + messenger: AgentMessenger for communication (optional) + registry: AgentRegistry for agent lookup (optional) + """ + self._participants = participants or [] + self._task_content = task_content or "" + self._messenger = messenger + self._registry = registry + self._agents: Dict[str, Any] = {} + self._running_tasks: Dict[str, asyncio.Task] = {} + # Auto-register participants + for agent in self._participants: + if hasattr(agent, "name"): + self._agents[agent.name] = agent + elif hasattr(agent, "id"): + self._agents[agent.id] = agent + + def register_agent(self, agent_id: str, agent: Any) -> None: + """Register an agent with the coordinator. + + Args: + agent_id: Unique agent identifier + agent: Agent instance + """ + self._agents[agent_id] = agent + logger.info("Registered agent: %s", agent_id) + + def unregister_agent(self, agent_id: str) -> None: + """Unregister an agent from the coordinator. + + Args: + agent_id: Agent identifier to remove + """ + if agent_id in self._agents: + del self._agents[agent_id] + logger.info("Unregistered agent: %s", agent_id) + + def get_agent(self, agent_id: str) -> Any: + """Get registered agent by ID. + + Args: + agent_id: Agent identifier + + Returns: + Agent instance + """ + return self._agents.get(agent_id) + + def list_agents(self) -> List[str]: + """List all registered agent IDs. + + Returns: + List of agent identifiers + """ + return list(self._agents.keys()) + + async def run_parallel( + self, + agent_ids: List[str], + initial_message: Optional[Msg] = None, + ) -> Dict[str, Any]: + """Run multiple agents in parallel using asyncio.gather(). + + Args: + agent_ids: List of agent IDs to run concurrently + initial_message: Optional initial message to broadcast + + Returns: + Dict mapping agent_id to result + """ + async def _run_agent(aid: str) -> tuple[str, Any]: + agent = self._agents.get(aid) + if agent is None: + logger.error("Agent %s not found", aid) + return (aid, None) + + try: + if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): + if initial_message: + result = await agent.reply(initial_message) + else: + result = await agent.reply() + elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): + result = await agent.run() + else: + result = await agent() + logger.info("Agent %s completed successfully", aid) + return (aid, result) + except Exception as e: + logger.error("Agent %s failed: %s", aid, e) + return (aid, {"error": str(e)}) + + results = await asyncio.gather( + *[_run_agent(aid) for aid in agent_ids], + return_exceptions=True, + ) + + output: Dict[str, Any] = {} + for result in results: + if isinstance(result, tuple): + agent_id, agent_result = result + output[agent_id] = agent_result + else: + logger.error("Unexpected result from asyncio.gather: %s", result) + + logger.info("Parallel run completed for %d agents", len(agent_ids)) + return output + + async def run_sequential( + self, + agent_ids: List[str], + initial_message: Optional[Msg] = None, + ) -> Dict[str, Any]: + """Run agents one after another in order. + + Args: + agent_ids: List of agent IDs to run in sequence + initial_message: Optional initial message for first agent + + Returns: + Dict mapping agent_id to result + """ + output: Dict[str, Any] = {} + current_message = initial_message + + for agent_id in agent_ids: + agent = self._agents.get(agent_id) + if agent is None: + logger.error("Agent %s not found", agent_id) + output[agent_id] = {"error": "Agent not found"} + continue + + try: + if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): + result = await agent.reply(current_message) + elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): + result = await agent.run() + else: + result = await agent() + + output[agent_id] = result + current_message = result + logger.info("Agent %s completed sequentially", agent_id) + + except Exception as e: + logger.error("Agent %s failed: %s", agent_id, e) + output[agent_id] = {"error": str(e)} + break + + logger.info("Sequential run completed for %d agents", len(agent_ids)) + return output + + async def run_phase( + self, + phase_name: str, + agent_ids: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> List[Any]: + """Execute a named phase with registered agents. + + Args: + phase_name: Name of the phase (e.g., "analyst_analysis") + agent_ids: Optional list of agent IDs; if None, uses all registered + metadata: Optional metadata to include in the message (e.g., tickers, date) + + Returns: + List of results from each agent + """ + if agent_ids is None: + agent_ids = list(self._agents.keys()) + + _agent_ids = [aid for aid in agent_ids if aid in self._agents] + + logger.info( + "Running phase '%s' with %d agents: %s", + phase_name, + len(_agent_ids), + _agent_ids, + ) + + # Create messages for each agent + results: List[Any] = [] + for agent_id in _agent_ids: + agent = self._agents[agent_id] + try: + if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): + # Create a message for the agent with proper structure + msg = Msg( + name="system", + content=self._task_content or f"Please execute phase: {phase_name}", + role="user", + metadata=metadata, + ) + result = await agent.reply(msg) + elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): + result = await agent.run() + else: + result = await agent() + results.append(result) + logger.info("Phase '%s': Agent %s completed", phase_name, agent_id) + except Exception as e: + logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e) + results.append(None) + + logger.info("Phase '%s' completed with %d results", phase_name, len(results)) + return results + + async def run_with_dependencies( + self, + agent_tasks: Dict[str, List[str]], + initial_message: Optional[Msg] = None, + ) -> Dict[str, Any]: + """Run agents respecting dependency graph. + + Args: + agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids + initial_message: Optional initial message + + Returns: + Dict mapping agent_id to result + """ + completed: Dict[str, Any] = {} + remaining = set(agent_tasks.keys()) + + while remaining: + ready = [ + aid for aid in remaining + if all(dep in completed for dep in agent_tasks.get(aid, [])) + ] + + if not ready: + logger.error("Circular dependency detected in agent tasks") + for aid in remaining: + completed[aid] = {"error": "Circular dependency"} + break + + results = await self.run_parallel(ready, initial_message) + completed.update(results) + + for aid in ready: + remaining.discard(aid) + initial_message = results.get(aid) + + return completed + + async def fanout_pipeline( + self, + agents: List[Any], + msg: Optional[Msg] = None, + ) -> List[Msg]: + """Fanout a message to multiple agents concurrently and collect all responses. + + Similar to AgentScope's fanout_pipeline, this sends the same message + to all specified agents and returns a list of all agent responses. + + Args: + agents: List of agent instances to fanout the message to + msg: Message to send to all agents (optional) + + Returns: + List of Msg responses from each agent (in the same order as input agents) + + Example: + >>> responses = await fanout_pipeline( + ... agents=[alice, bob, charlie], + ... msg=question, + ... ) + >>> # responses is a list of Msg responses from each agent + """ + async def _fanout_to_agent(agent: Any) -> Optional[Msg]: + """Send message to a single agent and return its response.""" + try: + if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): + result = await agent.reply(msg) if msg is not None else await agent.reply() + elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): + result = await agent.run() + else: + result = await agent() + + # Convert result to Msg if needed + if result is None: + return None + if isinstance(result, Msg): + return result + # If result is a dict with content, wrap it + if isinstance(result, dict) and "content" in result: + return Msg( + name=getattr(agent, "name", "unknown"), + content=result.get("content", ""), + role="assistant", + metadata=result.get("metadata"), + ) + # Otherwise wrap the result + return Msg( + name=getattr(agent, "name", "unknown"), + content=str(result), + role="assistant", + ) + except Exception as e: + logger.error("Agent %s failed in fanout_pipeline: %s", + getattr(agent, "name", "unknown"), e) + return None + + # Run all agents concurrently + results = await asyncio.gather( + *[_fanout_to_agent(agent) for agent in agents], + return_exceptions=True, + ) + + # Filter out exceptions and keep only valid responses + responses: List[Msg] = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error("Fanout to agent %d failed: %s", i, result) + responses.append(None) # type: ignore[arg-type] + else: + responses.append(result) # type: ignore[arg-type] + + logger.info("Fanout pipeline completed for %d agents", len(agents)) + return responses + + async def shutdown(self, timeout: Optional[float] = 5.0) -> None: + """Shutdown all running agents gracefully. + + Args: + timeout: Timeout for graceful shutdown + """ + logger.info("Shutting down TeamCoordinator...") + + cancel_tasks = [ + asyncio.create_task(asyncio.wait_for(task, timeout=timeout)) + for task in self._running_tasks.values() + ] + + if cancel_tasks: + await asyncio.gather(*cancel_tasks, return_exceptions=True) + + self._running_tasks.clear() + logger.info("TeamCoordinator shutdown complete") + + @property + def agents(self) -> Dict[str, Any]: + """Get copy of registered agents dict.""" + return dict(self._agents) + + +__all__ = ["TeamCoordinator"] diff --git a/backend/agents/team_pipeline_config.py b/backend/agents/team_pipeline_config.py new file mode 100644 index 0000000..e427973 --- /dev/null +++ b/backend/agents/team_pipeline_config.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +"""Run-scoped team pipeline configuration helpers.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Iterable, List, Dict, Any + +import yaml + + +DEFAULT_FILENAME = "TEAM_PIPELINE.yaml" + + +def team_pipeline_path(project_root: Path, config_name: str) -> Path: + """Return run-scoped team pipeline config path.""" + return project_root / "runs" / config_name / DEFAULT_FILENAME + + +def ensure_team_pipeline_config( + project_root: Path, + config_name: str, + default_analysts: Iterable[str], +) -> Path: + """Ensure TEAM_PIPELINE.yaml exists for one run.""" + path = team_pipeline_path(project_root, config_name) + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists(): + return path + + payload = { + "version": 1, + "controller_agent": "portfolio_manager", + "discussion": { + "allow_dynamic_team_update": True, + "active_analysts": list(default_analysts), + }, + "decision": { + "require_risk_manager": True, + }, + } + path.write_text( + yaml.safe_dump(payload, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + return path + + +def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]: + """Load TEAM_PIPELINE.yaml and return parsed dict.""" + path = team_pipeline_path(project_root, config_name) + if not path.exists(): + return {} + parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {} + return parsed if isinstance(parsed, dict) else {} + + +def save_team_pipeline_config( + project_root: Path, + config_name: str, + config: Dict[str, Any], +) -> Path: + """Persist TEAM_PIPELINE.yaml.""" + path = team_pipeline_path(project_root, config_name) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + yaml.safe_dump(config, allow_unicode=True, sort_keys=False), + encoding="utf-8", + ) + return path + + +def resolve_active_analysts( + project_root: Path, + config_name: str, + available_analysts: Iterable[str], +) -> List[str]: + """Resolve active analysts from TEAM_PIPELINE.yaml.""" + available = [item for item in available_analysts] + parsed = load_team_pipeline_config(project_root, config_name) + discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {} + configured = discussion.get("active_analysts", []) + if not isinstance(configured, list) or not configured: + return available + + active = [item for item in configured if item in available] + return active or available + + +def update_active_analysts( + project_root: Path, + config_name: str, + available_analysts: Iterable[str], + *, + add: Iterable[str] | None = None, + remove: Iterable[str] | None = None, + set_to: Iterable[str] | None = None, +) -> List[str]: + """Update active analysts and persist TEAM_PIPELINE.yaml.""" + available = [item for item in available_analysts] + ensure_team_pipeline_config(project_root, config_name, available) + parsed = load_team_pipeline_config(project_root, config_name) + discussion = parsed.setdefault("discussion", {}) + if not isinstance(discussion, dict): + discussion = {} + parsed["discussion"] = discussion + + current = discussion.get("active_analysts", []) + if not isinstance(current, list): + current = [] + current = [item for item in current if item in available] + if not current: + current = list(available) + + if set_to is not None: + target = [item for item in set_to if item in available] + current = target or current + + for item in add or []: + if item in available and item not in current: + current.append(item) + + for item in remove or []: + current = [existing for existing in current if existing != item] + + if not current: + current = [available[0]] if available else [] + + discussion["active_analysts"] = current + save_team_pipeline_config(project_root, config_name, parsed) + return current + diff --git a/backend/gateway_server.py b/backend/gateway_server.py new file mode 100644 index 0000000..3a4ee16 --- /dev/null +++ b/backend/gateway_server.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +"""Gateway Server - Entry point for Gateway subprocess. + +This module is launched as a subprocess by the Control Plane (FastAPI) +to run the Data Plane (Gateway + Pipeline). +""" + +import argparse +import asyncio +import json +import logging +import os +import sys +from contextlib import AsyncExitStack +from pathlib import Path + +from dotenv import load_dotenv + +# Load environment variables +load_dotenv() + +from backend.agents import AnalystAgent, PMAgent, RiskAgent +from backend.agents.skills_manager import SkillsManager +from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles +from backend.agents.prompt_loader import PromptLoader +from backend.agents.workspace_manager import WorkspaceManager +from backend.config.constants import ANALYST_TYPES +from backend.core.pipeline import TradingPipeline +from backend.core.pipeline_runner import create_agents, create_long_term_memory +from backend.core.scheduler import BacktestScheduler, Scheduler +from backend.llm.models import get_agent_formatter, get_agent_model +from backend.runtime.manager import ( + TradingRuntimeManager, + set_global_runtime_manager, + clear_global_runtime_manager, +) +from backend.services.gateway import Gateway +from backend.services.market import MarketService +from backend.services.storage import StorageService +from backend.utils.settlement import SettlementCoordinator + +logger = logging.getLogger(__name__) +_prompt_loader = PromptLoader() + + +async def run_gateway( + run_id: str, + run_dir: Path, + bootstrap: dict, + port: int +): + """Run Gateway with Pipeline.""" + + # Extract config + tickers = bootstrap.get("tickers", ["AAPL", "MSFT"]) + initial_cash = float(bootstrap.get("initial_cash", 100000.0)) + margin_requirement = float(bootstrap.get("margin_requirement", 0.0)) + max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2)) + schedule_mode = bootstrap.get("schedule_mode", "daily") + trigger_time = bootstrap.get("trigger_time", "09:30") + interval_minutes = int(bootstrap.get("interval_minutes", 60)) + heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled + mode = bootstrap.get("mode", "live") + start_date = bootstrap.get("start_date") + end_date = bootstrap.get("end_date") + enable_memory = bootstrap.get("enable_memory", False) + poll_interval = int(bootstrap.get("poll_interval", 10)) + enable_mock = bootstrap.get("enable_mock", False) + + is_backtest = mode == "backtest" + is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true") + + logger.info(f"[Gateway Server] Starting run {run_id} on port {port}") + + # Create runtime manager + runtime_manager = TradingRuntimeManager( + config_name=run_id, + run_dir=run_dir, + bootstrap=bootstrap, + ) + runtime_manager.prepare_run() + set_global_runtime_manager(runtime_manager) + + try: + async with AsyncExitStack() as stack: + # Create services + market_service = MarketService( + tickers=tickers, + poll_interval=poll_interval, + mock_mode=is_mock and not is_backtest, + backtest_mode=is_backtest, + api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None, + backtest_start_date=start_date if is_backtest else None, + backtest_end_date=end_date if is_backtest else None, + ) + + storage_service = StorageService( + dashboard_dir=run_dir / "team_dashboard", + initial_cash=initial_cash, + config_name=run_id, + ) + + if not storage_service.files["summary"].exists(): + storage_service.initialize_empty_dashboard() + else: + storage_service.update_leaderboard_model_info() + + # Create agents + analysts, risk_manager, pm, long_term_memories = create_agents( + run_id=run_id, + run_dir=run_dir, + initial_cash=initial_cash, + margin_requirement=margin_requirement, + enable_long_term_memory=enable_memory, + ) + + # Register agents + for agent in analysts + [risk_manager, pm]: + agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None) + if agent_id: + runtime_manager.register_agent(agent_id) + + # Load portfolio state + portfolio_state = storage_service.load_portfolio_state() + pm.load_portfolio_state(portfolio_state) + + # Create settlement coordinator + settlement_coordinator = SettlementCoordinator( + storage=storage_service, + initial_capital=initial_cash, + ) + + # Create pipeline + pipeline = TradingPipeline( + analysts=analysts, + risk_manager=risk_manager, + portfolio_manager=pm, + settlement_coordinator=settlement_coordinator, + max_comm_cycles=max_comm_cycles, + runtime_manager=runtime_manager, + ) + + # Create scheduler + scheduler_callback = None + live_scheduler = None + + if is_backtest: + backtest_scheduler = BacktestScheduler( + start_date=start_date, + end_date=end_date, + trading_calendar="NYSE", + delay_between_days=0.5, + ) + + async def scheduler_callback_fn(callback): + await backtest_scheduler.start(callback) + + scheduler_callback = scheduler_callback_fn + else: + live_scheduler = Scheduler( + mode=schedule_mode, + trigger_time=trigger_time, + interval_minutes=interval_minutes, + heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None, + config={"config_name": run_id}, + ) + + async def scheduler_callback_fn(callback): + await live_scheduler.start(callback) + + scheduler_callback = scheduler_callback_fn + + # Enter long-term memory contexts + for memory in long_term_memories: + await stack.enter_async_context(memory) + + # Create Gateway + gateway = Gateway( + market_service=market_service, + storage_service=storage_service, + pipeline=pipeline, + scheduler_callback=scheduler_callback, + config={ + "mode": mode, + "mock_mode": is_mock, + "backtest_mode": is_backtest, + "tickers": tickers, + "config_name": run_id, + "schedule_mode": schedule_mode, + "interval_minutes": interval_minutes, + "trigger_time": trigger_time, + "heartbeat_interval": heartbeat_interval, + "initial_cash": initial_cash, + "margin_requirement": margin_requirement, + "max_comm_cycles": max_comm_cycles, + "enable_memory": enable_memory, + }, + scheduler=live_scheduler, + ) + + # Start Gateway (blocks until shutdown) + logger.info(f"[Gateway Server] Gateway starting on port {port}") + await gateway.start(host="0.0.0.0", port=port) + + except asyncio.CancelledError: + logger.info("[Gateway Server] Cancelled") + raise + finally: + logger.info("[Gateway Server] Cleaning up") + clear_global_runtime_manager() + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser(description="Gateway Server") + parser.add_argument("--run-id", required=True, help="Run identifier") + parser.add_argument("--run-dir", required=True, help="Run directory path") + parser.add_argument("--port", type=int, default=8765, help="WebSocket port") + parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON") + parser.add_argument("--verbose", action="store_true", help="Verbose logging") + + args = parser.parse_args() + + # Setup logging + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s", + ) + + # Parse bootstrap + bootstrap = json.loads(args.bootstrap) + run_dir = Path(args.run_dir) + + # Run + try: + asyncio.run(run_gateway( + run_id=args.run_id, + run_dir=run_dir, + bootstrap=bootstrap, + port=args.port + )) + except KeyboardInterrupt: + logger.info("[Gateway Server] Interrupted by user") + except Exception as e: + logger.exception(f"[Gateway Server] Fatal error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/backend/skills/SKILL_TEMPLATE.md b/backend/skills/SKILL_TEMPLATE.md new file mode 100644 index 0000000..d4ed162 --- /dev/null +++ b/backend/skills/SKILL_TEMPLATE.md @@ -0,0 +1,119 @@ +# Skill Template (Anthropic + AgentScope Aligned) + +> 用于定义可执行、可路由、可评估的技能规范。 +> 建议所有 `SKILL.md` 至少覆盖以下 6 个部分。 + +--- + +## Frontmatter Spec + +All `SKILL.md` files should begin with a YAML frontmatter block: + +```yaml +--- +name: skill_name # Required. Unique identifier for the skill. +description: ... # Required. One-line description of the skill. +version: "1.0.0" # Optional. Semantic version string. +tools: [...] # Optional. Tools provided or used by this skill. +allowed_tools: [...] # Optional. List of tool names permitted when this skill is active. +denied_tools: [...] # Optional. List of tool names denied when this skill is active. +--- +``` + +### Frontmatter Fields + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique skill identifier (kebab-case recommended). | +| `description` | string | Human-readable one-line description. | +| `version` | string | Semantic version (e.g., `"1.0.0"`). | +| `tools` | list[string] | Tools provided by or associated with this skill. | +| `allowed_tools` | list[string] | Enumerates which tools are **permitted** when this skill is active. If set, only these tools may be used. | +| `denied_tools` | list[string] | Enumerates which tools are **forbidden** when this skill is active. Denied tools take precedence over `allowed_tools`. | + +### Tool Restriction Rules + +- If **only** `allowed_tools` is set: only those tools are accessible. +- If **only** `denied_tools` is set: all tools except those are accessible. +- If **both** are set: `allowed_tools` defines the initial set, then `denied_tools` removes from it. +- **Denial takes precedence**: a tool in `denied_tools` is always blocked even if also in `allowed_tools`. + +--- + +## 1) When to use + +- 明确触发条件(任务类型、关键词、场景)。 +- 明确不应使用该技能的边界(避免误触发)。 + +## 2) Required inputs + +- 列出最小必要输入(如 `tickers`、价格、组合状态、风险约束)。 +- 声明输入缺失时的处理规则(终止 / 降级 / 请求补充)。 + +## 3) Decision procedure + +- 采用固定步骤,确保可复现。 +- 每一步说明目标、判据和产物(例如中间结论)。 +- 标明冲突处理逻辑(信号冲突、数据冲突、置信度冲突)。 + +## 4) Tool call policy + +- 说明优先使用哪些工具组与工具。 +- 规定何时可以“无工具直接结论”,何时必须工具先证据后结论。 +- 规定工具失败、超时、返回异常时的替代动作。 + +## 5) Output schema + +- 定义标准输出字段,便于下游 Agent 消费与评估。 +- 推荐包含:`signal`、`confidence`、`reasons`、`risks`、`invalidation`、`next_action`。 +- 若是组合决策技能,必须包含每个 ticker 的 `action` 与 `quantity`。 + +## 6) Failure fallback + +- 规定在数据不足、信号冲突、风险超限、工具不可用时的降级策略。 +- 默认优先“保守 + 可解释 + 可执行”的输出。 + +## Optional: Evaluation hooks + +定义技能的可评估指标,用于后续记忆/反思阶段写入长期经验。 + +### 支持的指标类型 + +| 指标类型 | 描述 | 适用技能 | +|---------|------|---------| +| `hit_rate` | 信号命中率 - 决策信号与实际结果的符合程度 | sentiment_review, technical_review | +| `risk_violation` | 风控违例率 - 触发风控规则的次数 | risk_review, portfolio_decisioning | +| `position_deviation` | 仓位偏离率 - 建议仓位与实际执行仓位的偏差 | portfolio_decisioning | +| `pnl_attribution` | P&L 归因一致性 - 收益归因与实际收益的匹配度 | fundamental_review, valuation_review | +| `signal_consistency` | 信号一致性 - 多来源信号的一致程度 | sentiment_review | +| `decision_latency` | 决策延迟 - 从输入到决策的耗时 | portfolio_decisioning | +| `tool_usage` | 工具使用率 - 工具调用次数与成功率的比值 | 所有技能 | +| `custom` | 自定义指标 | 特定业务场景 | + +### 使用方式 + +```python +from backend.agents.base.evaluation_hook import EvaluationHook, MetricType + +# 在技能执行开始时 +evaluation_hook.start_evaluation( + skill_name="technical_review", + inputs={"tickers": ["AAPL"], "prices": {...}} +) + +# 在技能执行过程中添加指标 +evaluation_hook.add_metric( + name="signal_confidence", + metric_type=MetricType.HIT_RATE, + value=0.85, + metadata={"method": "rsi", "threshold": 30} +) + +# 在技能完成时记录结果 +evaluation_hook.record_outputs({"signal": "buy", "confidence": 0.8}) +evaluation_hook.complete_evaluation(success=True) +``` + +### 评估结果存储 + +评估结果自动保存到 `runs/{run_id}/evaluations/{agent_id}/{skill_name}_{timestamp}.json` diff --git a/backend/tests/test_heartbeat_hook.py b/backend/tests/test_heartbeat_hook.py new file mode 100644 index 0000000..e2927c7 --- /dev/null +++ b/backend/tests/test_heartbeat_hook.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +"""Tests for HeartbeatHook.""" +import tempfile +from pathlib import Path + +import pytest + +from backend.agents.base.hooks import HeartbeatHook + + +class TestHeartbeatHook: + """Tests for HeartbeatHook._read_heartbeat_content.""" + + def test_read_heartbeat_content_with_content(self, tmp_path): + """Test reading HEARTBEAT.md when it exists and has content.""" + ws_dir = tmp_path / "analyst_workspace" + ws_dir.mkdir() + hb_file = ws_dir / "HEARTBEAT.md" + hb_file.write_text("# 定期主动检查\n\n- [ ] 持仓是否健康\n", encoding="utf-8") + + hook = HeartbeatHook(workspace_dir=ws_dir) + content = hook._read_heartbeat_content() + + assert content is not None + assert "# 定期主动检查" in content + assert "持仓是否健康" in content + + def test_read_heartbeat_content_absent(self, tmp_path): + """Test reading when HEARTBEAT.md does not exist.""" + ws_dir = tmp_path / "analyst_workspace" + ws_dir.mkdir() + + hook = HeartbeatHook(workspace_dir=ws_dir) + content = hook._read_heartbeat_content() + + assert content is None + + def test_read_heartbeat_content_empty(self, tmp_path): + """Test reading when HEARTBEAT.md is empty.""" + ws_dir = tmp_path / "analyst_workspace" + ws_dir.mkdir() + hb_file = ws_dir / "HEARTBEAT.md" + hb_file.write_text("", encoding="utf-8") + + hook = HeartbeatHook(workspace_dir=ws_dir) + content = hook._read_heartbeat_content() + + assert content is None + + def test_read_heartbeat_content_whitespace_only(self, tmp_path): + """Test reading when HEARTBEAT.md contains only whitespace.""" + ws_dir = tmp_path / "analyst_workspace" + ws_dir.mkdir() + hb_file = ws_dir / "HEARTBEAT.md" + hb_file.write_text(" \n\n ", encoding="utf-8") + + hook = HeartbeatHook(workspace_dir=ws_dir) + content = hook._read_heartbeat_content() + + assert content is None + + def test_completed_flag_path(self, tmp_path): + """Test that completion flag is placed in workspace directory.""" + ws_dir = tmp_path / "analyst_workspace" + ws_dir.mkdir() + + hook = HeartbeatHook(workspace_dir=ws_dir) + + assert hook._completed_flag == ws_dir / ".heartbeat_completed" diff --git a/frontend/src/components/explain/ExplainInsiderSection.jsx b/frontend/src/components/explain/ExplainInsiderSection.jsx new file mode 100644 index 0000000..4881b7e --- /dev/null +++ b/frontend/src/components/explain/ExplainInsiderSection.jsx @@ -0,0 +1,107 @@ +import React from 'react'; +import { formatDateTime, formatNumber } from '../../utils/formatters'; + +export default function ExplainInsiderSection({ + insiderTrades, + selectedSymbol, + isOpen, + onToggle, + onRequest, +}) { + const handleRefresh = () => { + if (onRequest) { + onRequest(selectedSymbol); + } + }; + + return ( +
+
+

内部人交易

+
+
+ {insiderTrades.length} 笔内部人交易记录 +
+ + +
+
+ + {!isOpen ? ( +
点击展开查看内部人交易详情
+ ) : insiderTrades.length === 0 ? ( +
暂无内部人交易数据
+ ) : ( +
+ + + + + + + + + + + + + + {insiderTrades.slice(0, 20).map((trade, index) => { + const isBuy = trade.is_buy; + const holdingChange = trade.holding_change; + return ( + + + + + + + + + + ); + })} + +
交易日期内部人职位方向股份数价格持仓变化
{trade.transaction_date || '-'}{trade.name || '-'}{trade.title || '-'} + {isBuy === true ? '买入' : isBuy === false ? '卖出' : '-'} + {trade.transaction_shares != null ? formatNumber(trade.transaction_shares) : '-'}${trade.transaction_price_per_share != null ? Number(trade.transaction_price_per_share).toFixed(2) : '-'} 0 ? '#00C853' : '#FF1744') : '#666666', + fontWeight: holdingChange != null ? 700 : 400 + }}> + {holdingChange != null ? (holdingChange > 0 ? '+' : '') + formatNumber(holdingChange) : '-'} +
+
+ )} +
+ ); +} diff --git a/frontend/src/components/explain/ExplainTechnicalSection.jsx b/frontend/src/components/explain/ExplainTechnicalSection.jsx new file mode 100644 index 0000000..6a88804 --- /dev/null +++ b/frontend/src/components/explain/ExplainTechnicalSection.jsx @@ -0,0 +1,309 @@ +import React from 'react'; +import { formatNumber } from '../../utils/formatters'; + +export default function ExplainTechnicalSection({ + technicalIndicators, + selectedSymbol, + isOpen, + onToggle, +}) { + const formatPct = (value) => { + if (value == null) return '-'; + return `${value >= 0 ? '+' : ''}${value.toFixed(2)}%`; + }; + + const formatPrice = (value) => { + if (value == null) return '-'; + return `$${value.toFixed(2)}`; + }; + + const rsiStatusColor = (status) => { + if (status === 'oversold') return '#00C853'; + if (status === 'overbought') return '#FF1744'; + return '#666666'; + }; + + const riskColor = (level) => { + if (level === 'HIGH RISK') return '#FF1744'; + if (level === 'MODERATE RISK') return '#FF9800'; + return '#00C853'; + }; + + if (!technicalIndicators) { + return ( +
+
+

技术指标

+
+
+ 加载中... +
+ +
+
+ {isOpen && ( +
正在加载技术指标数据...
+ )} +
+ ); + } + + return ( +
+
+

技术指标

+
+
+ {technicalIndicators.trend} · {technicalIndicators.mean_reversion} +
+ +
+
+ + {!isOpen ? ( +
点击展开查看技术指标详情
+ ) : ( +
+ {/* MA Section */} +
+
+ 移动平均线 +
+
+
+ MA5 + {formatPrice(technicalIndicators.ma?.ma5)} + 0 ? '#00C853' : '#FF1744', fontWeight: 700 }}> + {formatPct(technicalIndicators.ma?.distance?.ma5)} + +
+
+ MA10 + {formatPrice(technicalIndicators.ma?.ma10)} + 0 ? '#00C853' : '#FF1744', fontWeight: 700 }}> + {formatPct(technicalIndicators.ma?.distance?.ma10)} + +
+
+ MA20 + {formatPrice(technicalIndicators.ma?.ma20)} + 0 ? '#00C853' : '#FF1744', fontWeight: 700 }}> + {formatPct(technicalIndicators.ma?.distance?.ma20)} + +
+
+ MA50 + {formatPrice(technicalIndicators.ma?.ma50)} + 0 ? '#00C853' : '#FF1744', fontWeight: 700 }}> + {formatPct(technicalIndicators.ma?.distance?.ma50)} + +
+
+ MA200 + {formatPrice(technicalIndicators.ma?.ma200)} + 0 ? '#00C853' : '#FF1744', fontWeight: 700 }}> + {formatPct(technicalIndicators.ma?.distance?.ma200)} + +
+
+
+ + {/* RSI Section */} +
+
+ RSI (14) +
+
+
+ {technicalIndicators.rsi?.rsi14?.toFixed(1) || '-'} +
+
+
+ {technicalIndicators.rsi?.status === 'oversold' ? '超卖' : + technicalIndicators.rsi?.status === 'overbought' ? '超买' : '中性'} +
+
+ <30 超卖 >70 超买 +
+
+
+ {/* RSI Gauge */} +
+
+
+
+
+
+ + {/* MACD Section */} +
+
+ MACD +
+
+
+ MACD 线 + 0 ? '#00C853' : '#FF1744' }}> + {formatPrice(technicalIndicators.macd?.macd)} + +
+
+ Signal 线 + {formatPrice(technicalIndicators.macd?.signal)} +
+
+ 柱状图 + 0 ? '#00C853' : '#FF1744' }}> + {formatPrice(technicalIndicators.macd?.histogram)} + +
+
+
+ + {/* Bollinger Bands Section */} +
+
+ 布林带 +
+
+
+ 上轨 + + {formatPrice(technicalIndicators.bollinger?.upper)} + +
+
+ 中轨 + {formatPrice(technicalIndicators.bollinger?.mid)} +
+
+ 下轨 + + {formatPrice(technicalIndicators.bollinger?.lower)} + +
+
+
+ + {/* Volatility Section */} +
+
+ 波动率 +
+
+
+ 10日 + {formatPct(technicalIndicators.volatility?.vol_10d)} +
+
+ 20日 + {formatPct(technicalIndicators.volatility?.vol_20d)} +
+
+ 60日 + {formatPct(technicalIndicators.volatility?.vol_60d)} +
+
+ 年化波动率 + {formatPct(technicalIndicators.volatility?.annualized)} +
+
+ 风险等级 + + {technicalIndicators.volatility?.risk_level || '-'} + +
+
+
+ + {/* Trend Summary */} +
+
+ 趋势判断 +
+
+
+
+ {technicalIndicators.trend || '-'} +
+
+
+
+ {technicalIndicators.mean_reversion || '-'} +
+
+
+ 当前价格: {formatPrice(technicalIndicators.current_price)} +
+
+
+
+ )} +
+ ); +} diff --git a/frontend/src/config/constants.js b/frontend/src/config/constants.js index 39d4ca2..93a9eb8 100644 --- a/frontend/src/config/constants.js +++ b/frontend/src/config/constants.js @@ -130,7 +130,7 @@ export const CHART_MARGIN = { left: 60, right: 20, top: 20, bottom: 40 }; export const AXIS_TICKS = 5; // WebSocket configuration -export const WS_URL = import.meta.env.VITE_WS_URL || "ws://localhost:8765"; +export const WS_URL = import.meta.env.VITE_WS_URL || "ws://localhost:8000"; // Initial ticker symbols for the production watchlist export const INITIAL_TICKERS = [ diff --git a/services/README.md b/services/README.md new file mode 100644 index 0000000..361a91e --- /dev/null +++ b/services/README.md @@ -0,0 +1,75 @@ +# EvoTraders Services Architecture + +This document describes the modular service architecture for EvoTraders. + +## Architecture + +EvoTraders uses a **modular single-process architecture** with services as Python modules: + +``` +backend/ +├── app.py # FastAPI entry point (port 8000) +├── main.py # CLI trading system entry point +├── api/ # REST API routes +│ ├── agents.py # Agent management +│ ├── guard.py # Tool guard +│ ├── runtime.py # Runtime management +│ └── workspaces.py # Workspace management +├── agents/ # Multi-agent system +│ ├── base/ # Base agent classes +│ ├── team/ # Team coordination +│ └── skills/ # Agent skills +├── core/ # Pipeline & scheduler +├── services/ # Core services +│ ├── gateway.py # WebSocket gateway +│ ├── market.py # Market data service +│ └── storage.py # Storage service +└── services/ # Modular services (optional) + ├── trading/ # Trading module + ├── news/ # News module + └── agents/ # Agents module +``` + +## Entry Points + +| Entry Point | Port | Purpose | +|------------|------|---------| +| `backend/app.py` | 8000 | FastAPI REST API | +| `backend/main.py` | CLI | Trading system (live/backtest) | + +## Running + +```bash +# Development mode (FastAPI only) +./start-dev.sh + +# Or manually +python -m uvicorn backend.app:app --port 8000 --reload + +# Trading system (CLI) +evotraders live --mock +evotraders backtest --start 2025-11-01 --end 2025-12-01 +``` + +## Service Modules + +| Module | Description | +|--------|-------------| +| `gateway.py` | WebSocket gateway for frontend communication | +| `market.py` | Market data fetching (prices, news, financials) | +| `storage.py` | Dashboard state and trade history persistence | + +## Module Dependencies + +``` +app.py (FastAPI) +├── runtime_router +│ └── backend/main.py (when task starts) +│ └── Gateway +│ ├── MarketService +│ ├── StorageService +│ └── TradingPipeline +│ ├── Analysts (4x) +│ ├── RiskManager +│ └── PortfolioManager +``` diff --git a/shared/client/__init__.py b/shared/client/__init__.py new file mode 100644 index 0000000..cb0814d --- /dev/null +++ b/shared/client/__init__.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +"""Shared client package.""" + +from shared.client.trading_client import TradingServiceClient +from shared.client.news_client import NewsServiceClient +from shared.client.agent_client import AgentServiceClient + +__all__ = [ + "TradingServiceClient", + "NewsServiceClient", + "AgentServiceClient", +] diff --git a/shared/client/agent_client.py b/shared/client/agent_client.py new file mode 100644 index 0000000..f9edf37 --- /dev/null +++ b/shared/client/agent_client.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +"""Agent service client for agent orchestration and runtime operations.""" + +import json +from typing import Any, AsyncIterator + +import httpx +import websockets + +from shared.schema.signals import AgentStateData + + +class AgentServiceClient: + """Async client for the Agent Service API.""" + + def __init__(self, base_url: str = "http://localhost:8000"): + """Initialize the client with a base URL. + + Args: + base_url: Base URL for the agent service API. + """ + self.base_url = base_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + + async def __aenter__(self) -> "AgentServiceClient": + self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + if self._client: + await self._client.aclose() + + async def get_agents(self) -> dict: + """Get list of all registered agents. + + Returns: + Dictionary with agent list. + """ + response = await self._client.get("/api/agents") + response.raise_for_status() + return response.json() + + async def get_agent_status(self, agent_id: str) -> dict: + """Get status of a specific agent. + + Args: + agent_id: The agent identifier. + + Returns: + Dictionary with agent status. + """ + response = await self._client.get(f"/api/agents/{agent_id}/status") + response.raise_for_status() + return response.json() + + async def post_run_daily( + self, + tickers: list[str], + start_date: str, + end_date: str, + runtime_config: dict[str, Any] | None = None, + ) -> dict: + """Trigger a daily analysis run. + + Args: + tickers: List of stock tickers to analyze. + start_date: Start date (YYYY-MM-DD). + end_date: End date (YYYY-MM-DD). + runtime_config: Optional runtime configuration. + + Returns: + Dictionary with run initiation response. + """ + payload = { + "tickers": tickers, + "start_date": start_date, + "end_date": end_date, + } + if runtime_config: + payload["runtime_config"] = runtime_config + response = await self._client.post("/api/run/daily", json=payload) + response.raise_for_status() + return response.json() + + async def get_run_status(self, run_id: str) -> dict: + """Get status of a run. + + Args: + run_id: The run identifier. + + Returns: + Dictionary with run status. + """ + response = await self._client.get(f"/api/runs/{run_id}/status") + response.raise_for_status() + return response.json() + + async def get_run_result(self, run_id: str) -> AgentStateData: + """Get the result of a completed run. + + Args: + run_id: The run identifier. + + Returns: + AgentStateData with run results. + """ + response = await self._client.get(f"/api/runs/{run_id}/result") + response.raise_for_status() + return AgentStateData.model_validate(response.json()) + + async def get_run_logs(self, run_id: str) -> dict: + """Get logs for a run. + + Args: + run_id: The run identifier. + + Returns: + Dictionary with run logs. + """ + response = await self._client.get(f"/api/runs/{run_id}/logs") + response.raise_for_status() + return response.json() + + async def cancel_run(self, run_id: str) -> dict: + """Cancel a running task. + + Args: + run_id: The run identifier. + + Returns: + Dictionary with cancellation confirmation. + """ + response = await self._client.post(f"/api/runs/{run_id}/cancel") + response.raise_for_status() + return response.json() + + async def get_runtime_config(self) -> dict: + """Get current runtime configuration. + + Returns: + Dictionary with runtime config. + """ + response = await self._client.get("/api/runtime/config") + response.raise_for_status() + return response.json() + + async def update_runtime_config(self, config: dict[str, Any]) -> dict: + """Update runtime configuration. + + Args: + config: New runtime configuration. + + Returns: + Dictionary with updated config. + """ + response = await self._client.put("/api/runtime/config", json=config) + response.raise_for_status() + return response.json() + + async def websocket_connect( + self, + run_id: str | None = None, + ) -> AsyncIterator[dict]: + """Connect to WebSocket for real-time updates. + + Args: + run_id: Optional run ID to subscribe to. + + Yields: + Dictionary with WebSocket messages. + """ + ws_url = self.base_url.replace("http", "ws") + "/ws" + if run_id: + ws_url += f"?run_id={run_id}" + + async with websockets.connect(ws_url) as ws: + async for message in ws: + yield json.loads(message) + + async def get_pipeline_status(self) -> dict: + """Get current pipeline execution status. + + Returns: + Dictionary with pipeline status. + """ + response = await self._client.get("/api/pipeline/status") + response.raise_for_status() + return response.json() + + async def trigger_pipeline( + self, + pipeline_type: str, + tickers: list[str], + config: dict[str, Any] | None = None, + ) -> dict: + """Trigger a pipeline execution. + + Args: + pipeline_type: Type of pipeline to run. + tickers: List of tickers to process. + config: Optional pipeline configuration. + + Returns: + Dictionary with pipeline trigger response. + """ + payload = {"pipeline_type": pipeline_type, "tickers": tickers} + if config: + payload["config"] = config + response = await self._client.post("/api/pipeline/trigger", json=payload) + response.raise_for_status() + return response.json() diff --git a/shared/client/news_client.py b/shared/client/news_client.py new file mode 100644 index 0000000..3cdf912 --- /dev/null +++ b/shared/client/news_client.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +"""News service client for news enrichment operations.""" + +import httpx + + +class NewsServiceClient: + """Async client for the News Service API.""" + + def __init__(self, base_url: str = "http://localhost:8002"): + """Initialize the client with a base URL. + + Args: + base_url: Base URL for the news service API. + """ + self.base_url = base_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + + async def __aenter__(self) -> "NewsServiceClient": + self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + if self._client: + await self._client.aclose() + + async def get_enriched_news( + self, + ticker: str, + start_date: str | None = None, + end_date: str | None = None, + ) -> dict: + """Get enriched news for a ticker. + + Args: + ticker: Stock ticker symbol. + start_date: Start date (YYYY-MM-DD). + end_date: End date (YYYY-MM-DD). + + Returns: + Dictionary with enriched news data. + """ + params = {"ticker": ticker} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + response = await self._client.get("/api/enriched-news", params=params) + response.raise_for_status() + return response.json() + + async def get_similar_days( + self, + ticker: str, + date: str, + n_similar: int = 5, + ) -> dict: + """Get similar trading days based on price patterns. + + Args: + ticker: Stock ticker symbol. + date: Reference date (YYYY-MM-DD). + n_similar: Number of similar days to return. + + Returns: + Dictionary with similar day data. + """ + params = {"ticker": ticker, "date": date, "n_similar": n_similar} + response = await self._client.get("/api/similar-days", params=params) + response.raise_for_status() + return response.json() + + async def get_story(self, story_id: str) -> dict: + """Get a specific news story by ID. + + Args: + story_id: The story identifier. + + Returns: + Dictionary with story data. + """ + response = await self._client.get(f"/api/stories/{story_id}") + response.raise_for_status() + return response.json() + + async def post_enrich(self, news_items: list[dict]) -> dict: + """Enrich news items with additional analysis. + + Args: + news_items: List of news items to enrich. + + Returns: + Dictionary with enriched news data. + """ + response = await self._client.post("/api/enrich", json=news_items) + response.raise_for_status() + return response.json() + + async def get_categories(self) -> dict: + """Get available news categories. + + Returns: + Dictionary with available categories. + """ + response = await self._client.get("/api/categories") + response.raise_for_status() + return response.json() + + async def search_news( + self, + query: str, + ticker: str | None = None, + limit: int = 10, + ) -> dict: + """Search news articles. + + Args: + query: Search query string. + ticker: Optional ticker to filter by. + limit: Maximum number of results. + + Returns: + Dictionary with search results. + """ + params = {"query": query, "limit": limit} + if ticker: + params["ticker"] = ticker + response = await self._client.get("/api/search", params=params) + response.raise_for_status() + return response.json() diff --git a/shared/client/trading_client.py b/shared/client/trading_client.py new file mode 100644 index 0000000..509731a --- /dev/null +++ b/shared/client/trading_client.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +"""Trading service client for market data operations.""" + +import httpx + +from shared.schema.price import PriceResponse +from shared.schema.financial import FinancialMetricsResponse, LineItemResponse +from shared.schema.market import InsiderTradeResponse, CompanyFactsResponse +from shared.schema.portfolio import Portfolio + + +class TradingServiceClient: + """Async client for the Trading Service API.""" + + def __init__(self, base_url: str = "http://localhost:8001"): + """Initialize the client with a base URL. + + Args: + base_url: Base URL for the trading service API. + """ + self.base_url = base_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + + async def __aenter__(self) -> "TradingServiceClient": + self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + if self._client: + await self._client.aclose() + + async def get_prices( + self, + ticker: str, + start_date: str | None = None, + end_date: str | None = None, + ) -> PriceResponse: + """Get price data for a ticker. + + Args: + ticker: Stock ticker symbol. + start_date: Start date (YYYY-MM-DD). + end_date: End date (YYYY-MM-DD). + + Returns: + PriceResponse with price data. + """ + params = {"ticker": ticker} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + response = await self._client.get("/api/prices", params=params) + response.raise_for_status() + return PriceResponse.model_validate(response.json()) + + async def get_news( + self, + ticker: str, + start_date: str | None = None, + end_date: str | None = None, + ) -> dict: + """Get news for a ticker. + + Args: + ticker: Stock ticker symbol. + start_date: Start date (YYYY-MM-DD). + end_date: End date (YYYY-MM-DD). + + Returns: + Dictionary with news data. + """ + params = {"ticker": ticker} + if start_date: + params["start_date"] = start_date + if end_date: + params["end_date"] = end_date + response = await self._client.get("/api/news", params=params) + response.raise_for_status() + return response.json() + + async def get_financials( + self, + ticker: str, + period: str | None = None, + limit: int | None = None, + ) -> FinancialMetricsResponse: + """Get financial metrics for a ticker. + + Args: + ticker: Stock ticker symbol. + period: Reporting period (e.g., "annual", "quarterly"). + limit: Maximum number of records to return. + + Returns: + FinancialMetricsResponse with financial data. + """ + params = {"ticker": ticker} + if period: + params["period"] = period + if limit: + params["limit"] = limit + response = await self._client.get("/api/financials", params=params) + response.raise_for_status() + return FinancialMetricsResponse.model_validate(response.json()) + + async def get_insider_trades( + self, + ticker: str, + limit: int | None = None, + ) -> InsiderTradeResponse: + """Get insider trades for a ticker. + + Args: + ticker: Stock ticker symbol. + limit: Maximum number of records to return. + + Returns: + InsiderTradeResponse with insider trade data. + """ + params = {"ticker": ticker} + if limit: + params["limit"] = limit + response = await self._client.get("/api/insider-trades", params=params) + response.raise_for_status() + return InsiderTradeResponse.model_validate(response.json()) + + async def get_portfolio(self) -> Portfolio: + """Get the current portfolio. + + Returns: + Portfolio with current positions and cash. + """ + response = await self._client.get("/api/portfolio") + response.raise_for_status() + return Portfolio.model_validate(response.json()) + + async def post_trades(self, trades: list[dict]) -> dict: + """Submit trades for execution. + + Args: + trades: List of trade orders. + + Returns: + Dictionary with trade execution results. + """ + response = await self._client.post("/api/trades", json=trades) + response.raise_for_status() + return response.json() + + async def post_settle(self) -> dict: + """Settle all pending trades. + + Returns: + Dictionary with settlement results. + """ + response = await self._client.post("/api/settle") + response.raise_for_status() + return response.json() + + async def get_market_status(self) -> dict: + """Get current market status. + + Returns: + Dictionary with market status information. + """ + response = await self._client.get("/api/market/status") + response.raise_for_status() + return response.json() + + async def get_company_facts(self, ticker: str) -> CompanyFactsResponse: + """Get company facts for a ticker. + + Args: + ticker: Stock ticker symbol. + + Returns: + CompanyFactsResponse with company information. + """ + response = await self._client.get(f"/api/company/{ticker}/facts") + response.raise_for_status() + return CompanyFactsResponse.model_validate(response.json()) + + async def get_line_items( + self, + ticker: str, + statement_type: str | None = None, + period: str | None = None, + ) -> LineItemResponse: + """Get line items (financial statement data) for a ticker. + + Args: + ticker: Stock ticker symbol. + statement_type: Type of statement (income, balance, cash_flow). + period: Reporting period. + + Returns: + LineItemResponse with financial statement data. + """ + params = {"ticker": ticker} + if statement_type: + params["statement_type"] = statement_type + if period: + params["period"] = period + response = await self._client.get("/api/line-items", params=params) + response.raise_for_status() + return LineItemResponse.model_validate(response.json()) diff --git a/shared/schema/__init__.py b/shared/schema/__init__.py new file mode 100644 index 0000000..d975743 --- /dev/null +++ b/shared/schema/__init__.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +"""Shared schema package for EvoTraders services.""" + +from shared.schema.price import Price, PriceResponse +from shared.schema.financial import ( + FinancialMetrics, + FinancialMetricsResponse, + LineItem, + LineItemResponse, +) +from shared.schema.portfolio import Position, Portfolio +from shared.schema.signals import ( + AnalystSignal, + TickerAnalysis, + AgentStateData, + AgentStateMetadata, +) +from shared.schema.market import ( + InsiderTrade, + InsiderTradeResponse, + CompanyNews, + CompanyNewsResponse, + CompanyFacts, + CompanyFactsResponse, +) + +__all__ = [ + # Price + "Price", + "PriceResponse", + # Financial + "FinancialMetrics", + "FinancialMetricsResponse", + "LineItem", + "LineItemResponse", + # Portfolio + "Position", + "Portfolio", + # Signals + "AnalystSignal", + "TickerAnalysis", + "AgentStateData", + "AgentStateMetadata", + # Market + "InsiderTrade", + "InsiderTradeResponse", + "CompanyNews", + "CompanyNewsResponse", + "CompanyFacts", + "CompanyFactsResponse", +] diff --git a/shared/schema/financial.py b/shared/schema/financial.py new file mode 100644 index 0000000..ad6f840 --- /dev/null +++ b/shared/schema/financial.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Financial-related schemas.""" + +from pydantic import BaseModel + + +class FinancialMetrics(BaseModel): + ticker: str + report_period: str + period: str + currency: str + market_cap: float | None + enterprise_value: float | None + price_to_earnings_ratio: float | None + price_to_book_ratio: float | None + price_to_sales_ratio: float | None + enterprise_value_to_ebitda_ratio: float | None + enterprise_value_to_revenue_ratio: float | None + free_cash_flow_yield: float | None + peg_ratio: float | None + gross_margin: float | None + operating_margin: float | None + net_margin: float | None + return_on_equity: float | None + return_on_assets: float | None + return_on_invested_capital: float | None + asset_turnover: float | None + inventory_turnover: float | None + receivables_turnover: float | None + days_sales_outstanding: float | None + operating_cycle: float | None + working_capital_turnover: float | None + current_ratio: float | None + quick_ratio: float | None + cash_ratio: float | None + operating_cash_flow_ratio: float | None + debt_to_equity: float | None + debt_to_assets: float | None + interest_coverage: float | None + revenue_growth: float | None + earnings_growth: float | None + book_value_growth: float | None + earnings_per_share_growth: float | None + free_cash_flow_growth: float | None + operating_income_growth: float | None + ebitda_growth: float | None + payout_ratio: float | None + earnings_per_share: float | None + book_value_per_share: float | None + free_cash_flow_per_share: float | None + + +class FinancialMetricsResponse(BaseModel): + financial_metrics: list[FinancialMetrics] + + +class LineItem(BaseModel): + ticker: str + report_period: str + period: str + currency: str + + # Allow additional fields dynamically + model_config = {"extra": "allow"} + + +class LineItemResponse(BaseModel): + search_results: list[LineItem] diff --git a/shared/schema/market.py b/shared/schema/market.py new file mode 100644 index 0000000..7890ca7 --- /dev/null +++ b/shared/schema/market.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +"""Market data-related schemas.""" + +from pydantic import BaseModel + + +class InsiderTrade(BaseModel): + ticker: str + issuer: str | None = None + name: str | None = None + title: str | None = None + is_board_director: bool | None = None + transaction_date: str | None = None + transaction_shares: float | None = None + transaction_price_per_share: float | None = None + transaction_value: float | None = None + shares_owned_before_transaction: float | None = None + shares_owned_after_transaction: float | None = None + security_title: str | None = None + filing_date: str + + +class InsiderTradeResponse(BaseModel): + insider_trades: list[InsiderTrade] + + +class CompanyNews(BaseModel): + category: str | None = None + ticker: str + title: str + related: str | None = None + source: str + date: str | None = None + url: str + summary: str | None = None + + +class CompanyNewsResponse(BaseModel): + news: list[CompanyNews] + + +class CompanyFacts(BaseModel): + ticker: str + name: str + cik: str | None = None + industry: str | None = None + sector: str | None = None + category: str | None = None + exchange: str | None = None + is_active: bool | None = None + listing_date: str | None = None + location: str | None = None + market_cap: float | None = None + number_of_employees: int | None = None + sec_filings_url: str | None = None + sic_code: str | None = None + sic_industry: str | None = None + sic_sector: str | None = None + website_url: str | None = None + weighted_average_shares: int | None = None + + +class CompanyFactsResponse(BaseModel): + company_facts: CompanyFacts diff --git a/shared/schema/portfolio.py b/shared/schema/portfolio.py new file mode 100644 index 0000000..1ec108b --- /dev/null +++ b/shared/schema/portfolio.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +"""Portfolio-related schemas.""" + +from pydantic import BaseModel + + +class Position(BaseModel): + """Position information - for Portfolio mode""" + + long: int = 0 # Long position quantity (shares) + short: int = 0 # Short position quantity (shares) + long_cost_basis: float = 0.0 # Long position average cost + short_cost_basis: float = 0.0 # Short position average cost + + +class Portfolio(BaseModel): + """Portfolio - for Portfolio mode""" + + cash: float = 100000.0 # Available cash + positions: dict[str, Position] = {} # ticker -> Position mapping + # Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin) + margin_requirement: float = 0.0 + margin_used: float = 0.0 # Margin used diff --git a/shared/schema/price.py b/shared/schema/price.py new file mode 100644 index 0000000..e5647a3 --- /dev/null +++ b/shared/schema/price.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +"""Price-related schemas.""" + +from pydantic import BaseModel + + +class Price(BaseModel): + open: float + close: float + high: float + low: float + volume: int + time: str + + +class PriceResponse(BaseModel): + ticker: str + prices: list[Price] diff --git a/shared/schema/signals.py b/shared/schema/signals.py new file mode 100644 index 0000000..b64d603 --- /dev/null +++ b/shared/schema/signals.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +"""Signal and analysis-related schemas.""" + +from pydantic import BaseModel + +from shared.schema.portfolio import Portfolio + + +class AnalystSignal(BaseModel): + signal: str | None = None + confidence: float | None = None + reasoning: dict | str | None = None + # Extended fields for richer signal information + reasons: list[str] | None = None # Core drivers/reasons for the signal + risks: list[str] | None = None # Key risk factors + invalidation: str | None = None # Conditions that would invalidate the thesis + next_action: str | None = None # Suggested next action for PM + # Valuation-related fields + intrinsic_value: float | None = None # DCF intrinsic value + fair_value_range: dict | None = None # {bear, base, bull} fair value range + value_gap_pct: float | None = None # Value gap percentage + valuation_methods: list[str] | None = None # List of valuation methods used + max_position_size: float | None = None # For risk management signals + + +class TickerAnalysis(BaseModel): + ticker: str + analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping + + +class AgentStateData(BaseModel): + tickers: list[str] + portfolio: Portfolio + start_date: str + end_date: str + ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping + + +class AgentStateMetadata(BaseModel): + show_reasoning: bool = False + model_config = {"extra": "allow"} diff --git a/start-dev.sh b/start-dev.sh new file mode 100755 index 0000000..e143e95 --- /dev/null +++ b/start-dev.sh @@ -0,0 +1,51 @@ +#!/bin/bash +# EvoTraders Development Startup Script +# Single-process FastAPI with auto-reload + +set -e + +echo "==========================================" +echo "EvoTraders Development Environment" +echo "==========================================" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Check virtual environment +if [ -z "$VIRTUAL_ENV" ]; then + echo -e "${YELLOW}Warning: Virtual environment not activated${NC}" + echo "Activating .venv..." + source .venv/bin/activate +fi + +# Load environment variables +if [ -f .env ]; then + echo -e "${GREEN}Loading environment from .env${NC}" + export $(grep -v '^#' .env | xargs) +else + echo -e "${YELLOW}Warning: .env file not found${NC}" +fi + +# Check required environment variables +if [ -z "$OPENAI_API_KEY" ]; then + echo -e "${RED}Error: OPENAI_API_KEY not set${NC}" + echo "Please set it in .env file or environment" + exit 1 +fi + +echo "" +echo "Starting EvoTraders API (FastAPI) on port 8000..." +echo "API Endpoints: http://localhost:8000/docs" +echo "" + +# Start FastAPI with auto-reload for development +cd /Users/cillin/workspeace/evotraders +python -m uvicorn backend.app:app \ + --host 0.0.0.0 \ + --port 8000 \ + --reload \ + --reload-dir backend \ + --log-level info