# -*- coding: utf-8 -*-
import os
import os.path as osp
import json
import asyncio
from agentscope.message import TextBlock
from agentscope.tool import ToolResponse
from .op_manager.op_retrieval import retrieve_ops
# Load tool information for formatting
TOOLS_INFO_PATH = osp.join(
osp.dirname(__file__),
"op_manager",
"dj_funcs_all.json",
)
def _load_tools_info():
"""Load tools information from JSON file or create it if not exists"""
if osp.exists(TOOLS_INFO_PATH):
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
return json.loads(f.read())
else:
from .op_manager.create_dj_func_info import dj_func_info
with open(TOOLS_INFO_PATH, "w", encoding="utf-8") as f:
json.dump(dj_func_info, f)
return dj_func_info
def _format_tool_names_to_class_entries(tool_names):
"""Convert tool names list to formatted class entries string"""
if not tool_names:
return ""
tools_info = _load_tools_info()
# Create a mapping from class_name to tool info for quick lookup
tools_map = {tool["class_name"]: tool for tool in tools_info}
formatted_entries = []
for i, tool_name in enumerate(tool_names):
if tool_name in tools_map:
tool_info = tools_map[tool_name]
class_entry = (
f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
)
class_entry += "\n" + tool_info["arguments"]
formatted_entries.append(class_entry)
return "\n".join(formatted_entries)
async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
"""Query DataJuicer operators by natural language description.
Retrieves relevant operators from DataJuicer library based on user query.
Supports matching by functionality, data type, and processing scenarios.
Args:
query (str): Natural language operator query
limit (int): Maximum number of operators to return (default: 20)
Returns:
ToolResponse: Tool response containing matched operators with names,
descriptions, and parameters
"""
try:
# Retrieve operator names using existing functionality with limit
# Use retrieval mode from environment variable if set
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
tool_names = await retrieve_ops(
query,
limit=limit,
mode=retrieval_mode,
)
if not tool_names:
return ToolResponse(
content=[
TextBlock(
type="text",
text="No matching DataJuicer operators found for "
f"query: {query}\n"
"Suggestions:\n"
"1. Use more specific keywords like 'text filter', "
"'image processing'\n"
"2. Check spelling and try alternative terms\n"
"3. Try English keywords for better matching",
),
],
)
# Format tool names to class entries
retrieved_operators = _format_tool_names_to_class_entries(tool_names)
# Format response
result_text = "🔍 DataJuicer Operator Query Results\n"
result_text += f"Query: {query}\n"
result_text += f"Limit: {limit} operators\n"
result_text += f"{'='*50}\n\n"
result_text += retrieved_operators
return ToolResponse(
content=[
TextBlock(
type="text",
text=result_text,
),
],
)
except Exception as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Error querying DataJuicer operators: {str(e)}\n"
f"Please verify query parameters and retry.",
),
],
)
async def execute_safe_command(
command: str,
timeout: int = 300,
) -> ToolResponse:
"""Execute safe commands including DataJuicer commands and other safe
system commands.
Returns the return code, standard output and error within
,
and tags.
Args:
command (`str`):
The command to execute. Allowed commands include:
- DataJuicer commands: dj-process, dj-analyze
- File system commands: mkdir, ls, pwd, cat, echo, cp, mv, rm
- Text processing: grep, head, tail, wc, sort, uniq
- Archive commands: tar, zip, unzip
- Other safe commands: which, whoami, date, find
timeout (`float`, defaults to `300`):
The maximum time (in seconds) allowed for the command to run.
Returns:
`ToolResponse`:
The tool response containing the return code, standard output, and
standard error of the executed command.
"""
# Security check: only allow safe commands
command_stripped = command.strip()
# Define allowed command prefixes for security
allowed_commands = [
# DataJuicer commands
"dj-process",
"dj-analyze",
# File system operations
"mkdir",
"ls",
"pwd",
"cat",
"echo",
"cp",
"mv",
"rm",
# Text processing
"grep",
"head",
"tail",
"wc",
"sort",
"uniq",
# Archive operations
"tar",
"zip",
"unzip",
# Information commands
"which",
"whoami",
"date",
"find",
# Python commands
"python",
"python3",
"pip",
"uv",
]
# Check if command starts with any allowed command
command_allowed = False
for allowed_cmd in allowed_commands:
if command_stripped.startswith(allowed_cmd):
# Additional security checks for potentially dangerous commands
if allowed_cmd in ["rm", "mv"] and (
"/" in command_stripped or ".." in command_stripped
):
# Prevent dangerous path operations
continue
command_allowed = True
break
if not command_allowed:
error_msg = (
"Error: Command not allowed for security reasons. "
"Allowed commands: "
f"{', '.join(allowed_commands)}. "
f"Received command: {command}"
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"-1"
f""
f"{error_msg}"
),
),
],
)
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
bufsize=0,
)
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
stdout, stderr = await proc.communicate()
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
returncode = proc.returncode
except asyncio.TimeoutError:
stderr_suffix = (
"TimeoutError: The command execution exceeded "
f"the timeout of {timeout} seconds."
)
returncode = -1
try:
proc.terminate()
stdout, stderr = await proc.communicate()
stdout_str = stdout.decode("utf-8")
stderr_str = stderr.decode("utf-8")
if stderr_str:
stderr_str += f"\n{stderr_suffix}"
else:
stderr_str = stderr_suffix
except ProcessLookupError:
stdout_str = ""
stderr_str = stderr_suffix
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"{returncode}"
f"{stdout_str}"
f"{stderr_str}"
),
),
],
)