Files
evotraders/backend/agents/base/hooks.py

614 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# -*- coding: utf-8 -*-
"""Hook system for EvoAgent.
Provides pre_reasoning and post_acting hooks with built-in implementations:
- BootstrapHook: First-time setup guidance
- MemoryCompactionHook: Automatic memory compression
Based on CoPaw's hooks design.
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from agentscope.agent import ReActAgent
logger = logging.getLogger(__name__)
# Hook types
HookType = str
HOOK_PRE_REASONING: HookType = "pre_reasoning"
HOOK_POST_ACTING: HookType = "post_acting"
class Hook(ABC):
"""Abstract base class for agent hooks."""
@abstractmethod
async def __call__(
self,
agent: "ReActAgent",
kwargs: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Execute the hook.
Args:
agent: The agent instance
kwargs: Input arguments to the method being hooked
Returns:
Modified kwargs or None to use original
"""
pass
class HookManager:
"""Manages agent hooks.
Provides registration and execution of hooks for different
lifecycle events in the agent's operation.
"""
def __init__(self):
self._hooks: Dict[HookType, List[tuple[str, Hook]]] = {
HOOK_PRE_REASONING: [],
HOOK_POST_ACTING: [],
}
def register(
self,
hook_type: HookType,
hook_name: str,
hook: Hook | Callable,
) -> None:
"""Register a hook.
Args:
hook_type: Type of hook (pre_reasoning, post_acting)
hook_name: Unique name for this hook
hook: Hook instance or callable
"""
# Remove existing hook with same name
self._hooks[hook_type] = [
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
]
self._hooks[hook_type].append((hook_name, hook))
logger.debug("Registered hook '%s' for type '%s'", hook_name, hook_type)
def unregister(self, hook_type: HookType, hook_name: str) -> bool:
"""Unregister a hook.
Args:
hook_type: Type of hook
hook_name: Name of the hook to remove
Returns:
True if hook was found and removed
"""
original_len = len(self._hooks[hook_type])
self._hooks[hook_type] = [
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
]
removed = len(self._hooks[hook_type]) < original_len
if removed:
logger.debug("Unregistered hook '%s' from type '%s'", hook_name, hook_type)
return removed
async def execute(
self,
hook_type: HookType,
agent: "ReActAgent",
kwargs: Dict[str, Any],
) -> Dict[str, Any]:
"""Execute all hooks of a given type.
Args:
hook_type: Type of hooks to execute
agent: The agent instance
kwargs: Input arguments
Returns:
Potentially modified kwargs
"""
for name, hook in self._hooks[hook_type]:
try:
result = await hook(agent, kwargs)
if result is not None:
kwargs = result
except Exception as e:
logger.error("Hook '%s' failed: %s", name, e, exc_info=True)
return kwargs
def list_hooks(self, hook_type: Optional[HookType] = None) -> List[str]:
"""List registered hook names.
Args:
hook_type: Optional type to filter by
Returns:
List of hook names
"""
if hook_type:
return [name for name, _ in self._hooks.get(hook_type, [])]
names = []
for hooks in self._hooks.values():
names.extend([name for name, _ in hooks])
return names
class BootstrapHook(Hook):
"""Hook for bootstrap guidance on first user interaction.
This hook looks for a BOOTSTRAP.md file in the working directory
and if found, prepends guidance to the first user message to help
establish the agent's identity and user preferences.
"""
def __init__(
self,
workspace_dir: Path,
language: str = "zh",
):
"""Initialize bootstrap hook.
Args:
workspace_dir: Working directory containing BOOTSTRAP.md
language: Language code for bootstrap guidance (en/zh)
"""
self.workspace_dir = Path(workspace_dir)
self.language = language
self._completed_flag = self.workspace_dir / ".bootstrap_completed"
def _is_first_user_interaction(self, agent: "ReActAgent") -> bool:
"""Check if this is the first user interaction.
Args:
agent: The agent instance
Returns:
True if first user interaction
"""
if not hasattr(agent, "memory") or not agent.memory.content:
return True
# Count user messages (excluding system)
user_count = sum(
1 for msg, _ in agent.memory.content if msg.role == "user"
)
return user_count <= 1
def _build_bootstrap_guidance(self) -> str:
"""Build bootstrap guidance message.
Returns:
Formatted bootstrap guidance
"""
if self.language == "zh":
return (
"# 引导模式\n"
"\n"
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
"\n"
"1. 阅读 BOOTSTRAP.md友好地表示初次见面"
"引导用户完成设置。\n"
"2. 按照 BOOTSTRAP.md 的指示,"
"帮助用户定义你的身份和偏好。\n"
"3. 按指南创建/更新必要文件"
"PROFILE.md、MEMORY.md 等)。\n"
"4. 完成后删除 BOOTSTRAP.md。\n"
"\n"
"如果用户希望跳过,直接回答下面的问题即可。\n"
"\n"
"---\n"
"\n"
)
return (
"# BOOTSTRAP MODE\n"
"\n"
"`BOOTSTRAP.md` exists — first-time setup.\n"
"\n"
"1. Read BOOTSTRAP.md, greet the user, "
"and guide them through setup.\n"
"2. Follow BOOTSTRAP.md instructions "
"to define identity and preferences.\n"
"3. Create/update files "
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
"4. Delete BOOTSTRAP.md when done.\n"
"\n"
"If the user wants to skip, answer their "
"question directly instead.\n"
"\n"
"---\n"
"\n"
)
async def __call__(
self,
agent: "ReActAgent",
kwargs: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Check and load BOOTSTRAP.md on first user interaction.
Args:
agent: The agent instance
kwargs: Input arguments to the _reasoning method
Returns:
None (hook doesn't modify kwargs)
"""
try:
bootstrap_path = self.workspace_dir / "BOOTSTRAP.md"
# Check if bootstrap has already been triggered
if self._completed_flag.exists():
return None
if not bootstrap_path.exists():
return None
if not self._is_first_user_interaction(agent):
return None
bootstrap_guidance = self._build_bootstrap_guidance()
logger.debug("Found BOOTSTRAP.md [%s], prepending guidance", self.language)
# Prepend to first user message in memory
if hasattr(agent, "memory") and agent.memory.content:
system_count = sum(
1 for msg, _ in agent.memory.content if msg.role == "system"
)
for msg, _ in agent.memory.content[system_count:]:
if msg.role == "user":
# Prepend guidance to message content
original_content = msg.content
msg.content = bootstrap_guidance + original_content
break
logger.debug("Bootstrap guidance prepended to first user message")
# Create completion flag to prevent repeated triggering
self._completed_flag.touch()
logger.debug("Created bootstrap completion flag")
except Exception as e:
logger.error("Failed to process bootstrap: %s", e, exc_info=True)
return None
class WorkspaceWatchHook(Hook):
"""Hook for auto-reloading workspace markdown files on change.
Monitors SOUL.md, AGENTS.md, PROFILE.md, etc. and triggers
a prompt rebuild when any of them change. Based on CoPaw's
AgentConfigWatcher approach but for markdown files.
"""
# Files to monitor (same as PromptBuilder.DEFAULT_FILES)
WATCHED_FILES = frozenset([
"SOUL.md", "AGENTS.md", "PROFILE.md",
"POLICY.md", "MEMORY.md",
"BOOTSTRAP.md",
])
def __init__(
self,
workspace_dir: Path,
poll_interval: float = 2.0,
):
"""Initialize workspace watch hook.
Args:
workspace_dir: Workspace directory to monitor
poll_interval: How often to check for changes (seconds)
"""
self.workspace_dir = Path(workspace_dir)
self.poll_interval = poll_interval
self._last_mtimes: dict[str, float] = {}
self._initialized = False
def _scan_mtimes(self) -> dict[str, float]:
"""Scan watched files and return their current mtimes."""
mtimes = {}
for name in self.WATCHED_FILES:
path = self.workspace_dir / name
if path.exists():
mtimes[name] = path.stat().st_mtime
return mtimes
def _has_changes(self) -> bool:
"""Check if any watched file has changed since last check."""
current = self._scan_mtimes()
if not self._initialized:
self._last_mtimes = current
self._initialized = True
return False
# Check for new, modified, or deleted files
if set(current.keys()) != set(self._last_mtimes.keys()):
self._last_mtimes = current
return True
for name, mtime in current.items():
if mtime != self._last_mtimes.get(name):
self._last_mtimes = current
return True
return False
async def __call__(
self,
agent: "ReActAgent",
kwargs: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Check for file changes and rebuild prompt if needed.
Args:
agent: The agent instance
kwargs: Input arguments (unused)
Returns:
None
"""
try:
if self._has_changes():
logger.info(
"Workspace files changed, triggering prompt rebuild for: %s",
getattr(agent, "agent_id", "unknown"),
)
if hasattr(agent, "rebuild_sys_prompt"):
agent.rebuild_sys_prompt()
else:
logger.warning(
"Agent %s has no rebuild_sys_prompt method",
getattr(agent, "agent_id", "unknown"),
)
except Exception as e:
logger.error("Workspace watch hook failed: %s", e, exc_info=True)
return None
class MemoryCompactionHook(Hook):
"""Hook for automatic memory compaction when context is full.
This hook monitors the token count of messages and triggers compaction
when it exceeds the threshold. It preserves the system prompt and recent
messages while summarizing older conversation history.
Based on CoPaw's memory compaction design with additional improvements:
- memory_compact_ratio: Ratio to compact when threshold reached
- memory_reserve_ratio: Always keep a reserve of tokens for recent messages
- enable_tool_result_compact: Compact tool results separately
- tool_result_compact_keep_n: Number of tool results to keep
"""
def __init__(
self,
memory_manager: Any,
memory_compact_threshold: Optional[int] = None,
memory_compact_ratio: float = 0.75,
memory_reserve_ratio: float = 0.1,
enable_tool_result_compact: bool = False,
tool_result_compact_keep_n: int = 5,
):
"""Initialize memory compaction hook.
Args:
memory_manager: Memory manager instance for compaction
memory_compact_threshold: Token threshold for compaction
memory_compact_ratio: Target ratio to compact to (e.g., 0.75 = compact to 75%)
memory_reserve_ratio: Reserve ratio to always keep free (e.g., 0.1 = 10%)
enable_tool_result_compact: Enable tool result compaction
tool_result_compact_keep_n: Number of tool results to keep
"""
self.memory_manager = memory_manager
self.memory_compact_threshold = memory_compact_threshold
self.memory_compact_ratio = memory_compact_ratio
self.memory_reserve_ratio = memory_reserve_ratio
self.enable_tool_result_compact = enable_tool_result_compact
self.tool_result_compact_keep_n = tool_result_compact_keep_n
async def __call__(
self,
agent: "ReActAgent",
kwargs: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
"""Pre-reasoning hook to check and compact memory if needed.
Args:
agent: The agent instance
kwargs: Input arguments to the _reasoning method
Returns:
None (hook doesn't modify kwargs)
"""
try:
if not hasattr(agent, "memory") or not self.memory_manager:
return None
memory = agent.memory
# Get current token count estimate
messages = await memory.get_memory()
total_tokens = self._estimate_tokens(messages)
if self.memory_compact_threshold is None:
return None
if total_tokens < self.memory_compact_threshold:
return None
logger.info(
"Memory compaction triggered: %d tokens (threshold: %d)",
total_tokens,
self.memory_compact_threshold,
)
# Compact memory
await self._compact_memory(agent, messages)
except Exception as e:
logger.error("Failed to compact memory: %s", e, exc_info=True)
return None
def _estimate_tokens(self, messages: List[Any]) -> int:
"""Estimate token count for messages.
Args:
messages: List of messages
Returns:
Estimated token count
"""
# Simple estimation: ~4 chars per token
total_chars = sum(
len(str(getattr(msg, "content", "")))
for msg in messages
)
return total_chars // 4
async def _compact_memory(
self,
agent: "ReActAgent",
messages: List[Any],
) -> None:
"""Compact memory by summarizing older messages.
Uses CoPaw-style memory management:
- memory_compact_ratio: Target ratio to compact to (e.g., 0.75 means compact to 75%)
- memory_reserve_ratio: Always keep this ratio free (e.g., 0.1 means keep 10% for recent)
Args:
agent: The agent instance
messages: Current messages in memory
"""
if self.memory_compact_threshold is None:
return
# Estimate total tokens
total_tokens = self._estimate_tokens(messages)
# Calculate reserve based on ratio (CoPaw-style)
reserve_tokens = int(total_tokens * self.memory_reserve_ratio)
# Calculate target tokens after compaction
target_tokens = int(total_tokens * self.memory_compact_ratio)
target_tokens = max(target_tokens, total_tokens - reserve_tokens)
# Find messages to compact (older ones)
# Keep recent messages that fit within target
messages_to_compact = []
kept_tokens = 0
# Start from oldest, stop when we've kept enough
for msg in messages:
msg_tokens = self._estimate_tokens([msg])
if kept_tokens + msg_tokens > target_tokens:
messages_to_compact.append(msg)
else:
kept_tokens += msg_tokens
if not messages_to_compact:
return
logger.info(
"Compacting %d messages (%d tokens) to target %d tokens",
len(messages_to_compact),
self._estimate_tokens(messages_to_compact),
target_tokens,
)
# Use memory manager to compact if available
if hasattr(self.memory_manager, "compact_memory"):
try:
summary = await self.memory_manager.compact_memory(
messages=messages_to_compact,
)
logger.info(
"Memory compacted: %d messages summarized, summary: %s",
len(messages_to_compact),
summary[:200] if summary else "N/A",
)
# Mark messages as compressed if supported
if hasattr(agent.memory, "update_messages_mark"):
from agentscope.agent._react_agent import _MemoryMark
await agent.memory.update_messages_mark(
new_mark=_MemoryMark.COMPRESSED,
msg_ids=[msg.id for msg in messages_to_compact],
)
except Exception as e:
logger.error("Memory manager compaction failed: %s", e)
# Tool result compaction (CoPaw-style)
if self.enable_tool_result_compact:
await self._compact_tool_results(agent, messages)
async def _compact_tool_results(
self,
agent: "ReActAgent",
messages: List[Any],
) -> None:
"""Compact tool results by keeping only recent ones.
Based on CoPaw's tool_result_compact_keep_n pattern.
Tool results can be very verbose, so we keep only the N most recent ones.
Args:
agent: The agent instance
messages: Current messages in memory
"""
if not hasattr(agent.memory, "content"):
return
# Find tool result messages (usually have "tool" role or tool_related content)
tool_results = []
for msg, _ in agent.memory.content:
if hasattr(msg, "role") and msg.role == "tool":
tool_results.append(msg)
if len(tool_results) <= self.tool_result_compact_keep_n:
return
# Keep only the most recent N tool results
excess_results = tool_results[:-self.tool_result_compact_keep_n]
logger.info(
"Tool result compaction: %d tool results found, keeping %d, compacting %d",
len(tool_results),
self.tool_result_compact_keep_n,
len(excess_results),
)
# Mark excess tool results as compressed if supported
if hasattr(agent.memory, "update_messages_mark"):
from agentscope.agent._react_agent import _MemoryMark
await agent.memory.update_messages_mark(
new_mark=_MemoryMark.COMPRESSED,
msg_ids=[msg.id for msg in excess_results],
)
__all__ = [
"Hook",
"HookManager",
"HookType",
"HOOK_PRE_REASONING",
"HOOK_POST_ACTING",
"BootstrapHook",
"MemoryCompactionHook",
"WorkspaceWatchHook",
]