# -*- coding: utf-8 -*- """ToolGuardMixin - Security interception for dangerous tool calls. Provides ``_acting`` and ``_reasoning`` overrides that intercept sensitive tool calls before execution, implementing the deny / guard / approve flow. Based on CoPaw's tool_guard_mixin.py design. """ from __future__ import annotations import asyncio import json import logging from datetime import UTC, datetime from enum import Enum from typing import Any, Callable, Dict, Iterable, List, Optional, Set from agentscope.message import Msg from backend.runtime.manager import get_global_runtime_manager logger = logging.getLogger(__name__) class SeverityLevel(str, Enum): """Risk severity level.""" LOW = "low" MEDIUM = "medium" HIGH = "high" CRITICAL = "critical" class ApprovalStatus(str, Enum): """Approval lifecycle state.""" PENDING = "pending" APPROVED = "approved" DENIED = "denied" EXPIRED = "expired" class ToolFindingRecord: """Internal representation of a guard finding.""" def __init__(self, severity: SeverityLevel, message: str, field: Optional[str] = None) -> None: self.severity = severity self.message = message self.field = field def to_dict(self) -> Dict[str, Any]: return { "severity": self.severity.value, "message": self.message, "field": self.field, } class ApprovalRecord: """Stores the state of an approval request.""" def __init__( self, approval_id: str, tool_name: str, tool_input: Dict[str, Any], agent_id: str, workspace_id: str, session_id: Optional[str] = None, findings: Optional[List[ToolFindingRecord]] = None, ) -> None: self.approval_id = approval_id self.tool_name = tool_name self.tool_input = tool_input self.agent_id = agent_id # run_id is the new preferred name; workspace_id is kept for backward compatibility self.run_id = workspace_id self.workspace_id = workspace_id self.session_id = session_id self.status = ApprovalStatus.PENDING self.findings = findings or [] self.created_at = datetime.now(UTC) self.resolved_at: Optional[datetime] = None self.resolved_by: Optional[str] = None self.metadata: Dict[str, Any] = {} self.pending_request: "ToolApprovalRequest" | None = None def to_dict(self) -> Dict[str, Any]: return { "approval_id": self.approval_id, "status": self.status.value, "tool_name": self.tool_name, "tool_input": self.tool_input, "agent_id": self.agent_id, "run_id": self.run_id, "workspace_id": self.workspace_id, "session_id": self.session_id, "findings": [f.to_dict() for f in self.findings], "created_at": self.created_at.isoformat(), "resolved_at": self.resolved_at.isoformat() if self.resolved_at else None, "resolved_by": self.resolved_by, } class ToolGuardStore: """Simple in-memory approval store for development/testing.""" def __init__(self) -> None: self._records: Dict[str, ApprovalRecord] = {} self._counter = 0 def next_id(self) -> str: self._counter += 1 return f"approval_{self._counter:06d}" def list( self, status: ApprovalStatus | None = None, workspace_id: Optional[str] = None, agent_id: Optional[str] = None, ) -> Iterable[ApprovalRecord]: for record in self._records.values(): if status and record.status != status: continue if workspace_id and record.workspace_id != workspace_id: continue if agent_id and record.agent_id != agent_id: continue yield record def get(self, approval_id: str) -> Optional[ApprovalRecord]: return self._records.get(approval_id) def create_pending( self, tool_name: str, tool_input: Dict[str, Any], agent_id: str, workspace_id: str, session_id: Optional[str] = None, findings: Optional[List[ToolFindingRecord]] = None, ) -> ApprovalRecord: record = ApprovalRecord( approval_id=self.next_id(), tool_name=tool_name, tool_input=tool_input, agent_id=agent_id, workspace_id=workspace_id, session_id=session_id, findings=findings, ) self._records[record.approval_id] = record return record def set_status( self, approval_id: str, status: ApprovalStatus, resolved_by: Optional[str] = None, notify_request: bool = True, ) -> ApprovalRecord: record = self._records[approval_id] if record.status == status: return record record.status = status record.resolved_at = datetime.now(UTC) record.resolved_by = resolved_by if notify_request and record.pending_request: if status == ApprovalStatus.APPROVED: record.pending_request.approve() elif status == ApprovalStatus.DENIED: record.pending_request.deny() return record def cancel(self, approval_id: str) -> None: self._records.pop(approval_id, None) TOOL_GUARD_STORE = ToolGuardStore() def get_tool_guard_store() -> ToolGuardStore: return TOOL_GUARD_STORE # Default tools that require approval DEFAULT_GUARDED_TOOLS: Set[str] = { "execute_shell_command", "write_file", "edit_file", "place_order", "modify_position", "delete_file", } # Default denied tools (cannot be approved) DEFAULT_DENIED_TOOLS: Set[str] = { "execute_shell_command", # Shell execution is dangerous } # Mark for tool guard denied messages TOOL_GUARD_DENIED_MARK = "tool_guard_denied" def default_findings_for_tool(tool_name: str) -> List[ToolFindingRecord]: findings: List[ToolFindingRecord] = [] if tool_name in {"execute_trade", "modify_portfolio"}: findings.append( ToolFindingRecord( severity=SeverityLevel.HIGH, message=f"Tool '{tool_name}' touches portfolio state", ) ) return findings class ToolApprovalRequest: """Represents a pending tool approval request.""" def __init__( self, approval_id: str, tool_name: str, tool_input: Dict[str, Any], tool_call_id: str, session_id: Optional[str] = None, ): self.approval_id = approval_id self.tool_name = tool_name self.tool_input = tool_input self.tool_call_id = tool_call_id self.session_id = session_id self.approved: Optional[bool] = None self._event = asyncio.Event() async def wait_for_approval(self, timeout: Optional[float] = None) -> bool: """Wait for approval decision. Args: timeout: Maximum time to wait in seconds Returns: True if approved, False otherwise """ try: await asyncio.wait_for(self._event.wait(), timeout=timeout) except asyncio.TimeoutError: return False return self.approved is True def approve(self) -> None: """Approve this request.""" self.approved = True self._event.set() def deny(self) -> None: """Deny this request.""" self.approved = False self._event.set() class ToolGuardMixin: """Mixin that adds tool-guard interception to a ReActAgent. At runtime this class is combined with ReActAgent via MRO, so ``super()._acting`` and ``super()._reasoning`` resolve to the concrete agent methods. Usage: class MyAgent(ToolGuardMixin, ReActAgent): def __init__(self, ...): super().__init__(...) self._init_tool_guard() """ def _init_tool_guard( self, guarded_tools: Optional[Set[str]] = None, denied_tools: Optional[Set[str]] = None, approval_timeout: float = 300.0, ) -> None: """Initialize tool guard. Args: guarded_tools: Set of tool names requiring approval denied_tools: Set of tool names that are always denied approval_timeout: Timeout for approval requests in seconds """ self._guarded_tools = guarded_tools or DEFAULT_GUARDED_TOOLS.copy() self._denied_tools = denied_tools or DEFAULT_DENIED_TOOLS.copy() self._approval_timeout = approval_timeout self._pending_approval: Optional[ToolApprovalRequest] = None self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None self._approval_lock = asyncio.Lock() def set_approval_callback( self, callback: Callable[[ToolApprovalRequest], None], ) -> None: """Set callback for approval requests. Args: callback: Function called when approval is needed """ self._approval_callback = callback def _is_tool_guarded(self, tool_name: str) -> bool: """Check if a tool requires approval. Args: tool_name: Name of the tool Returns: True if tool requires approval """ return tool_name in self._guarded_tools def _is_tool_denied(self, tool_name: str) -> bool: """Check if a tool is always denied. Args: tool_name: Name of the tool Returns: True if tool is denied """ return tool_name in self._denied_tools def _last_tool_response_is_denied(self) -> bool: """Check if the last message is a guard-denied tool result.""" if not hasattr(self, "memory") or not self.memory.content: return False msg, marks = self.memory.content[-1] return TOOL_GUARD_DENIED_MARK in marks and msg.role == "system" async def _cleanup_tool_guard_denied_messages( self, include_denial_response: bool = True, ) -> None: """Remove tool-guard denied messages from memory. Args: include_denial_response: Also remove the assistant's denial explanation """ if not hasattr(self, "memory"): return ids_to_delete: list[str] = [] last_marked_idx = -1 for i, (msg, marks) in enumerate(self.memory.content): if TOOL_GUARD_DENIED_MARK in marks: ids_to_delete.append(msg.id) last_marked_idx = i if ( include_denial_response and last_marked_idx >= 0 and last_marked_idx + 1 < len(self.memory.content) ): next_msg, _ = self.memory.content[last_marked_idx + 1] if next_msg.role == "assistant": ids_to_delete.append(next_msg.id) if ids_to_delete: removed = await self.memory.delete(ids_to_delete) logger.info("Tool guard: cleaned up %d denied message(s)", removed) async def _request_guard_approval( self, tool_name: str, tool_input: Dict[str, Any], tool_call_id: str, ) -> bool: """Request approval for a guarded tool call. This method creates a ToolApprovalRequest and waits for external approval via approve_guard_call() or deny_guard_call(). Args: tool_name: Name of the tool tool_input: Tool input parameters tool_call_id: ID of the tool call Returns: True if approved, False otherwise """ async with self._approval_lock: record = TOOL_GUARD_STORE.create_pending( tool_name=tool_name, tool_input=tool_input, agent_id=getattr(self, "agent_id", "unknown"), workspace_id=getattr(self, "workspace_id", "default"), session_id=getattr(self, "session_id", None), findings=default_findings_for_tool(tool_name), ) manager = get_global_runtime_manager() approval_data = { "tool_name": record.tool_name, "agent_id": record.agent_id, "workspace_id": record.workspace_id, "session_id": record.session_id, "tool_input": record.tool_input, } if manager: manager.register_pending_approval( record.approval_id, approval_data, ) # Broadcast WebSocket event for real-time UI updates try: if hasattr(manager, 'broadcast_event'): await manager.broadcast_event({ "type": "approval_requested", "approval_id": record.approval_id, "agent_id": record.agent_id, "tool_name": record.tool_name, "timestamp": record.created_at.isoformat(), "data": approval_data, }) except Exception as e: logger.warning(f"Failed to broadcast approval event: {e}") self._pending_approval = ToolApprovalRequest( approval_id=record.approval_id, tool_name=tool_name, tool_input=tool_input, tool_call_id=tool_call_id, session_id=getattr(self, "session_id", None), ) record.pending_request = self._pending_approval # Notify via callback if set if self._approval_callback: self._approval_callback(self._pending_approval) # Wait for approval (lock is released during wait, re-acquired after) approval_request = self._pending_approval # Wait for approval outside the lock to allow concurrent approval approved = await approval_request.wait_for_approval( timeout=self._approval_timeout ) async with self._approval_lock: if approval_request: status = ( ApprovalStatus.APPROVED if approval_request.approved is True else ApprovalStatus.DENIED if approval_request.approved is False else ApprovalStatus.EXPIRED ) TOOL_GUARD_STORE.set_status( approval_request.approval_id, status, resolved_by="agent", notify_request=False, ) manager = get_global_runtime_manager() if manager: manager.resolve_pending_approval( approval_request.approval_id, resolved_by="agent", status=status.value, ) # Only clear if this is still the same request if self._pending_approval is approval_request: self._pending_approval = None return approved async def approve_guard_call(self, request_id: Optional[str] = None) -> bool: """Approve a pending guard request. This method is called externally to approve a tool call that is waiting for approval. Args: request_id: Optional request ID to verify (not yet implemented) Returns: True if a request was approved, False if no pending request """ async with self._approval_lock: if self._pending_approval is None: logger.warning("No pending approval request to approve") return False TOOL_GUARD_STORE.set_status( self._pending_approval.approval_id, ApprovalStatus.APPROVED, resolved_by="agent", notify_request=False, ) manager = get_global_runtime_manager() if manager: manager.resolve_pending_approval( self._pending_approval.approval_id, resolved_by="agent", status=ApprovalStatus.APPROVED.value, ) self._pending_approval.approve() logger.info("Approved tool call: %s", self._pending_approval.tool_name) return True async def deny_guard_call(self, request_id: Optional[str] = None) -> bool: """Deny a pending guard request. This method is called externally to deny a tool call that is waiting for approval. Args: request_id: Optional request ID to verify (not yet implemented) Returns: True if a request was denied, False if no pending request """ async with self._approval_lock: if self._pending_approval is None: logger.warning("No pending approval request to deny") return False TOOL_GUARD_STORE.set_status( self._pending_approval.approval_id, ApprovalStatus.DENIED, resolved_by="agent", notify_request=False, ) manager = get_global_runtime_manager() if manager: manager.resolve_pending_approval( self._pending_approval.approval_id, resolved_by="agent", status=ApprovalStatus.DENIED.value, ) self._pending_approval.deny() logger.info("Denied tool call: %s", self._pending_approval.tool_name) return True async def _acting(self, tool_call) -> dict | None: """Intercept sensitive tool calls before execution. 1. If tool is in denied_tools, auto-deny unconditionally. 2. Check for a one-shot pre-approval. 3. If tool is in the guarded scope, request approval. 4. Otherwise, delegate to parent _acting. Args: tool_call: Tool call from the model Returns: Tool result dict or None """ tool_name: str = tool_call.get("name", "") tool_input: dict = tool_call.get("input", {}) tool_call_id: str = tool_call.get("id", "") # Check if tool is denied if tool_name and self._is_tool_denied(tool_name): logger.warning("Tool '%s' is in the denied set, auto-denying", tool_name) return await self._acting_auto_denied(tool_call, tool_name) # Check if tool is guarded if tool_name and self._is_tool_guarded(tool_name): approved = await self._request_guard_approval( tool_name=tool_name, tool_input=tool_input, tool_call_id=tool_call_id, ) if not approved: return await self._acting_with_denial(tool_call, tool_name) # Call parent _acting return await super()._acting(tool_call) # type: ignore[misc] async def _acting_auto_denied( self, tool_call: Dict[str, Any], tool_name: str, ) -> dict | None: """Auto-deny a tool call without offering approval. Args: tool_call: Tool call from the model tool_name: Name of the denied tool Returns: Denial result """ from agentscope.message import ToolResultBlock denied_text = ( f"⛔ **Tool Blocked / 工具已拦截**\n\n" f"- Tool / 工具: `{tool_name}`\n" f"- Reason / 原因: This tool is blocked for security reasons\n\n" f"This tool is blocked and cannot be approved.\n" f"该工具已被禁止,无法批准执行。" ) tool_res_msg = Msg( "system", [ ToolResultBlock( type="tool_result", id=tool_call.get("id", ""), name=tool_name, output=[{"type": "text", "text": denied_text}], ), ], "system", ) await self.print(tool_res_msg, True) await self.memory.add(tool_res_msg) return None async def _acting_with_denial( self, tool_call: Dict[str, Any], tool_name: str, ) -> dict | None: """Deny the tool call after approval was rejected. Args: tool_call: Tool call from the model tool_name: Name of the tool Returns: Denial result """ from agentscope.message import ToolResultBlock params_text = json.dumps( tool_call.get("input", {}), ensure_ascii=False, indent=2, ) denied_text = ( f"⚠️ **Tool Call Denied / 工具调用被拒绝**\n\n" f"- Tool / 工具: `{tool_name}`\n" f"- Parameters / 参数:\n" f"```json\n{params_text}\n```\n\n" f"The tool call was denied by the user or timed out.\n" f"工具调用被用户拒绝或已超时。" ) tool_res_msg = Msg( "system", [ ToolResultBlock( type="tool_result", id=tool_call.get("id", ""), name=tool_name, output=[{"type": "text", "text": denied_text}], ), ], "system", ) await self.print(tool_res_msg, True) await self.memory.add(tool_res_msg, marks=TOOL_GUARD_DENIED_MARK) return None async def _reasoning(self, **kwargs) -> Msg: """Short-circuit reasoning when awaiting guard approval. If the last message was a guard denial, return a waiting message instead of continuing reasoning. Returns: Response message """ if self._last_tool_response_is_denied(): msg = Msg( self.name, "⏳ Waiting for approval / 等待审批...\n\n" "Type `/approve` to approve, or send any message to deny.\n" "输入 `/approve` 批准执行,或发送任意消息拒绝。", "assistant", ) await self.print(msg, True) await self.memory.add(msg) return msg return await super()._reasoning(**kwargs) # type: ignore[misc] __all__ = [ "ToolGuardMixin", "ToolApprovalRequest", "DEFAULT_GUARDED_TOOLS", "DEFAULT_DENIED_TOOLS", "TOOL_GUARD_DENIED_MARK", ]