# -*- coding: utf-8 -*- """TeamCoordinator - Agent lifecycle management and execution. Provides run_parallel() using asyncio.gather() and run_sequential() for coordinating multiple agents. """ from __future__ import annotations import asyncio import logging from typing import Any, Awaitable, Callable, Dict, List, Optional, Type from agentscope.message import Msg logger = logging.getLogger(__name__) class TeamCoordinator: """Coordinates agent lifecycle and execution. Supports: - run_parallel(): Execute multiple agents concurrently with asyncio.gather() - run_sequential(): Execute agents one after another - run_phase(): Execute a named phase with registered agents - register_agent(): Add agent to coordinator - unregister_agent(): Remove agent from coordinator Each agent maintains separate context/memory. """ def __init__( self, participants: Optional[List[Any]] = None, task_content: Optional[str] = None, messenger: Optional[Any] = None, registry: Optional[Any] = None, ): """Initialize TeamCoordinator. Args: participants: List of agent instances to coordinate task_content: Task description content for the agents messenger: AgentMessenger for communication (optional) registry: AgentRegistry for agent lookup (optional) """ self._participants = participants or [] self._task_content = task_content or "" self._messenger = messenger self._registry = registry self._agents: Dict[str, Any] = {} self._running_tasks: Dict[str, asyncio.Task] = {} # Auto-register participants for agent in self._participants: if hasattr(agent, "name"): self._agents[agent.name] = agent elif hasattr(agent, "id"): self._agents[agent.id] = agent def register_agent(self, agent_id: str, agent: Any) -> None: """Register an agent with the coordinator. Args: agent_id: Unique agent identifier agent: Agent instance """ self._agents[agent_id] = agent logger.info("Registered agent: %s", agent_id) def unregister_agent(self, agent_id: str) -> None: """Unregister an agent from the coordinator. Args: agent_id: Agent identifier to remove """ if agent_id in self._agents: del self._agents[agent_id] logger.info("Unregistered agent: %s", agent_id) def get_agent(self, agent_id: str) -> Any: """Get registered agent by ID. Args: agent_id: Agent identifier Returns: Agent instance """ return self._agents.get(agent_id) def list_agents(self) -> List[str]: """List all registered agent IDs. Returns: List of agent identifiers """ return list(self._agents.keys()) async def run_parallel( self, agent_ids: List[str], initial_message: Optional[Msg] = None, ) -> Dict[str, Any]: """Run multiple agents in parallel using asyncio.gather(). Args: agent_ids: List of agent IDs to run concurrently initial_message: Optional initial message to broadcast Returns: Dict mapping agent_id to result """ async def _run_agent(aid: str) -> tuple[str, Any]: agent = self._agents.get(aid) if agent is None: logger.error("Agent %s not found", aid) return (aid, None) try: if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): if initial_message: result = await agent.reply(initial_message) else: result = await agent.reply() elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): result = await agent.run() else: result = await agent() logger.info("Agent %s completed successfully", aid) return (aid, result) except Exception as e: logger.error("Agent %s failed: %s", aid, e) return (aid, {"error": str(e)}) results = await asyncio.gather( *[_run_agent(aid) for aid in agent_ids], return_exceptions=True, ) output: Dict[str, Any] = {} for result in results: if isinstance(result, tuple): agent_id, agent_result = result output[agent_id] = agent_result else: logger.error("Unexpected result from asyncio.gather: %s", result) logger.info("Parallel run completed for %d agents", len(agent_ids)) return output async def run_sequential( self, agent_ids: List[str], initial_message: Optional[Msg] = None, ) -> Dict[str, Any]: """Run agents one after another in order. Args: agent_ids: List of agent IDs to run in sequence initial_message: Optional initial message for first agent Returns: Dict mapping agent_id to result """ output: Dict[str, Any] = {} current_message = initial_message for agent_id in agent_ids: agent = self._agents.get(agent_id) if agent is None: logger.error("Agent %s not found", agent_id) output[agent_id] = {"error": "Agent not found"} continue try: if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): result = await agent.reply(current_message) elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): result = await agent.run() else: result = await agent() output[agent_id] = result current_message = result logger.info("Agent %s completed sequentially", agent_id) except Exception as e: logger.error("Agent %s failed: %s", agent_id, e) output[agent_id] = {"error": str(e)} break logger.info("Sequential run completed for %d agents", len(agent_ids)) return output async def run_phase( self, phase_name: str, agent_ids: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, ) -> List[Any]: """Execute a named phase with registered agents. Args: phase_name: Name of the phase (e.g., "analyst_analysis") agent_ids: Optional list of agent IDs; if None, uses all registered metadata: Optional metadata to include in the message (e.g., tickers, date) Returns: List of results from each agent """ if agent_ids is None: agent_ids = list(self._agents.keys()) _agent_ids = [aid for aid in agent_ids if aid in self._agents] logger.info( "Running phase '%s' with %d agents: %s", phase_name, len(_agent_ids), _agent_ids, ) # Create messages for each agent results: List[Any] = [] for agent_id in _agent_ids: agent = self._agents[agent_id] try: if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): # Create a message for the agent with proper structure msg = Msg( name="system", content=self._task_content or f"Please execute phase: {phase_name}", role="user", metadata=metadata, ) result = await agent.reply(msg) elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): result = await agent.run() else: result = await agent() results.append(result) logger.info("Phase '%s': Agent %s completed", phase_name, agent_id) except Exception as e: logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e) results.append(None) logger.info("Phase '%s' completed with %d results", phase_name, len(results)) return results async def run_with_dependencies( self, agent_tasks: Dict[str, List[str]], initial_message: Optional[Msg] = None, ) -> Dict[str, Any]: """Run agents respecting dependency graph. Args: agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids initial_message: Optional initial message Returns: Dict mapping agent_id to result """ completed: Dict[str, Any] = {} remaining = set(agent_tasks.keys()) while remaining: ready = [ aid for aid in remaining if all(dep in completed for dep in agent_tasks.get(aid, [])) ] if not ready: logger.error("Circular dependency detected in agent tasks") for aid in remaining: completed[aid] = {"error": "Circular dependency"} break results = await self.run_parallel(ready, initial_message) completed.update(results) for aid in ready: remaining.discard(aid) initial_message = results.get(aid) return completed async def fanout_pipeline( self, agents: List[Any], msg: Optional[Msg] = None, ) -> List[Msg]: """Fanout a message to multiple agents concurrently and collect all responses. Similar to AgentScope's fanout_pipeline, this sends the same message to all specified agents and returns a list of all agent responses. Args: agents: List of agent instances to fanout the message to msg: Message to send to all agents (optional) Returns: List of Msg responses from each agent (in the same order as input agents) Example: >>> responses = await fanout_pipeline( ... agents=[alice, bob, charlie], ... msg=question, ... ) >>> # responses is a list of Msg responses from each agent """ async def _fanout_to_agent(agent: Any) -> Optional[Msg]: """Send message to a single agent and return its response.""" try: if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply): result = await agent.reply(msg) if msg is not None else await agent.reply() elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run): result = await agent.run() else: result = await agent() # Convert result to Msg if needed if result is None: return None if isinstance(result, Msg): return result # If result is a dict with content, wrap it if isinstance(result, dict) and "content" in result: return Msg( name=getattr(agent, "name", "unknown"), content=result.get("content", ""), role="assistant", metadata=result.get("metadata"), ) # Otherwise wrap the result return Msg( name=getattr(agent, "name", "unknown"), content=str(result), role="assistant", ) except Exception as e: logger.error("Agent %s failed in fanout_pipeline: %s", getattr(agent, "name", "unknown"), e) return None # Run all agents concurrently results = await asyncio.gather( *[_fanout_to_agent(agent) for agent in agents], return_exceptions=True, ) # Filter out exceptions and keep only valid responses responses: List[Msg] = [] for i, result in enumerate(results): if isinstance(result, Exception): logger.error("Fanout to agent %d failed: %s", i, result) responses.append(None) # type: ignore[arg-type] else: responses.append(result) # type: ignore[arg-type] logger.info("Fanout pipeline completed for %d agents", len(agents)) return responses async def shutdown(self, timeout: Optional[float] = 5.0) -> None: """Shutdown all running agents gracefully. Args: timeout: Timeout for graceful shutdown """ logger.info("Shutting down TeamCoordinator...") cancel_tasks = [ asyncio.create_task(asyncio.wait_for(task, timeout=timeout)) for task in self._running_tasks.values() ] if cancel_tasks: await asyncio.gather(*cancel_tasks, return_exceptions=True) self._running_tasks.clear() logger.info("TeamCoordinator shutdown complete") @property def agents(self) -> Dict[str, Any]: """Get copy of registered agents dict.""" return dict(self._agents) __all__ = ["TeamCoordinator"]