Initial commit of integrated agent system
This commit is contained in:
389
backend/agents/team/team_coordinator.py
Normal file
389
backend/agents/team/team_coordinator.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TeamCoordinator - Agent lifecycle management and execution.
|
||||
|
||||
Provides run_parallel() using asyncio.gather() and run_sequential()
|
||||
for coordinating multiple agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamCoordinator:
|
||||
"""Coordinates agent lifecycle and execution.
|
||||
|
||||
Supports:
|
||||
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
|
||||
- run_sequential(): Execute agents one after another
|
||||
- run_phase(): Execute a named phase with registered agents
|
||||
- register_agent(): Add agent to coordinator
|
||||
- unregister_agent(): Remove agent from coordinator
|
||||
|
||||
Each agent maintains separate context/memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: Optional[List[Any]] = None,
|
||||
task_content: Optional[str] = None,
|
||||
messenger: Optional[Any] = None,
|
||||
registry: Optional[Any] = None,
|
||||
):
|
||||
"""Initialize TeamCoordinator.
|
||||
|
||||
Args:
|
||||
participants: List of agent instances to coordinate
|
||||
task_content: Task description content for the agents
|
||||
messenger: AgentMessenger for communication (optional)
|
||||
registry: AgentRegistry for agent lookup (optional)
|
||||
"""
|
||||
self._participants = participants or []
|
||||
self._task_content = task_content or ""
|
||||
self._messenger = messenger
|
||||
self._registry = registry
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Auto-register participants
|
||||
for agent in self._participants:
|
||||
if hasattr(agent, "name"):
|
||||
self._agents[agent.name] = agent
|
||||
elif hasattr(agent, "id"):
|
||||
self._agents[agent.id] = agent
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent with the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Registered agent: %s", agent_id)
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Unregister an agent from the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any:
|
||||
"""Get registered agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
async def run_parallel(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run multiple agents in parallel using asyncio.gather().
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run concurrently
|
||||
initial_message: Optional initial message to broadcast
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
async def _run_agent(aid: str) -> tuple[str, Any]:
|
||||
agent = self._agents.get(aid)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", aid)
|
||||
return (aid, None)
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
if initial_message:
|
||||
result = await agent.reply(initial_message)
|
||||
else:
|
||||
result = await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
logger.info("Agent %s completed successfully", aid)
|
||||
return (aid, result)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", aid, e)
|
||||
return (aid, {"error": str(e)})
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_run_agent(aid) for aid in agent_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
output: Dict[str, Any] = {}
|
||||
for result in results:
|
||||
if isinstance(result, tuple):
|
||||
agent_id, agent_result = result
|
||||
output[agent_id] = agent_result
|
||||
else:
|
||||
logger.error("Unexpected result from asyncio.gather: %s", result)
|
||||
|
||||
logger.info("Parallel run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_sequential(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents one after another in order.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run in sequence
|
||||
initial_message: Optional initial message for first agent
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
current_message = initial_message
|
||||
|
||||
for agent_id in agent_ids:
|
||||
agent = self._agents.get(agent_id)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", agent_id)
|
||||
output[agent_id] = {"error": "Agent not found"}
|
||||
continue
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(current_message)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
output[agent_id] = result
|
||||
current_message = result
|
||||
logger.info("Agent %s completed sequentially", agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", agent_id, e)
|
||||
output[agent_id] = {"error": str(e)}
|
||||
break
|
||||
|
||||
logger.info("Sequential run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_phase(
|
||||
self,
|
||||
phase_name: str,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Any]:
|
||||
"""Execute a named phase with registered agents.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the phase (e.g., "analyst_analysis")
|
||||
agent_ids: Optional list of agent IDs; if None, uses all registered
|
||||
metadata: Optional metadata to include in the message (e.g., tickers, date)
|
||||
|
||||
Returns:
|
||||
List of results from each agent
|
||||
"""
|
||||
if agent_ids is None:
|
||||
agent_ids = list(self._agents.keys())
|
||||
|
||||
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
|
||||
|
||||
logger.info(
|
||||
"Running phase '%s' with %d agents: %s",
|
||||
phase_name,
|
||||
len(_agent_ids),
|
||||
_agent_ids,
|
||||
)
|
||||
|
||||
# Create messages for each agent
|
||||
results: List[Any] = []
|
||||
for agent_id in _agent_ids:
|
||||
agent = self._agents[agent_id]
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
# Create a message for the agent with proper structure
|
||||
msg = Msg(
|
||||
name="system",
|
||||
content=self._task_content or f"Please execute phase: {phase_name}",
|
||||
role="user",
|
||||
metadata=metadata,
|
||||
)
|
||||
result = await agent.reply(msg)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
results.append(result)
|
||||
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
|
||||
except Exception as e:
|
||||
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
|
||||
results.append(None)
|
||||
|
||||
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
|
||||
return results
|
||||
|
||||
async def run_with_dependencies(
|
||||
self,
|
||||
agent_tasks: Dict[str, List[str]],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents respecting dependency graph.
|
||||
|
||||
Args:
|
||||
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
|
||||
initial_message: Optional initial message
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
completed: Dict[str, Any] = {}
|
||||
remaining = set(agent_tasks.keys())
|
||||
|
||||
while remaining:
|
||||
ready = [
|
||||
aid for aid in remaining
|
||||
if all(dep in completed for dep in agent_tasks.get(aid, []))
|
||||
]
|
||||
|
||||
if not ready:
|
||||
logger.error("Circular dependency detected in agent tasks")
|
||||
for aid in remaining:
|
||||
completed[aid] = {"error": "Circular dependency"}
|
||||
break
|
||||
|
||||
results = await self.run_parallel(ready, initial_message)
|
||||
completed.update(results)
|
||||
|
||||
for aid in ready:
|
||||
remaining.discard(aid)
|
||||
initial_message = results.get(aid)
|
||||
|
||||
return completed
|
||||
|
||||
async def fanout_pipeline(
|
||||
self,
|
||||
agents: List[Any],
|
||||
msg: Optional[Msg] = None,
|
||||
) -> List[Msg]:
|
||||
"""Fanout a message to multiple agents concurrently and collect all responses.
|
||||
|
||||
Similar to AgentScope's fanout_pipeline, this sends the same message
|
||||
to all specified agents and returns a list of all agent responses.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances to fanout the message to
|
||||
msg: Message to send to all agents (optional)
|
||||
|
||||
Returns:
|
||||
List of Msg responses from each agent (in the same order as input agents)
|
||||
|
||||
Example:
|
||||
>>> responses = await fanout_pipeline(
|
||||
... agents=[alice, bob, charlie],
|
||||
... msg=question,
|
||||
... )
|
||||
>>> # responses is a list of Msg responses from each agent
|
||||
"""
|
||||
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
|
||||
"""Send message to a single agent and return its response."""
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(msg) if msg is not None else await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
# Convert result to Msg if needed
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, Msg):
|
||||
return result
|
||||
# If result is a dict with content, wrap it
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=result.get("content", ""),
|
||||
role="assistant",
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
# Otherwise wrap the result
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=str(result),
|
||||
role="assistant",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed in fanout_pipeline: %s",
|
||||
getattr(agent, "name", "unknown"), e)
|
||||
return None
|
||||
|
||||
# Run all agents concurrently
|
||||
results = await asyncio.gather(
|
||||
*[_fanout_to_agent(agent) for agent in agents],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Filter out exceptions and keep only valid responses
|
||||
responses: List[Msg] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Fanout to agent %d failed: %s", i, result)
|
||||
responses.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
responses.append(result) # type: ignore[arg-type]
|
||||
|
||||
logger.info("Fanout pipeline completed for %d agents", len(agents))
|
||||
return responses
|
||||
|
||||
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
|
||||
"""Shutdown all running agents gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Timeout for graceful shutdown
|
||||
"""
|
||||
logger.info("Shutting down TeamCoordinator...")
|
||||
|
||||
cancel_tasks = [
|
||||
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
|
||||
for task in self._running_tasks.values()
|
||||
]
|
||||
|
||||
if cancel_tasks:
|
||||
await asyncio.gather(*cancel_tasks, return_exceptions=True)
|
||||
|
||||
self._running_tasks.clear()
|
||||
logger.info("TeamCoordinator shutdown complete")
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["TeamCoordinator"]
|
||||
Reference in New Issue
Block a user