feat: Add agent workspace system and runtime management
- Add agent core modules (agent_core, factory, registry, skill_loader) - Add runtime system for agent execution management - Add REST API for agents, workspaces, and runtime control - Add process supervisor for agent lifecycle management - Add workspace template system with agent profiles - Add frontend RuntimeView and runtime API integration - Add per-agent skill workspaces for smoke_fullstack run - Refactor skill system with active/installed separation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
23
backend/agents/base/__init__.py
Normal file
23
backend/agents/base/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Base agent module for EvoTraders.
|
||||
|
||||
提供Agent基础类、命令处理、工具守卫和钩子管理等功能。
|
||||
"""
|
||||
|
||||
# 命令处理器 (从command_handler.py导入)
|
||||
from .command_handler import (
|
||||
AgentCommandDispatcher,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
CommandResult,
|
||||
create_command_dispatcher,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 命令处理
|
||||
"AgentCommandDispatcher",
|
||||
"CommandContext",
|
||||
"CommandHandler",
|
||||
"CommandResult",
|
||||
"create_command_dispatcher",
|
||||
]
|
||||
543
backend/agents/base/command_handler.py
Normal file
543
backend/agents/base/command_handler.py
Normal file
@@ -0,0 +1,543 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent command handler for system commands.
|
||||
|
||||
This module handles system commands like /save, /compact, /skills, /reload, etc.
|
||||
参考CoPaw设计,为EvoAgent提供命令处理能力。
|
||||
"""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import EvoAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResult:
|
||||
"""命令执行结果"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class CommandContext:
|
||||
"""命令执行上下文"""
|
||||
|
||||
def __init__(self, agent: "EvoAgent", raw_query: str, args: str = ""):
|
||||
self.agent = agent
|
||||
self.raw_query = raw_query
|
||||
self.args = args
|
||||
self.config_name = getattr(agent, "config_name", "default")
|
||||
self.agent_id = getattr(agent, "agent_id", "unknown")
|
||||
|
||||
|
||||
class CommandHandler(ABC):
|
||||
"""命令处理器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
"""处理命令"""
|
||||
pass
|
||||
|
||||
|
||||
class SaveCommandHandler(CommandHandler):
|
||||
"""处理 /save <message> 命令 - 保存内容到MEMORY.md"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
message = ctx.args.strip()
|
||||
if not message:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /save <message>\n请提供要保存的内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
memory_path = self._get_memory_path(ctx)
|
||||
memory_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = self._get_timestamp()
|
||||
entry = f"\n## {timestamp}\n\n{message}\n"
|
||||
|
||||
with open(memory_path, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 内容已保存到 MEMORY.md\n- 路径: {memory_path}\n- 长度: {len(message)} 字符",
|
||||
data={"path": str(memory_path), "length": len(message)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to MEMORY.md: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 保存失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _get_memory_path(self, ctx: CommandContext) -> Path:
|
||||
"""获取MEMORY.md路径"""
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
asset_dir = sm.get_agent_asset_dir(ctx.config_name, ctx.agent_id)
|
||||
return asset_dir / "MEMORY.md"
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""获取当前时间戳"""
|
||||
from datetime import datetime
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
class CompactCommandHandler(CommandHandler):
|
||||
"""处理 /compact 命令 - 压缩记忆"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
memory_manager = getattr(agent, "memory_manager", None)
|
||||
|
||||
if memory_manager is None:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="❌ Memory Manager 未启用\n\n- 记忆压缩功能不可用\n- 请在配置中启用 memory_manager"
|
||||
)
|
||||
|
||||
messages = await self._get_messages(agent)
|
||||
if not messages:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="⚠️ 没有可压缩的消息\n\n- 当前记忆为空\n- 无需执行压缩"
|
||||
)
|
||||
|
||||
compact_content = await memory_manager.compact_memory(messages)
|
||||
await self._update_compressed_summary(agent, compact_content)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 记忆压缩完成\n\n- 压缩了 {len(messages)} 条消息\n- 摘要长度: {len(compact_content)} 字符",
|
||||
data={"message_count": len(messages), "summary_length": len(compact_content)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compact memory: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 压缩失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _get_messages(self, agent: "EvoAgent") -> List[Any]:
|
||||
"""获取Agent的记忆消息"""
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory is None:
|
||||
return []
|
||||
return await memory.get_memory() if hasattr(memory, "get_memory") else []
|
||||
|
||||
async def _update_compressed_summary(self, agent: "EvoAgent", content: str) -> None:
|
||||
"""更新压缩摘要"""
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory and hasattr(memory, "update_compressed_summary"):
|
||||
await memory.update_compressed_summary(content)
|
||||
|
||||
|
||||
class SkillsListCommandHandler(CommandHandler):
|
||||
"""处理 /skills list 命令 - 列出已激活技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||
catalog = sm.list_agent_skill_catalog(ctx.config_name, ctx.agent_id)
|
||||
|
||||
lines = ["📋 技能列表", ""]
|
||||
|
||||
if active_skills:
|
||||
lines.append("✅ 已激活技能:")
|
||||
for skill in active_skills:
|
||||
lines.append(f" • {skill.name} - {skill.description[:50]}...")
|
||||
else:
|
||||
lines.append("⚠️ 当前没有激活的技能")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"📚 可用技能总数: {len(catalog)}")
|
||||
lines.append("💡 使用 /skills enable <name> 启用技能")
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="\n".join(lines),
|
||||
data={
|
||||
"active_count": len(active_skills),
|
||||
"catalog_count": len(catalog),
|
||||
"active": [s.skill_name for s in active_skills]
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list skills: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 获取技能列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsEnableCommandHandler(CommandHandler):
|
||||
"""处理 /skills enable <name> 命令 - 启用技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills enable <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
result = sm.update_agent_skill_overrides(
|
||||
ctx.config_name,
|
||||
ctx.agent_id,
|
||||
enable=[skill_name]
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已启用: {skill_name}\n\n已启用技能: {', '.join(result['enabled_skills'])}",
|
||||
data=result
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enable skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 启用技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsDisableCommandHandler(CommandHandler):
|
||||
"""处理 /skills disable <name> 命令 - 禁用技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills disable <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
result = sm.update_agent_skill_overrides(
|
||||
ctx.config_name,
|
||||
ctx.agent_id,
|
||||
disable=[skill_name]
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已禁用: {skill_name}\n\n已禁用技能: {', '.join(result['disabled_skills'])}",
|
||||
data=result
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disable skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 禁用技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsInstallCommandHandler(CommandHandler):
|
||||
"""处理 /skills install <name> 命令 - 安装技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills install <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.skill_loader import load_skill_from_dir
|
||||
sm = SkillsManager()
|
||||
|
||||
# 查找技能源目录
|
||||
source_dir = self._resolve_skill_source(sm, skill_name)
|
||||
if not source_dir:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 技能未找到: {skill_name}\n\n请检查技能名称是否正确,或技能是否存在于 builtin/customized 目录。"
|
||||
)
|
||||
|
||||
# 加载并验证技能
|
||||
skill_info = load_skill_from_dir(source_dir)
|
||||
if not skill_info:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 技能加载失败: {skill_name}\n\n技能格式可能不正确。"
|
||||
)
|
||||
|
||||
# 安装到agent的installed目录
|
||||
installed_root = sm.get_agent_installed_root(ctx.config_name, ctx.agent_id)
|
||||
target_dir = installed_root / skill_name
|
||||
|
||||
import shutil
|
||||
if target_dir.exists():
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已安装: {skill_name}\n\n- 名称: {skill_info.get('name', skill_name)}\n- 版本: {skill_info.get('version', 'unknown')}\n- 路径: {target_dir}",
|
||||
data={"skill_name": skill_name, "target_dir": str(target_dir)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 安装技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _resolve_skill_source(self, sm: "SkillsManager", skill_name: str) -> Optional[Path]:
|
||||
"""解析技能源目录"""
|
||||
for root in [sm.customized_root, sm.builtin_root]:
|
||||
candidate = root / skill_name
|
||||
if candidate.exists() and (candidate / "SKILL.md").exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
class ReloadCommandHandler(CommandHandler):
|
||||
"""处理 /reload 命令 - 重新加载配置"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
|
||||
# 重新加载配置
|
||||
if hasattr(agent, "reload_config"):
|
||||
await agent.reload_config()
|
||||
|
||||
# 重新加载技能
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
# 刷新技能同步
|
||||
active_root = sm.get_agent_active_root(ctx.config_name, ctx.agent_id)
|
||||
if active_root.exists():
|
||||
# 清除缓存,强制重新加载
|
||||
import shutil
|
||||
for item in active_root.iterdir():
|
||||
if item.is_dir():
|
||||
shutil.rmtree(item)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="✅ 配置已重新加载\n\n- Agent配置已刷新\n- 技能缓存已清除\n- 请重启对话以应用所有更改",
|
||||
data={"config_name": ctx.config_name, "agent_id": ctx.agent_id}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload config: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 重新加载失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class StatusCommandHandler(CommandHandler):
|
||||
"""处理 /status 命令 - 显示Agent状态"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
|
||||
lines = ["📊 Agent 状态", ""]
|
||||
lines.append(f"🆔 Agent ID: {ctx.agent_id}")
|
||||
lines.append(f"⚙️ Config: {ctx.config_name}")
|
||||
|
||||
# 模型信息
|
||||
model = getattr(agent, "model", None)
|
||||
if model:
|
||||
lines.append(f"🤖 Model: {model}")
|
||||
|
||||
# 记忆状态
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory:
|
||||
msg_count = len(getattr(memory, "content", []))
|
||||
lines.append(f"💾 Memory: {msg_count} messages")
|
||||
|
||||
# 技能状态
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||
lines.append(f"🔧 Active Skills: {len(active_skills)}")
|
||||
|
||||
# 工具组状态
|
||||
toolkit = getattr(agent, "toolkit", None)
|
||||
if toolkit:
|
||||
groups = getattr(toolkit, "tool_groups", {})
|
||||
active_groups = [name for name, g in groups.items() if getattr(g, "active", False)]
|
||||
lines.append(f"🛠️ Active Tool Groups: {', '.join(active_groups) if active_groups else 'None'}")
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="\n".join(lines),
|
||||
data={
|
||||
"agent_id": ctx.agent_id,
|
||||
"config_name": ctx.config_name,
|
||||
"active_skills_count": len(active_skills)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 获取状态失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class HelpCommandHandler(CommandHandler):
|
||||
"""处理 /help 命令 - 显示帮助"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
help_text = """📖 EvoAgent 命令帮助
|
||||
|
||||
可用命令:
|
||||
/save <message> - 保存内容到 MEMORY.md
|
||||
/compact - 压缩记忆
|
||||
/skills list - 列出已激活技能
|
||||
/skills enable <name> - 启用技能
|
||||
/skills disable <name>- 禁用技能
|
||||
/skills install <name>- 安装技能
|
||||
/reload - 重新加载配置
|
||||
/status - 显示Agent状态
|
||||
/help - 显示此帮助信息
|
||||
|
||||
提示:
|
||||
• 所有命令以 / 开头
|
||||
• 命令不区分大小写
|
||||
• 使用 Tab 键可自动补全命令
|
||||
"""
|
||||
return CommandResult(success=True, message=help_text)
|
||||
|
||||
|
||||
class AgentCommandDispatcher:
|
||||
"""Agent命令分发器
|
||||
|
||||
参考CoPaw的CommandHandler设计,为EvoAgent提供统一的命令处理入口。
|
||||
"""
|
||||
|
||||
# 支持的系统命令
|
||||
SYSTEM_COMMANDS = frozenset({
|
||||
"save", "compact",
|
||||
"skills", "reload",
|
||||
"status", "help"
|
||||
})
|
||||
|
||||
def __init__(self):
|
||||
self._handlers: Dict[str, CommandHandler] = {}
|
||||
self._subcommands: Dict[str, Dict[str, CommandHandler]] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _register_default_handlers(self) -> None:
|
||||
"""注册默认命令处理器"""
|
||||
self._handlers["save"] = SaveCommandHandler()
|
||||
self._handlers["compact"] = CompactCommandHandler()
|
||||
self._handlers["reload"] = ReloadCommandHandler()
|
||||
self._handlers["status"] = StatusCommandHandler()
|
||||
self._handlers["help"] = HelpCommandHandler()
|
||||
|
||||
# 子命令: /skills list/enable/disable/install
|
||||
self._subcommands["skills"] = {
|
||||
"list": SkillsListCommandHandler(),
|
||||
"enable": SkillsEnableCommandHandler(),
|
||||
"disable": SkillsDisableCommandHandler(),
|
||||
"install": SkillsInstallCommandHandler(),
|
||||
}
|
||||
|
||||
def is_command(self, query: str | None) -> bool:
|
||||
"""检查是否为命令
|
||||
|
||||
Args:
|
||||
query: 用户输入字符串
|
||||
|
||||
Returns:
|
||||
True 如果是系统命令
|
||||
"""
|
||||
if not isinstance(query, str) or not query.startswith("/"):
|
||||
return False
|
||||
|
||||
parts = query.strip().lstrip("/").split()
|
||||
if not parts:
|
||||
return False
|
||||
|
||||
cmd = parts[0].lower()
|
||||
|
||||
# 检查主命令
|
||||
if cmd in self.SYSTEM_COMMANDS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle(self, agent: "EvoAgent", query: str) -> CommandResult:
|
||||
"""处理命令
|
||||
|
||||
Args:
|
||||
agent: EvoAgent实例
|
||||
query: 命令字符串
|
||||
|
||||
Returns:
|
||||
命令执行结果
|
||||
"""
|
||||
if not self.is_command(query):
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"未知命令: {query}\n使用 /help 查看可用命令。"
|
||||
)
|
||||
|
||||
# 解析命令和参数
|
||||
parts = query.strip().lstrip("/").split(maxsplit=1)
|
||||
cmd = parts[0].lower()
|
||||
args = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
logger.info(f"Processing command: {cmd}, args: {args}")
|
||||
|
||||
# 处理子命令 (e.g., /skills list)
|
||||
if cmd in self._subcommands:
|
||||
sub_parts = args.split(maxsplit=1)
|
||||
sub_cmd = sub_parts[0].lower() if sub_parts else ""
|
||||
sub_args = sub_parts[1] if len(sub_parts) > 1 else ""
|
||||
|
||||
handlers = self._subcommands[cmd]
|
||||
handler = handlers.get(sub_cmd)
|
||||
|
||||
if handler is None:
|
||||
available = ", ".join(handlers.keys())
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"未知子命令: {sub_cmd}\n可用子命令: {available}"
|
||||
)
|
||||
|
||||
ctx = CommandContext(agent, query, sub_args)
|
||||
return await handler.handle(ctx)
|
||||
|
||||
# 处理主命令
|
||||
handler = self._handlers.get(cmd)
|
||||
if handler is None:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"命令未实现: {cmd}"
|
||||
)
|
||||
|
||||
ctx = CommandContext(agent, query, args)
|
||||
return await handler.handle(ctx)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_command_dispatcher() -> AgentCommandDispatcher:
|
||||
"""创建命令分发器实例"""
|
||||
return AgentCommandDispatcher()
|
||||
411
backend/agents/base/evo_agent.py
Normal file
411
backend/agents/base/evo_agent.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""EvoAgent - Core agent implementation for EvoTraders.
|
||||
|
||||
This module provides the main EvoAgent class built on AgentScope's ReActAgent,
|
||||
with integrated tools, skills, and memory management based on CoPaw design.
|
||||
|
||||
Key features:
|
||||
- Workspace-driven configuration from Markdown files
|
||||
- Dynamic skill loading from skills/active directories
|
||||
- Tool-guard security interception
|
||||
- Hook system for extensibility
|
||||
- Runtime skill and prompt reloading
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
from .tool_guard import ToolGuardMixin
|
||||
from .hooks import (
|
||||
HookManager,
|
||||
BootstrapHook,
|
||||
MemoryCompactionHook,
|
||||
HOOK_PRE_REASONING,
|
||||
)
|
||||
from ..prompts.builder import (
|
||||
PromptBuilder,
|
||||
build_system_prompt_from_workspace,
|
||||
)
|
||||
from ..agent_workspace import load_agent_workspace_config
|
||||
from ..skills_manager import SkillsManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentscope.formatter import FormatterBase
|
||||
from agentscope.model import ModelWrapperBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
"""EvoAgent with integrated tools, skills, and memory management.
|
||||
|
||||
This agent extends ReActAgent with:
|
||||
- Workspace-driven configuration from AGENTS.md/SOUL.md/PROFILE.md/etc.
|
||||
- Dynamic skill loading from skills/active directories
|
||||
- Tool-guard security interception (via ToolGuardMixin)
|
||||
- Hook system for extensibility (bootstrap, memory compaction)
|
||||
- Runtime skill and prompt reloading
|
||||
|
||||
MRO note
|
||||
~~~~~~~~
|
||||
``ToolGuardMixin`` overrides ``_acting`` and ``_reasoning`` via
|
||||
Python's MRO: EvoAgent → ToolGuardMixin → ReActAgent.
|
||||
|
||||
Example:
|
||||
agent = EvoAgent(
|
||||
agent_id="fundamentals_analyst",
|
||||
config_name="smoke_fullstack",
|
||||
workspace_dir=Path("runs/smoke_fullstack/agents/fundamentals_analyst"),
|
||||
model=model_instance,
|
||||
formatter=formatter_instance,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
config_name: str,
|
||||
workspace_dir: Path,
|
||||
model: "ModelWrapperBase",
|
||||
formatter: "FormatterBase",
|
||||
skills_manager: Optional[SkillsManager] = None,
|
||||
sys_prompt: Optional[str] = None,
|
||||
max_iters: int = 10,
|
||||
memory: Optional[Any] = None,
|
||||
enable_tool_guard: bool = True,
|
||||
enable_bootstrap_hook: bool = True,
|
||||
enable_memory_compaction: bool = False,
|
||||
memory_manager: Optional[Any] = None,
|
||||
memory_compact_threshold: Optional[int] = None,
|
||||
env_context: Optional[str] = None,
|
||||
prompt_files: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize EvoAgent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for this agent
|
||||
config_name: Run configuration name (e.g., "smoke_fullstack")
|
||||
workspace_dir: Agent workspace directory containing markdown files
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
skills_manager: Optional SkillsManager instance
|
||||
sys_prompt: Optional override for system prompt
|
||||
max_iters: Maximum reasoning-acting iterations
|
||||
memory: Optional memory instance (defaults to InMemoryMemory)
|
||||
enable_tool_guard: Enable tool-guard security interception
|
||||
enable_bootstrap_hook: Enable bootstrap guidance on first interaction
|
||||
enable_memory_compaction: Enable automatic memory compaction
|
||||
memory_manager: Optional memory manager for compaction
|
||||
memory_compact_threshold: Token threshold for memory compaction
|
||||
env_context: Optional environment context to prepend to system prompt
|
||||
prompt_files: List of markdown files to load (defaults to standard set)
|
||||
"""
|
||||
self.agent_id = agent_id
|
||||
self.config_name = config_name
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self._skills_manager = skills_manager or SkillsManager()
|
||||
self._env_context = env_context
|
||||
self._prompt_files = prompt_files
|
||||
|
||||
# Initialize tool guard
|
||||
if enable_tool_guard:
|
||||
self._init_tool_guard()
|
||||
|
||||
# Load agent configuration from workspace
|
||||
self._agent_config = self._load_agent_config()
|
||||
|
||||
# Build or use provided system prompt
|
||||
if sys_prompt is not None:
|
||||
self._sys_prompt = sys_prompt
|
||||
else:
|
||||
self._sys_prompt = self._build_system_prompt()
|
||||
|
||||
# Create toolkit with skills
|
||||
toolkit = self._create_toolkit()
|
||||
|
||||
# Initialize hook manager
|
||||
self._hook_manager = HookManager()
|
||||
|
||||
# Initialize parent ReActAgent
|
||||
super().__init__(
|
||||
name=agent_id,
|
||||
model=model,
|
||||
sys_prompt=self._sys_prompt,
|
||||
toolkit=toolkit,
|
||||
memory=memory or InMemoryMemory(),
|
||||
formatter=formatter,
|
||||
max_iters=max_iters,
|
||||
)
|
||||
|
||||
# Register hooks
|
||||
self._register_hooks(
|
||||
enable_bootstrap=enable_bootstrap_hook,
|
||||
enable_memory_compaction=enable_memory_compaction,
|
||||
memory_manager=memory_manager,
|
||||
memory_compact_threshold=memory_compact_threshold,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"EvoAgent initialized: %s (workspace: %s)",
|
||||
agent_id,
|
||||
workspace_dir,
|
||||
)
|
||||
|
||||
def _load_agent_config(self) -> Dict[str, Any]:
|
||||
"""Load agent configuration from workspace.
|
||||
|
||||
Returns:
|
||||
Agent configuration dictionary
|
||||
"""
|
||||
config_path = self.workspace_dir / "agent.yaml"
|
||||
if config_path.exists():
|
||||
loaded = load_agent_workspace_config(config_path)
|
||||
return dict(loaded.values)
|
||||
return {}
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build system prompt from workspace markdown files.
|
||||
|
||||
Uses PromptBuilder to load and combine AGENTS.md, SOUL.md,
|
||||
PROFILE.md, and other configured files.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
"""
|
||||
prompt = build_system_prompt_from_workspace(
|
||||
workspace_dir=self.workspace_dir,
|
||||
enabled_files=self._prompt_files,
|
||||
agent_id=self.agent_id,
|
||||
extra_context=self._env_context,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _create_toolkit(self) -> Toolkit:
|
||||
"""Create and populate toolkit with agent skills.
|
||||
|
||||
Loads skills from the agent's active skills directory and
|
||||
registers them with the toolkit.
|
||||
|
||||
Returns:
|
||||
Configured Toolkit instance
|
||||
"""
|
||||
toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
"<system-info>You have access to specialized skills. "
|
||||
"Each skill lives in a directory and is described by SKILL.md. "
|
||||
"Follow the skill instructions when they are relevant to the current task."
|
||||
"</system-info>"
|
||||
),
|
||||
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||
)
|
||||
|
||||
# Register skills from active directory
|
||||
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
|
||||
if active_skills_dir.exists():
|
||||
for skill_dir in sorted(active_skills_dir.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
try:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
logger.debug("Registered skill: %s", skill_dir.name)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register skill '%s': %s",
|
||||
skill_dir.name,
|
||||
e,
|
||||
)
|
||||
|
||||
return toolkit
|
||||
|
||||
def _register_hooks(
|
||||
self,
|
||||
enable_bootstrap: bool,
|
||||
enable_memory_compaction: bool,
|
||||
memory_manager: Optional[Any],
|
||||
memory_compact_threshold: Optional[int],
|
||||
) -> None:
|
||||
"""Register agent hooks.
|
||||
|
||||
Args:
|
||||
enable_bootstrap: Enable bootstrap hook
|
||||
enable_memory_compaction: Enable memory compaction hook
|
||||
memory_manager: Memory manager instance
|
||||
memory_compact_threshold: Token threshold for compaction
|
||||
"""
|
||||
# Bootstrap hook - checks BOOTSTRAP.md on first interaction
|
||||
if enable_bootstrap:
|
||||
bootstrap_hook = BootstrapHook(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language="zh",
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="bootstrap",
|
||||
hook=bootstrap_hook,
|
||||
)
|
||||
logger.debug("Registered bootstrap hook")
|
||||
|
||||
# Memory compaction hook
|
||||
if enable_memory_compaction and memory_manager is not None:
|
||||
compaction_hook = MemoryCompactionHook(
|
||||
memory_manager=memory_manager,
|
||||
memory_compact_threshold=memory_compact_threshold,
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="memory_compaction",
|
||||
hook=compaction_hook,
|
||||
)
|
||||
logger.debug("Registered memory compaction hook")
|
||||
|
||||
async def _reasoning(self, **kwargs) -> Msg:
|
||||
"""Override reasoning to execute pre-reasoning hooks.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments for reasoning
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
# Execute pre-reasoning hooks
|
||||
kwargs = await self._hook_manager.execute(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
agent=self,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Call parent (which may be ToolGuardMixin's _reasoning)
|
||||
return await super()._reasoning(**kwargs)
|
||||
|
||||
def reload_skills(self, active_skill_dirs: Optional[List[Path]] = None) -> None:
|
||||
"""Reload skills at runtime.
|
||||
|
||||
Rebuilds the toolkit with current skills from the active directory.
|
||||
|
||||
Args:
|
||||
active_skill_dirs: Optional list of specific skill directories to load
|
||||
"""
|
||||
logger.info("Reloading skills for agent: %s", self.agent_id)
|
||||
|
||||
# Create new toolkit
|
||||
new_toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
"<system-info>You have access to specialized skills. "
|
||||
"Each skill lives in a directory and is described by SKILL.md. "
|
||||
"Follow the skill instructions when they are relevant to the current task."
|
||||
"</system-info>"
|
||||
),
|
||||
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||
)
|
||||
|
||||
# Register skills
|
||||
if active_skill_dirs is None:
|
||||
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
if active_skills_dir.exists():
|
||||
active_skill_dirs = [
|
||||
d for d in active_skills_dir.iterdir()
|
||||
if d.is_dir() and (d / "SKILL.md").exists()
|
||||
]
|
||||
else:
|
||||
active_skill_dirs = []
|
||||
|
||||
for skill_dir in active_skill_dirs:
|
||||
if skill_dir.exists() and (skill_dir / "SKILL.md").exists():
|
||||
try:
|
||||
new_toolkit.register_agent_skill(str(skill_dir))
|
||||
logger.debug("Reloaded skill: %s", skill_dir.name)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to reload skill '%s': %s",
|
||||
skill_dir.name,
|
||||
e,
|
||||
)
|
||||
|
||||
# Replace toolkit
|
||||
self.toolkit = new_toolkit
|
||||
logger.info("Skills reloaded for agent: %s", self.agent_id)
|
||||
|
||||
def rebuild_sys_prompt(self) -> None:
|
||||
"""Rebuild and replace the system prompt at runtime.
|
||||
|
||||
Useful after updating AGENTS.md, SOUL.md, PROFILE.md, etc.
|
||||
to ensure the prompt reflects the latest configuration.
|
||||
|
||||
Updates both self._sys_prompt and the first system-role
|
||||
message stored in self.memory.content.
|
||||
"""
|
||||
logger.info("Rebuilding system prompt for agent: %s", self.agent_id)
|
||||
|
||||
# Reload agent config in case it changed
|
||||
self._agent_config = self._load_agent_config()
|
||||
|
||||
# Rebuild prompt
|
||||
self._sys_prompt = self._build_system_prompt()
|
||||
|
||||
# Update memory if system message exists
|
||||
if hasattr(self, "memory") and self.memory.content:
|
||||
for msg, _marks in self.memory.content:
|
||||
if getattr(msg, "role", None) == "system":
|
||||
msg.content = self._sys_prompt
|
||||
logger.debug("Updated system message in memory")
|
||||
break
|
||||
|
||||
logger.info("System prompt rebuilt for agent: %s", self.agent_id)
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
msg: Msg | List[Msg] | None = None,
|
||||
structured_model: Optional[Type[Any]] = None,
|
||||
) -> Msg:
|
||||
"""Process a message and return a response.
|
||||
|
||||
Args:
|
||||
msg: Input message(s) from user
|
||||
structured_model: Optional pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
# Handle list of messages
|
||||
if isinstance(msg, list):
|
||||
# Process each message in sequence
|
||||
for m in msg[:-1]:
|
||||
await self.memory.add(m)
|
||||
msg = msg[-1] if msg else None
|
||||
|
||||
return await super().reply(msg=msg, structured_model=structured_model)
|
||||
|
||||
def get_agent_info(self) -> Dict[str, Any]:
|
||||
"""Get agent information.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent metadata
|
||||
"""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"config_name": self.config_name,
|
||||
"workspace_dir": str(self.workspace_dir),
|
||||
"skills_count": len([
|
||||
s for s in self._skills_manager.list_active_skill_metadata(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
]),
|
||||
"registered_hooks": self._hook_manager.list_hooks(),
|
||||
}
|
||||
|
||||
|
||||
__all__ = ["EvoAgent"]
|
||||
432
backend/agents/base/hooks.py
Normal file
432
backend/agents/base/hooks.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hook system for EvoAgent.
|
||||
|
||||
Provides pre_reasoning and post_acting hooks with built-in implementations:
|
||||
- BootstrapHook: First-time setup guidance
|
||||
- MemoryCompactionHook: Automatic memory compression
|
||||
|
||||
Based on CoPaw's hooks design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentscope.agent import ReActAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hook types
|
||||
HookType = str
|
||||
HOOK_PRE_REASONING: HookType = "pre_reasoning"
|
||||
HOOK_POST_ACTING: HookType = "post_acting"
|
||||
|
||||
|
||||
class Hook(ABC):
|
||||
"""Abstract base class for agent hooks."""
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Execute the hook.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the method being hooked
|
||||
|
||||
Returns:
|
||||
Modified kwargs or None to use original
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages agent hooks.
|
||||
|
||||
Provides registration and execution of hooks for different
|
||||
lifecycle events in the agent's operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: Dict[HookType, List[tuple[str, Hook]]] = {
|
||||
HOOK_PRE_REASONING: [],
|
||||
HOOK_POST_ACTING: [],
|
||||
}
|
||||
|
||||
def register(
|
||||
self,
|
||||
hook_type: HookType,
|
||||
hook_name: str,
|
||||
hook: Hook | Callable,
|
||||
) -> None:
|
||||
"""Register a hook.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hook (pre_reasoning, post_acting)
|
||||
hook_name: Unique name for this hook
|
||||
hook: Hook instance or callable
|
||||
"""
|
||||
# Remove existing hook with same name
|
||||
self._hooks[hook_type] = [
|
||||
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||
]
|
||||
self._hooks[hook_type].append((hook_name, hook))
|
||||
logger.debug("Registered hook '%s' for type '%s'", hook_name, hook_type)
|
||||
|
||||
def unregister(self, hook_type: HookType, hook_name: str) -> bool:
|
||||
"""Unregister a hook.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hook
|
||||
hook_name: Name of the hook to remove
|
||||
|
||||
Returns:
|
||||
True if hook was found and removed
|
||||
"""
|
||||
original_len = len(self._hooks[hook_type])
|
||||
self._hooks[hook_type] = [
|
||||
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||
]
|
||||
removed = len(self._hooks[hook_type]) < original_len
|
||||
if removed:
|
||||
logger.debug("Unregistered hook '%s' from type '%s'", hook_name, hook_type)
|
||||
return removed
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
hook_type: HookType,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute all hooks of a given type.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hooks to execute
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments
|
||||
|
||||
Returns:
|
||||
Potentially modified kwargs
|
||||
"""
|
||||
for name, hook in self._hooks[hook_type]:
|
||||
try:
|
||||
result = await hook(agent, kwargs)
|
||||
if result is not None:
|
||||
kwargs = result
|
||||
except Exception as e:
|
||||
logger.error("Hook '%s' failed: %s", name, e, exc_info=True)
|
||||
|
||||
return kwargs
|
||||
|
||||
def list_hooks(self, hook_type: Optional[HookType] = None) -> List[str]:
|
||||
"""List registered hook names.
|
||||
|
||||
Args:
|
||||
hook_type: Optional type to filter by
|
||||
|
||||
Returns:
|
||||
List of hook names
|
||||
"""
|
||||
if hook_type:
|
||||
return [name for name, _ in self._hooks.get(hook_type, [])]
|
||||
|
||||
names = []
|
||||
for hooks in self._hooks.values():
|
||||
names.extend([name for name, _ in hooks])
|
||||
return names
|
||||
|
||||
|
||||
class BootstrapHook(Hook):
|
||||
"""Hook for bootstrap guidance on first user interaction.
|
||||
|
||||
This hook looks for a BOOTSTRAP.md file in the working directory
|
||||
and if found, prepends guidance to the first user message to help
|
||||
establish the agent's identity and user preferences.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
language: str = "zh",
|
||||
):
|
||||
"""Initialize bootstrap hook.
|
||||
|
||||
Args:
|
||||
workspace_dir: Working directory containing BOOTSTRAP.md
|
||||
language: Language code for bootstrap guidance (en/zh)
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self.language = language
|
||||
self._completed_flag = self.workspace_dir / ".bootstrap_completed"
|
||||
|
||||
def _is_first_user_interaction(self, agent: "ReActAgent") -> bool:
|
||||
"""Check if this is the first user interaction.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
|
||||
Returns:
|
||||
True if first user interaction
|
||||
"""
|
||||
if not hasattr(agent, "memory") or not agent.memory.content:
|
||||
return True
|
||||
|
||||
# Count user messages (excluding system)
|
||||
user_count = sum(
|
||||
1 for msg, _ in agent.memory.content if msg.role == "user"
|
||||
)
|
||||
return user_count <= 1
|
||||
|
||||
def _build_bootstrap_guidance(self) -> str:
|
||||
"""Build bootstrap guidance message.
|
||||
|
||||
Returns:
|
||||
Formatted bootstrap guidance
|
||||
"""
|
||||
if self.language == "zh":
|
||||
return (
|
||||
"# 引导模式\n"
|
||||
"\n"
|
||||
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
|
||||
"\n"
|
||||
"1. 阅读 BOOTSTRAP.md,友好地表示初次见面,"
|
||||
"引导用户完成设置。\n"
|
||||
"2. 按照 BOOTSTRAP.md 的指示,"
|
||||
"帮助用户定义你的身份和偏好。\n"
|
||||
"3. 按指南创建/更新必要文件"
|
||||
"(PROFILE.md、MEMORY.md 等)。\n"
|
||||
"4. 完成后删除 BOOTSTRAP.md。\n"
|
||||
"\n"
|
||||
"如果用户希望跳过,直接回答下面的问题即可。\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
return (
|
||||
"# BOOTSTRAP MODE\n"
|
||||
"\n"
|
||||
"`BOOTSTRAP.md` exists — first-time setup.\n"
|
||||
"\n"
|
||||
"1. Read BOOTSTRAP.md, greet the user, "
|
||||
"and guide them through setup.\n"
|
||||
"2. Follow BOOTSTRAP.md instructions "
|
||||
"to define identity and preferences.\n"
|
||||
"3. Create/update files "
|
||||
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
|
||||
"4. Delete BOOTSTRAP.md when done.\n"
|
||||
"\n"
|
||||
"If the user wants to skip, answer their "
|
||||
"question directly instead.\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Check and load BOOTSTRAP.md on first user interaction.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the _reasoning method
|
||||
|
||||
Returns:
|
||||
None (hook doesn't modify kwargs)
|
||||
"""
|
||||
try:
|
||||
bootstrap_path = self.workspace_dir / "BOOTSTRAP.md"
|
||||
|
||||
# Check if bootstrap has already been triggered
|
||||
if self._completed_flag.exists():
|
||||
return None
|
||||
|
||||
if not bootstrap_path.exists():
|
||||
return None
|
||||
|
||||
if not self._is_first_user_interaction(agent):
|
||||
return None
|
||||
|
||||
bootstrap_guidance = self._build_bootstrap_guidance()
|
||||
|
||||
logger.debug("Found BOOTSTRAP.md [%s], prepending guidance", self.language)
|
||||
|
||||
# Prepend to first user message in memory
|
||||
if hasattr(agent, "memory") and agent.memory.content:
|
||||
system_count = sum(
|
||||
1 for msg, _ in agent.memory.content if msg.role == "system"
|
||||
)
|
||||
for msg, _ in agent.memory.content[system_count:]:
|
||||
if msg.role == "user":
|
||||
# Prepend guidance to message content
|
||||
original_content = msg.content
|
||||
msg.content = bootstrap_guidance + original_content
|
||||
break
|
||||
|
||||
logger.debug("Bootstrap guidance prepended to first user message")
|
||||
|
||||
# Create completion flag to prevent repeated triggering
|
||||
self._completed_flag.touch()
|
||||
logger.debug("Created bootstrap completion flag")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to process bootstrap: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MemoryCompactionHook(Hook):
|
||||
"""Hook for automatic memory compaction when context is full.
|
||||
|
||||
This hook monitors the token count of messages and triggers compaction
|
||||
when it exceeds the threshold. It preserves the system prompt and recent
|
||||
messages while summarizing older conversation history.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_manager: Any,
|
||||
memory_compact_threshold: Optional[int] = None,
|
||||
memory_compact_reserve: Optional[int] = None,
|
||||
enable_tool_result_compact: bool = False,
|
||||
tool_result_compact_keep_n: int = 5,
|
||||
):
|
||||
"""Initialize memory compaction hook.
|
||||
|
||||
Args:
|
||||
memory_manager: Memory manager instance for compaction
|
||||
memory_compact_threshold: Token threshold for compaction
|
||||
memory_compact_reserve: Reserve tokens for recent messages
|
||||
enable_tool_result_compact: Enable tool result compaction
|
||||
tool_result_compact_keep_n: Number of tool results to keep
|
||||
"""
|
||||
self.memory_manager = memory_manager
|
||||
self.memory_compact_threshold = memory_compact_threshold
|
||||
self.memory_compact_reserve = memory_compact_reserve
|
||||
self.enable_tool_result_compact = enable_tool_result_compact
|
||||
self.tool_result_compact_keep_n = tool_result_compact_keep_n
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-reasoning hook to check and compact memory if needed.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the _reasoning method
|
||||
|
||||
Returns:
|
||||
None (hook doesn't modify kwargs)
|
||||
"""
|
||||
try:
|
||||
if not hasattr(agent, "memory") or not self.memory_manager:
|
||||
return None
|
||||
|
||||
memory = agent.memory
|
||||
|
||||
# Get current token count estimate
|
||||
messages = await memory.get_memory()
|
||||
total_tokens = self._estimate_tokens(messages)
|
||||
|
||||
if self.memory_compact_threshold is None:
|
||||
return None
|
||||
|
||||
if total_tokens < self.memory_compact_threshold:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Memory compaction triggered: %d tokens (threshold: %d)",
|
||||
total_tokens,
|
||||
self.memory_compact_threshold,
|
||||
)
|
||||
|
||||
# Compact memory
|
||||
await self._compact_memory(agent, messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to compact memory: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
def _estimate_tokens(self, messages: List[Any]) -> int:
|
||||
"""Estimate token count for messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Simple estimation: ~4 chars per token
|
||||
total_chars = sum(
|
||||
len(str(getattr(msg, "content", "")))
|
||||
for msg in messages
|
||||
)
|
||||
return total_chars // 4
|
||||
|
||||
async def _compact_memory(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
messages: List[Any],
|
||||
) -> None:
|
||||
"""Compact memory by summarizing older messages.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current messages in memory
|
||||
"""
|
||||
if self.memory_compact_reserve is None:
|
||||
return
|
||||
|
||||
# Keep recent messages
|
||||
keep_count = min(
|
||||
len(messages) // 4,
|
||||
10, # Max 10 recent messages
|
||||
)
|
||||
keep_count = max(keep_count, 2) # At least 2
|
||||
|
||||
messages_to_compact = messages[:-keep_count] if keep_count < len(messages) else []
|
||||
|
||||
if not messages_to_compact:
|
||||
return
|
||||
|
||||
# Use memory manager to compact if available
|
||||
if hasattr(self.memory_manager, "compact_memory"):
|
||||
try:
|
||||
summary = await self.memory_manager.compact_memory(
|
||||
messages=messages_to_compact,
|
||||
)
|
||||
logger.info("Memory compacted: %d messages summarized", len(messages_to_compact))
|
||||
|
||||
# Mark messages as compressed if supported
|
||||
if hasattr(agent.memory, "update_messages_mark"):
|
||||
from agentscope.agent._react_agent import _MemoryMark
|
||||
await agent.memory.update_messages_mark(
|
||||
new_mark=_MemoryMark.COMPRESSED,
|
||||
msg_ids=[msg.id for msg in messages_to_compact],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Memory manager compaction failed: %s", e)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Hook",
|
||||
"HookManager",
|
||||
"HookType",
|
||||
"HOOK_PRE_REASONING",
|
||||
"HOOK_POST_ACTING",
|
||||
"BootstrapHook",
|
||||
"MemoryCompactionHook",
|
||||
]
|
||||
674
backend/agents/base/tool_guard.py
Normal file
674
backend/agents/base/tool_guard.py
Normal file
@@ -0,0 +1,674 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ToolGuardMixin - Security interception for dangerous tool calls.
|
||||
|
||||
Provides ``_acting`` and ``_reasoning`` overrides that intercept
|
||||
sensitive tool calls before execution, implementing the deny /
|
||||
guard / approve flow.
|
||||
|
||||
Based on CoPaw's tool_guard_mixin.py design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from agentscope.message import Msg
|
||||
from backend.runtime.manager import get_global_runtime_manager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SeverityLevel(str, Enum):
|
||||
"""Risk severity level."""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
"""Approval lifecycle state."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class ToolFindingRecord:
|
||||
"""Internal representation of a guard finding."""
|
||||
|
||||
def __init__(self, severity: SeverityLevel, message: str, field: Optional[str] = None) -> None:
|
||||
self.severity = severity
|
||||
self.message = message
|
||||
self.field = field
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"severity": self.severity.value,
|
||||
"message": self.message,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
|
||||
class ApprovalRecord:
|
||||
"""Stores the state of an approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
agent_id: str,
|
||||
workspace_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
findings: Optional[List[ToolFindingRecord]] = None,
|
||||
) -> None:
|
||||
self.approval_id = approval_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.agent_id = agent_id
|
||||
self.workspace_id = workspace_id
|
||||
self.session_id = session_id
|
||||
self.status = ApprovalStatus.PENDING
|
||||
self.findings = findings or []
|
||||
self.created_at = datetime.utcnow()
|
||||
self.resolved_at: Optional[datetime] = None
|
||||
self.resolved_by: Optional[str] = None
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.pending_request: "ToolApprovalRequest" | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"approval_id": self.approval_id,
|
||||
"status": self.status.value,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_input": self.tool_input,
|
||||
"agent_id": self.agent_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"session_id": self.session_id,
|
||||
"findings": [f.to_dict() for f in self.findings],
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
|
||||
"resolved_by": self.resolved_by,
|
||||
}
|
||||
|
||||
|
||||
class ToolGuardStore:
|
||||
"""Simple in-memory approval store for development/testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: Dict[str, ApprovalRecord] = {}
|
||||
self._counter = 0
|
||||
|
||||
def next_id(self) -> str:
|
||||
self._counter += 1
|
||||
return f"approval_{self._counter:06d}"
|
||||
|
||||
def list(
|
||||
self,
|
||||
status: ApprovalStatus | None = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Iterable[ApprovalRecord]:
|
||||
for record in self._records.values():
|
||||
if status and record.status != status:
|
||||
continue
|
||||
if workspace_id and record.workspace_id != workspace_id:
|
||||
continue
|
||||
if agent_id and record.agent_id != agent_id:
|
||||
continue
|
||||
yield record
|
||||
|
||||
def get(self, approval_id: str) -> Optional[ApprovalRecord]:
|
||||
return self._records.get(approval_id)
|
||||
|
||||
def create_pending(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
agent_id: str,
|
||||
workspace_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
findings: Optional[List[ToolFindingRecord]] = None,
|
||||
) -> ApprovalRecord:
|
||||
record = ApprovalRecord(
|
||||
approval_id=self.next_id(),
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=agent_id,
|
||||
workspace_id=workspace_id,
|
||||
session_id=session_id,
|
||||
findings=findings,
|
||||
)
|
||||
self._records[record.approval_id] = record
|
||||
return record
|
||||
|
||||
def set_status(
|
||||
self,
|
||||
approval_id: str,
|
||||
status: ApprovalStatus,
|
||||
resolved_by: Optional[str] = None,
|
||||
notify_request: bool = True,
|
||||
) -> ApprovalRecord:
|
||||
record = self._records[approval_id]
|
||||
if record.status == status:
|
||||
return record
|
||||
|
||||
record.status = status
|
||||
record.resolved_at = datetime.utcnow()
|
||||
record.resolved_by = resolved_by
|
||||
if notify_request and record.pending_request:
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
record.pending_request.approve()
|
||||
elif status == ApprovalStatus.DENIED:
|
||||
record.pending_request.deny()
|
||||
return record
|
||||
|
||||
def cancel(self, approval_id: str) -> None:
|
||||
self._records.pop(approval_id, None)
|
||||
|
||||
|
||||
TOOL_GUARD_STORE = ToolGuardStore()
|
||||
|
||||
|
||||
def get_tool_guard_store() -> ToolGuardStore:
|
||||
return TOOL_GUARD_STORE
|
||||
|
||||
|
||||
# Default tools that require approval
|
||||
DEFAULT_GUARDED_TOOLS: Set[str] = {
|
||||
"execute_shell_command",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"place_order",
|
||||
"modify_position",
|
||||
"delete_file",
|
||||
}
|
||||
|
||||
# Default denied tools (cannot be approved)
|
||||
DEFAULT_DENIED_TOOLS: Set[str] = {
|
||||
"execute_shell_command", # Shell execution is dangerous
|
||||
}
|
||||
|
||||
# Mark for tool guard denied messages
|
||||
TOOL_GUARD_DENIED_MARK = "tool_guard_denied"
|
||||
|
||||
|
||||
def default_findings_for_tool(tool_name: str) -> List[ToolFindingRecord]:
|
||||
findings: List[ToolFindingRecord] = []
|
||||
if tool_name in {"execute_trade", "modify_portfolio"}:
|
||||
findings.append(
|
||||
ToolFindingRecord(
|
||||
severity=SeverityLevel.HIGH,
|
||||
message=f"Tool '{tool_name}' touches portfolio state",
|
||||
)
|
||||
)
|
||||
return findings
|
||||
|
||||
|
||||
class ToolApprovalRequest:
|
||||
"""Represents a pending tool approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
tool_call_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
self.approval_id = approval_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.tool_call_id = tool_call_id
|
||||
self.session_id = session_id
|
||||
self.approved: Optional[bool] = None
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def wait_for_approval(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Wait for approval decision.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
True if approved, False otherwise
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
return self.approved is True
|
||||
|
||||
def approve(self) -> None:
|
||||
"""Approve this request."""
|
||||
self.approved = True
|
||||
self._event.set()
|
||||
|
||||
def deny(self) -> None:
|
||||
"""Deny this request."""
|
||||
self.approved = False
|
||||
self._event.set()
|
||||
|
||||
|
||||
class ToolGuardMixin:
|
||||
"""Mixin that adds tool-guard interception to a ReActAgent.
|
||||
|
||||
At runtime this class is combined with ReActAgent via MRO,
|
||||
so ``super()._acting`` and ``super()._reasoning`` resolve to
|
||||
the concrete agent methods.
|
||||
|
||||
Usage:
|
||||
class MyAgent(ToolGuardMixin, ReActAgent):
|
||||
def __init__(self, ...):
|
||||
super().__init__(...)
|
||||
self._init_tool_guard()
|
||||
"""
|
||||
|
||||
def _init_tool_guard(
|
||||
self,
|
||||
guarded_tools: Optional[Set[str]] = None,
|
||||
denied_tools: Optional[Set[str]] = None,
|
||||
approval_timeout: float = 300.0,
|
||||
) -> None:
|
||||
"""Initialize tool guard.
|
||||
|
||||
Args:
|
||||
guarded_tools: Set of tool names requiring approval
|
||||
denied_tools: Set of tool names that are always denied
|
||||
approval_timeout: Timeout for approval requests in seconds
|
||||
"""
|
||||
self._guarded_tools = guarded_tools or DEFAULT_GUARDED_TOOLS.copy()
|
||||
self._denied_tools = denied_tools or DEFAULT_DENIED_TOOLS.copy()
|
||||
self._approval_timeout = approval_timeout
|
||||
self._pending_approval: Optional[ToolApprovalRequest] = None
|
||||
self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None
|
||||
|
||||
def set_approval_callback(
|
||||
self,
|
||||
callback: Callable[[ToolApprovalRequest], None],
|
||||
) -> None:
|
||||
"""Set callback for approval requests.
|
||||
|
||||
Args:
|
||||
callback: Function called when approval is needed
|
||||
"""
|
||||
self._approval_callback = callback
|
||||
|
||||
def _is_tool_guarded(self, tool_name: str) -> bool:
|
||||
"""Check if a tool requires approval.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
True if tool requires approval
|
||||
"""
|
||||
return tool_name in self._guarded_tools
|
||||
|
||||
def _is_tool_denied(self, tool_name: str) -> bool:
|
||||
"""Check if a tool is always denied.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
True if tool is denied
|
||||
"""
|
||||
return tool_name in self._denied_tools
|
||||
|
||||
def _last_tool_response_is_denied(self) -> bool:
|
||||
"""Check if the last message is a guard-denied tool result."""
|
||||
if not hasattr(self, "memory") or not self.memory.content:
|
||||
return False
|
||||
|
||||
msg, marks = self.memory.content[-1]
|
||||
return TOOL_GUARD_DENIED_MARK in marks and msg.role == "system"
|
||||
|
||||
async def _cleanup_tool_guard_denied_messages(
|
||||
self,
|
||||
include_denial_response: bool = True,
|
||||
) -> None:
|
||||
"""Remove tool-guard denied messages from memory.
|
||||
|
||||
Args:
|
||||
include_denial_response: Also remove the assistant's denial explanation
|
||||
"""
|
||||
if not hasattr(self, "memory"):
|
||||
return
|
||||
|
||||
ids_to_delete: list[str] = []
|
||||
last_marked_idx = -1
|
||||
|
||||
for i, (msg, marks) in enumerate(self.memory.content):
|
||||
if TOOL_GUARD_DENIED_MARK in marks:
|
||||
ids_to_delete.append(msg.id)
|
||||
last_marked_idx = i
|
||||
|
||||
if (
|
||||
include_denial_response
|
||||
and last_marked_idx >= 0
|
||||
and last_marked_idx + 1 < len(self.memory.content)
|
||||
):
|
||||
next_msg, _ = self.memory.content[last_marked_idx + 1]
|
||||
if next_msg.role == "assistant":
|
||||
ids_to_delete.append(next_msg.id)
|
||||
|
||||
if ids_to_delete:
|
||||
removed = await self.memory.delete(ids_to_delete)
|
||||
logger.info("Tool guard: cleaned up %d denied message(s)", removed)
|
||||
|
||||
async def _request_guard_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
tool_call_id: str,
|
||||
) -> bool:
|
||||
"""Request approval for a guarded tool call.
|
||||
|
||||
This method creates a ToolApprovalRequest and waits for
|
||||
external approval via approve_guard_call() or deny_guard_call().
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_input: Tool input parameters
|
||||
tool_call_id: ID of the tool call
|
||||
|
||||
Returns:
|
||||
True if approved, False otherwise
|
||||
"""
|
||||
record = TOOL_GUARD_STORE.create_pending(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=getattr(self, "agent_id", "unknown"),
|
||||
workspace_id=getattr(self, "workspace_id", "default"),
|
||||
session_id=getattr(self, "session_id", None),
|
||||
findings=default_findings_for_tool(tool_name),
|
||||
)
|
||||
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.register_pending_approval(
|
||||
record.approval_id,
|
||||
{
|
||||
"tool_name": record.tool_name,
|
||||
"agent_id": record.agent_id,
|
||||
"workspace_id": record.workspace_id,
|
||||
"session_id": record.session_id,
|
||||
"tool_input": record.tool_input,
|
||||
},
|
||||
)
|
||||
|
||||
self._pending_approval = ToolApprovalRequest(
|
||||
approval_id=record.approval_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=getattr(self, "session_id", None),
|
||||
)
|
||||
record.pending_request = self._pending_approval
|
||||
|
||||
# Notify via callback if set
|
||||
if self._approval_callback:
|
||||
self._approval_callback(self._pending_approval)
|
||||
|
||||
# Wait for approval
|
||||
approval_request = self._pending_approval
|
||||
approved = await approval_request.wait_for_approval(
|
||||
timeout=self._approval_timeout
|
||||
)
|
||||
|
||||
if approval_request:
|
||||
status = (
|
||||
ApprovalStatus.APPROVED
|
||||
if approval_request.approved is True
|
||||
else ApprovalStatus.DENIED
|
||||
if approval_request.approved is False
|
||||
else ApprovalStatus.EXPIRED
|
||||
)
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
approval_request.approval_id,
|
||||
status,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
approval_request.approval_id,
|
||||
resolved_by="agent",
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
self._pending_approval = None
|
||||
return approved
|
||||
|
||||
def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Approve a pending guard request.
|
||||
|
||||
This method is called externally to approve a tool call
|
||||
that is waiting for approval.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID to verify (not yet implemented)
|
||||
|
||||
Returns:
|
||||
True if a request was approved, False if no pending request
|
||||
"""
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to approve")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.APPROVED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.APPROVED.value,
|
||||
)
|
||||
self._pending_approval.approve()
|
||||
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Deny a pending guard request.
|
||||
|
||||
This method is called externally to deny a tool call
|
||||
that is waiting for approval.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID to verify (not yet implemented)
|
||||
|
||||
Returns:
|
||||
True if a request was denied, False if no pending request
|
||||
"""
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to deny")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.DENIED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.DENIED.value,
|
||||
)
|
||||
self._pending_approval.deny()
|
||||
logger.info("Denied tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
async def _acting(self, tool_call) -> dict | None:
|
||||
"""Intercept sensitive tool calls before execution.
|
||||
|
||||
1. If tool is in denied_tools, auto-deny unconditionally.
|
||||
2. Check for a one-shot pre-approval.
|
||||
3. If tool is in the guarded scope, request approval.
|
||||
4. Otherwise, delegate to parent _acting.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
|
||||
Returns:
|
||||
Tool result dict or None
|
||||
"""
|
||||
tool_name: str = tool_call.get("name", "")
|
||||
tool_input: dict = tool_call.get("input", {})
|
||||
tool_call_id: str = tool_call.get("id", "")
|
||||
|
||||
# Check if tool is denied
|
||||
if tool_name and self._is_tool_denied(tool_name):
|
||||
logger.warning("Tool '%s' is in the denied set, auto-denying", tool_name)
|
||||
return await self._acting_auto_denied(tool_call, tool_name)
|
||||
|
||||
# Check if tool is guarded
|
||||
if tool_name and self._is_tool_guarded(tool_name):
|
||||
approved = await self._request_guard_approval(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
if not approved:
|
||||
return await self._acting_with_denial(tool_call, tool_name)
|
||||
|
||||
# Call parent _acting
|
||||
return await super()._acting(tool_call) # type: ignore[misc]
|
||||
|
||||
async def _acting_auto_denied(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> dict | None:
|
||||
"""Auto-deny a tool call without offering approval.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
tool_name: Name of the denied tool
|
||||
|
||||
Returns:
|
||||
Denial result
|
||||
"""
|
||||
from agentscope.message import ToolResultBlock
|
||||
|
||||
denied_text = (
|
||||
f"⛔ **Tool Blocked / 工具已拦截**\n\n"
|
||||
f"- Tool / 工具: `{tool_name}`\n"
|
||||
f"- Reason / 原因: This tool is blocked for security reasons\n\n"
|
||||
f"This tool is blocked and cannot be approved.\n"
|
||||
f"该工具已被禁止,无法批准执行。"
|
||||
)
|
||||
|
||||
tool_res_msg = Msg(
|
||||
"system",
|
||||
[
|
||||
ToolResultBlock(
|
||||
type="tool_result",
|
||||
id=tool_call.get("id", ""),
|
||||
name=tool_name,
|
||||
output=[{"type": "text", "text": denied_text}],
|
||||
),
|
||||
],
|
||||
"system",
|
||||
)
|
||||
|
||||
await self.print(tool_res_msg, True)
|
||||
await self.memory.add(tool_res_msg)
|
||||
return None
|
||||
|
||||
async def _acting_with_denial(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> dict | None:
|
||||
"""Deny the tool call after approval was rejected.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Denial result
|
||||
"""
|
||||
from agentscope.message import ToolResultBlock
|
||||
|
||||
params_text = json.dumps(
|
||||
tool_call.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
denied_text = (
|
||||
f"⚠️ **Tool Call Denied / 工具调用被拒绝**\n\n"
|
||||
f"- Tool / 工具: `{tool_name}`\n"
|
||||
f"- Parameters / 参数:\n"
|
||||
f"```json\n{params_text}\n```\n\n"
|
||||
f"The tool call was denied by the user or timed out.\n"
|
||||
f"工具调用被用户拒绝或已超时。"
|
||||
)
|
||||
|
||||
tool_res_msg = Msg(
|
||||
"system",
|
||||
[
|
||||
ToolResultBlock(
|
||||
type="tool_result",
|
||||
id=tool_call.get("id", ""),
|
||||
name=tool_name,
|
||||
output=[{"type": "text", "text": denied_text}],
|
||||
),
|
||||
],
|
||||
"system",
|
||||
)
|
||||
|
||||
await self.print(tool_res_msg, True)
|
||||
await self.memory.add(tool_res_msg, marks=TOOL_GUARD_DENIED_MARK)
|
||||
return None
|
||||
|
||||
async def _reasoning(self, **kwargs) -> Msg:
|
||||
"""Short-circuit reasoning when awaiting guard approval.
|
||||
|
||||
If the last message was a guard denial, return a waiting message
|
||||
instead of continuing reasoning.
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
if self._last_tool_response_is_denied():
|
||||
msg = Msg(
|
||||
self.name,
|
||||
"⏳ Waiting for approval / 等待审批...\n\n"
|
||||
"Type `/approve` to approve, or send any message to deny.\n"
|
||||
"输入 `/approve` 批准执行,或发送任意消息拒绝。",
|
||||
"assistant",
|
||||
)
|
||||
await self.print(msg, True)
|
||||
await self.memory.add(msg)
|
||||
return msg
|
||||
|
||||
return await super()._reasoning(**kwargs) # type: ignore[misc]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolGuardMixin",
|
||||
"ToolApprovalRequest",
|
||||
"DEFAULT_GUARDED_TOOLS",
|
||||
"DEFAULT_DENIED_TOOLS",
|
||||
"TOOL_GUARD_DENIED_MARK",
|
||||
]
|
||||
Reference in New Issue
Block a user