release datajuicer agent
This commit is contained in:
89
data_juicer_agent/tools/__init__.py
Normal file
89
data_juicer_agent/tools/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Tools package for data-agent.
|
||||
|
||||
This module provides a unified entry point for all agent tools,
|
||||
organized by agent type for easy access and management.
|
||||
"""
|
||||
import asyncio
|
||||
from typing import List
|
||||
from agentscope.agent import AgentBase
|
||||
from agentscope.tool import (
|
||||
view_text_file,
|
||||
write_text_file,
|
||||
)
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
from .dj_tools import execute_safe_command
|
||||
from .router_tools import agent_to_tool
|
||||
from .dj_tools import query_dj_operators
|
||||
from .dj_dev_tools import get_basic_files, get_operator_example, configure_data_juicer_path
|
||||
from .mcp_tools import get_mcp_toolkit
|
||||
|
||||
def create_toolkit(tools: List[str]):
|
||||
# Create toolkit and register tools
|
||||
toolkit = Toolkit()
|
||||
for tool in tools:
|
||||
toolkit.register_tool_function(tool)
|
||||
|
||||
return toolkit
|
||||
|
||||
# DJ Agent tools
|
||||
dj_tools = [
|
||||
execute_safe_command,
|
||||
view_text_file,
|
||||
write_text_file,
|
||||
query_dj_operators,
|
||||
]
|
||||
|
||||
# DJ Development Agent tools - for developing DataJuicer operators
|
||||
dj_dev_tools = [
|
||||
view_text_file,
|
||||
write_text_file,
|
||||
get_basic_files,
|
||||
get_operator_example,
|
||||
configure_data_juicer_path,
|
||||
]
|
||||
|
||||
# MCP Agent tools - for advanced data processing with Recipe Flow MCP
|
||||
mcp_tools = [
|
||||
view_text_file,
|
||||
write_text_file,
|
||||
]
|
||||
|
||||
def agents2toolkit(agents: List[AgentBase]):
|
||||
tools = [agent_to_tool(agent) for agent in agents]
|
||||
return create_toolkit(tools)
|
||||
|
||||
dj_toolkit = create_toolkit(dj_tools)
|
||||
dj_dev_toolkit = create_toolkit(dj_dev_tools)
|
||||
|
||||
|
||||
# All available tools
|
||||
all_toolkit = {
|
||||
"dj": dj_toolkit,
|
||||
"dj_dev": dj_dev_toolkit,
|
||||
"dj_mcp": get_mcp_toolkit,
|
||||
"router": agents2toolkit,
|
||||
}
|
||||
|
||||
# Public API
|
||||
__all__ = [
|
||||
"dj_tools",
|
||||
"dj_dev_tools",
|
||||
"mcp_tools",
|
||||
"all_tools",
|
||||
"agents2toolkit",
|
||||
"dj_toolkit",
|
||||
"dj_dev_toolkit",
|
||||
"get_mcp_toolkit",
|
||||
# Individual tools for direct import
|
||||
"execute_safe_command",
|
||||
"view_text_file",
|
||||
"write_text_file",
|
||||
"agent_to_tool",
|
||||
"query_dj_operators",
|
||||
"get_basic_files",
|
||||
"get_operator_example",
|
||||
"configure_data_juicer_path",
|
||||
]
|
||||
235
data_juicer_agent/tools/dj_dev_tools.py
Normal file
235
data_juicer_agent/tools/dj_dev_tools.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
DataJuicer Development Tools
|
||||
|
||||
Tools for developing DataJuicer operators, including access to basic documentation
|
||||
and example code for different operator types.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from agentscope.message import TextBlock
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
# DataJuicer home path - should be configured based on your environment
|
||||
DATA_JUICER_PATH = os.getenv("DATA_JUICER_PATH", None)
|
||||
|
||||
BASIC_LIST_RELATIVE = [
|
||||
"data_juicer/ops/base_op.py",
|
||||
"docs/DeveloperGuide.md",
|
||||
"docs/DeveloperGuide_ZH.md",
|
||||
]
|
||||
|
||||
|
||||
def get_basic_files() -> ToolResponse:
|
||||
"""Get basic DataJuicer development files content.
|
||||
|
||||
Returns the content of essential files needed for DJ operator development:
|
||||
- base_op.py: Base operator class
|
||||
- DeveloperGuide.md: English developer guide
|
||||
- DeveloperGuide_ZH.md: Chinese developer guide
|
||||
|
||||
Returns:
|
||||
ToolResponse: Combined content of all basic development files
|
||||
"""
|
||||
global DATA_JUICER_PATH, BASIC_LIST_RELATIVE
|
||||
if DATA_JUICER_PATH is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
combined_content = "# DataJuicer Operator Development Basic Files\n\n"
|
||||
|
||||
for relative_path in BASIC_LIST_RELATIVE:
|
||||
file_path = os.path.join(DATA_JUICER_PATH, relative_path)
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
filename = os.path.basename(file_path)
|
||||
combined_content += f"## {filename}\n\n"
|
||||
combined_content += (
|
||||
f"```{'python' if filename.endswith('.py') else 'markdown'}\n"
|
||||
)
|
||||
combined_content += content
|
||||
combined_content += "\n```\n\n"
|
||||
except Exception as e:
|
||||
combined_content += (
|
||||
f"## {os.path.basename(file_path)} (Read Failed)\n"
|
||||
)
|
||||
combined_content += f"Error: {str(e)}\n\n"
|
||||
|
||||
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while getting basic files: {str(e)}",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
async def get_operator_example(
|
||||
requirement_description: str, limit: int = 2
|
||||
) -> ToolResponse:
|
||||
"""Get example operators based on requirement description using dynamic search.
|
||||
|
||||
Args:
|
||||
requirement_description (str): Natural language description of the operator requirement
|
||||
limit (int): Maximum number of example operators to return (default: 2)
|
||||
|
||||
Returns:
|
||||
ToolResponse: Example operator code and test files based on the requirement
|
||||
"""
|
||||
global DATA_JUICER_PATH
|
||||
if DATA_JUICER_PATH is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="DATA_JUICER_PATH is not configured. Please ask the user to provide the DATA_JUICER_PATH",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
# Import retrieve_ops from op_manager
|
||||
from .op_manager.op_retrieval import retrieve_ops
|
||||
|
||||
# Query relevant operators using the requirement description
|
||||
# Use retrieval mode from environment variable if set
|
||||
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
|
||||
tool_names = await retrieve_ops(requirement_description, limit=limit, mode=retrieval_mode)
|
||||
|
||||
if not tool_names:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"No relevant operators found for requirement: {requirement_description}\n"
|
||||
f"Please try with more specific keywords or check if DATA_JUICER_PATH is properly configured.",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
combined_content = (
|
||||
f"# Dynamic Operator Examples for: {requirement_description}\n\n"
|
||||
)
|
||||
combined_content += (
|
||||
f"Found {len(tool_names)} relevant operators (limit: {limit})\n\n"
|
||||
)
|
||||
|
||||
# Process each found operator
|
||||
for i, tool_name in enumerate(tool_names[:limit]):
|
||||
combined_content += f"## {i+1}. {tool_name}\n\n"
|
||||
|
||||
op_type = tool_name.split("_")[-1]
|
||||
|
||||
operator_path = f"data_juicer/ops/{op_type}/{tool_name}.py"
|
||||
|
||||
# Try to find operator source file
|
||||
|
||||
full_path = os.path.join(DATA_JUICER_PATH, operator_path)
|
||||
if os.path.exists(full_path):
|
||||
with open(full_path, "r", encoding="utf-8") as f:
|
||||
operator_code = f.read()
|
||||
|
||||
combined_content += f"### Source Code\n"
|
||||
combined_content += "```python\n"
|
||||
combined_content += operator_code
|
||||
combined_content += "\n```\n\n"
|
||||
else:
|
||||
combined_content += (
|
||||
f"**Note:** Source code file not found for `{tool_name}`.\n\n"
|
||||
)
|
||||
|
||||
test_path = f"tests/ops/{op_type}/test_{tool_name}.py"
|
||||
|
||||
full_test_path = os.path.join(DATA_JUICER_PATH, test_path)
|
||||
if os.path.exists(full_test_path):
|
||||
with open(full_test_path, "r", encoding="utf-8") as f:
|
||||
test_code = f.read()
|
||||
|
||||
combined_content += f"### Test Code\n"
|
||||
combined_content += f"**File Path:** `{test_path}`\n\n"
|
||||
combined_content += "```python\n"
|
||||
combined_content += test_code
|
||||
combined_content += "\n```\n\n"
|
||||
|
||||
else:
|
||||
combined_content += (
|
||||
f"**Note:** Test file not found for `{tool_name}`.\n\n"
|
||||
)
|
||||
|
||||
combined_content += "---\n\n"
|
||||
|
||||
return ToolResponse(content=[TextBlock(type="text", text=combined_content)])
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while getting operator examples: {str(e)}\n"
|
||||
f"Please check the requirement description and try again.",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def configure_data_juicer_path(data_juicer_path: str) -> ToolResponse:
|
||||
"""Configure DataJuicer path.
|
||||
If the user provides the data_juicer_path, please use this method to configure it.
|
||||
|
||||
Args:
|
||||
data_juicer_path (str): Path to DataJuicer installation
|
||||
|
||||
Returns:
|
||||
ToolResponse: Configuration result
|
||||
"""
|
||||
global DATA_JUICER_PATH
|
||||
|
||||
data_juicer_path = os.path.expanduser(data_juicer_path)
|
||||
|
||||
try:
|
||||
if not os.path.exists(data_juicer_path):
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Specified DataJuicer path does not exist: {data_juicer_path}",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Update global DATA_JUICER_PATH
|
||||
DATA_JUICER_PATH = data_juicer_path
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"DataJuicer path has been updated to: {DATA_JUICER_PATH}",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error occurred while configuring DataJuicer path: {str(e)}",
|
||||
)
|
||||
]
|
||||
)
|
||||
224
data_juicer_agent/tools/dj_tools.py
Normal file
224
data_juicer_agent/tools/dj_tools.py
Normal file
@@ -0,0 +1,224 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from agentscope.message import TextBlock
|
||||
from agentscope.tool import ToolResponse
|
||||
from .op_manager.op_retrieval import retrieve_ops
|
||||
|
||||
# Load tool information for formatting
|
||||
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "op_manager", "dj_funcs_all.json")
|
||||
|
||||
def _load_tools_info():
|
||||
"""Load tools information from JSON file or create it if not exists"""
|
||||
if osp.exists(TOOLS_INFO_PATH):
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
return json.loads(f.read())
|
||||
else:
|
||||
from .op_manager.create_dj_func_info import dj_func_info
|
||||
with open(TOOLS_INFO_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(dj_func_info, f)
|
||||
return dj_func_info
|
||||
|
||||
def _format_tool_names_to_class_entries(tool_names):
|
||||
"""Convert tool names list to formatted class entries string"""
|
||||
if not tool_names:
|
||||
return ""
|
||||
|
||||
tools_info = _load_tools_info()
|
||||
|
||||
# Create a mapping from class_name to tool info for quick lookup
|
||||
tools_map = {tool['class_name']: tool for tool in tools_info}
|
||||
|
||||
formatted_entries = []
|
||||
for i, tool_name in enumerate(tool_names):
|
||||
if tool_name in tools_map:
|
||||
tool_info = tools_map[tool_name]
|
||||
class_entry = f"{i+1}. {tool_info['class_name']}: {tool_info['class_desc']}"
|
||||
class_entry += "\n" + tool_info["arguments"]
|
||||
formatted_entries.append(class_entry)
|
||||
|
||||
return "\n".join(formatted_entries)
|
||||
|
||||
async def query_dj_operators(query: str, limit: int = 20) -> ToolResponse:
|
||||
"""Query DataJuicer operators by natural language description.
|
||||
|
||||
Retrieves relevant operators from DataJuicer library based on user query.
|
||||
Supports matching by functionality, data type, and processing scenarios.
|
||||
|
||||
Args:
|
||||
query (str): Natural language operator query
|
||||
limit (int): Maximum number of operators to return (default: 20)
|
||||
|
||||
Returns:
|
||||
ToolResponse: Tool response containing matched operators with names, descriptions, and parameters
|
||||
"""
|
||||
|
||||
try:
|
||||
# Retrieve operator names using existing functionality with limit
|
||||
# Use retrieval mode from environment variable if set
|
||||
retrieval_mode = os.environ.get("RETRIEVAL_MODE", "auto")
|
||||
tool_names = await retrieve_ops(query, limit=limit, mode=retrieval_mode)
|
||||
|
||||
if not tool_names:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"No matching DataJuicer operators found for query: {query}\n"
|
||||
f"Suggestions:\n"
|
||||
f"1. Use more specific keywords like 'text filter', 'image processing'\n"
|
||||
f"2. Check spelling and try alternative terms\n"
|
||||
f"3. Try English keywords for better matching",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Format tool names to class entries
|
||||
retrieved_operators = _format_tool_names_to_class_entries(tool_names)
|
||||
|
||||
# Format response
|
||||
result_text = f"🔍 DataJuicer Operator Query Results\n"
|
||||
result_text += f"Query: {query}\n"
|
||||
result_text += f"Limit: {limit} operators\n"
|
||||
result_text += f"{'='*50}\n\n"
|
||||
result_text += retrieved_operators
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=result_text,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Error querying DataJuicer operators: {str(e)}\n"
|
||||
f"Please verify query parameters and retry.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
async def execute_safe_command(
|
||||
command: str,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
"""Execute safe commands including DataJuicer commands and other safe system commands.
|
||||
Returns the return code, standard output and error within <returncode></returncode>,
|
||||
<stdout></stdout> and <stderr></stderr> tags.
|
||||
|
||||
Args:
|
||||
command (`str`):
|
||||
The command to execute. Allowed commands include:
|
||||
- DataJuicer commands: dj-process, dj-analyze
|
||||
- File system commands: mkdir, ls, pwd, cat, echo, cp, mv, rm
|
||||
- Text processing: grep, head, tail, wc, sort, uniq
|
||||
- Archive commands: tar, zip, unzip
|
||||
- Other safe commands: which, whoami, date, find
|
||||
timeout (`float`, defaults to `300`):
|
||||
The maximum time (in seconds) allowed for the command to run.
|
||||
|
||||
Returns:
|
||||
`ToolResponse`:
|
||||
The tool response containing the return code, standard output, and
|
||||
standard error of the executed command.
|
||||
"""
|
||||
|
||||
# Security check: only allow safe commands
|
||||
command_stripped = command.strip()
|
||||
|
||||
# Define allowed command prefixes for security
|
||||
allowed_commands = [
|
||||
# DataJuicer commands
|
||||
'dj-process', 'dj-analyze',
|
||||
# File system operations
|
||||
'mkdir', 'ls', 'pwd', 'cat', 'echo', 'cp', 'mv', 'rm',
|
||||
# Text processing
|
||||
'grep', 'head', 'tail', 'wc', 'sort', 'uniq',
|
||||
# Archive operations
|
||||
'tar', 'zip', 'unzip',
|
||||
# Information commands
|
||||
'which', 'whoami', 'date', 'find',
|
||||
# Python commands
|
||||
'python', 'python3', 'pip', 'uv'
|
||||
]
|
||||
|
||||
# Check if command starts with any allowed command
|
||||
command_allowed = False
|
||||
for allowed_cmd in allowed_commands:
|
||||
if command_stripped.startswith(allowed_cmd):
|
||||
# Additional security checks for potentially dangerous commands
|
||||
if allowed_cmd in ['rm', 'mv'] and ('/' in command_stripped or '..' in command_stripped):
|
||||
# Prevent dangerous path operations
|
||||
continue
|
||||
command_allowed = True
|
||||
break
|
||||
|
||||
if not command_allowed:
|
||||
error_msg = f"Error: Command not allowed for security reasons. Allowed commands: {', '.join(allowed_commands)}. Received command: {command}"
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"<returncode>-1</returncode>"
|
||||
f"<stdout></stdout>"
|
||||
f"<stderr>{error_msg}</stderr>"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
bufsize=0,
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(proc.wait(), timeout=timeout)
|
||||
stdout, stderr = await proc.communicate()
|
||||
stdout_str = stdout.decode("utf-8")
|
||||
stderr_str = stderr.decode("utf-8")
|
||||
returncode = proc.returncode
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
stderr_suffix = (
|
||||
f"TimeoutError: The command execution exceeded "
|
||||
f"the timeout of {timeout} seconds."
|
||||
)
|
||||
returncode = -1
|
||||
try:
|
||||
proc.terminate()
|
||||
stdout, stderr = await proc.communicate()
|
||||
stdout_str = stdout.decode("utf-8")
|
||||
stderr_str = stderr.decode("utf-8")
|
||||
if stderr_str:
|
||||
stderr_str += f"\n{stderr_suffix}"
|
||||
else:
|
||||
stderr_str = stderr_suffix
|
||||
except ProcessLookupError:
|
||||
stdout_str = ""
|
||||
stderr_str = stderr_suffix
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"<returncode>{returncode}</returncode>"
|
||||
f"<stdout>{stdout_str}</stdout>"
|
||||
f"<stderr>{stderr_str}</stderr>"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
120
data_juicer_agent/tools/mcp_tools.py
Normal file
120
data_juicer_agent/tools/mcp_tools.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
import string
|
||||
|
||||
from agentscope.tool import Toolkit
|
||||
from agentscope.mcp import HttpStatefulClient, HttpStatelessClient, StdIOStatefulClient
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
root_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
def _load_config(config_path: str) -> dict:
|
||||
"""Load MCP configuration from file"""
|
||||
try:
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"Loaded MCP configuration from {config_path}")
|
||||
return config
|
||||
else:
|
||||
logger.warning(
|
||||
f"Configuration file {config_path} not found, using default settings"
|
||||
)
|
||||
return _create_default_config()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading configuration: {e}")
|
||||
return _create_default_config()
|
||||
|
||||
def _create_default_config() -> dict:
|
||||
"""Create default configuration"""
|
||||
return {
|
||||
"mcpServers": {
|
||||
"dj_recipe_flow": {
|
||||
"command": "python",
|
||||
"args": ["/home/test/data_juicer/tools/DJ_mcp_recipe_flow.py"],
|
||||
"env": {"SERVER_TRANSPORT": "stdio"},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _expand_env_vars(value: str) -> str:
|
||||
"""Expand environment variables in configuration values"""
|
||||
if isinstance(value, str):
|
||||
template = string.Template(value)
|
||||
try:
|
||||
return template.substitute(os.environ)
|
||||
except KeyError as e:
|
||||
logger.warning(f"Environment variable not found: {e}")
|
||||
return value
|
||||
return value
|
||||
|
||||
async def _create_clients(config: dict, toolkit: Toolkit):
|
||||
"""Create MCP clients based on configuration"""
|
||||
server_configs = config.get("mcpServers", {})
|
||||
clients = []
|
||||
|
||||
for server_name, server_config in server_configs.items():
|
||||
try:
|
||||
# Handle StdIO client
|
||||
if "command" in server_config:
|
||||
command = server_config["command"]
|
||||
args = server_config.get("args", [])
|
||||
env = server_config.get("env", {})
|
||||
|
||||
# Expand environment variables
|
||||
expanded_args = [_expand_env_vars(arg) for arg in args]
|
||||
expanded_env = {k: _expand_env_vars(v) for k, v in env.items()}
|
||||
|
||||
client = StdIOStatefulClient(
|
||||
name=server_name,
|
||||
command=command,
|
||||
args=expanded_args,
|
||||
env=expanded_env,
|
||||
)
|
||||
|
||||
await client.connect()
|
||||
await toolkit.register_mcp_client(client)
|
||||
|
||||
# Handle HTTP clients
|
||||
elif "url" in server_config:
|
||||
url = _expand_env_vars(server_config["url"])
|
||||
transport = server_config.get("transport", "sse")
|
||||
stateful = server_config.get("stateful", True)
|
||||
|
||||
if stateful:
|
||||
client = HttpStatefulClient(
|
||||
name=server_name, transport=transport, url=url
|
||||
)
|
||||
await client.connect()
|
||||
await toolkit.register_mcp_client(client)
|
||||
else:
|
||||
client = HttpStatelessClient(
|
||||
name=server_name, transport=transport, url=url
|
||||
)
|
||||
await toolkit.register_mcp_client(client)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid server configuration")
|
||||
|
||||
clients.append(client)
|
||||
except Exception as e:
|
||||
if "Invalid server configuration" in str(e):
|
||||
raise e
|
||||
logger.error(f"Failed to create client {server_name}: {e}")
|
||||
|
||||
return clients
|
||||
|
||||
async def get_mcp_toolkit(config_path: Optional[str] = None) -> Toolkit:
|
||||
"""Get toolkit with all MCP tools registered"""
|
||||
config_path = config_path or root_path + "/configs/mcp_config.json"
|
||||
config = _load_config(config_path)
|
||||
toolkit = Toolkit()
|
||||
|
||||
clients = await _create_clients(config, toolkit)
|
||||
|
||||
return toolkit, clients
|
||||
0
data_juicer_agent/tools/op_manager/__init__.py
Normal file
0
data_juicer_agent/tools/op_manager/__init__.py
Normal file
34
data_juicer_agent/tools/op_manager/create_dj_func_info.py
Normal file
34
data_juicer_agent/tools/op_manager/create_dj_func_info.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import inspect
|
||||
from data_juicer.tools.op_search import OPSearcher
|
||||
|
||||
searcher = OPSearcher(include_formatter=False)
|
||||
|
||||
all_ops = searcher.search()
|
||||
|
||||
dj_func_info = []
|
||||
for i, op in enumerate(all_ops):
|
||||
class_entry = {"index": i, "class_name": op["name"], "class_desc": op["desc"]}
|
||||
param_desc = op["param_desc"]
|
||||
param_desc_map = {}
|
||||
args = ""
|
||||
for item in param_desc.split(":param"):
|
||||
_item = item.split(":")
|
||||
if len(_item) < 2:
|
||||
continue
|
||||
param_desc_map[_item[0].strip()] = ":".join(_item[1:]).strip()
|
||||
|
||||
if op["sig"]:
|
||||
for param_name, param in op["sig"].parameters.items():
|
||||
if param_name in ["self", "args", "kwargs"]:
|
||||
continue
|
||||
if param.kind in (
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
):
|
||||
continue
|
||||
if param_name in param_desc_map:
|
||||
args += f" {param_name} ({param.annotation}): {param_desc_map[param_name]}\n"
|
||||
else:
|
||||
args += f" {param_name} ({param.annotation})\n"
|
||||
class_entry["arguments"] = args
|
||||
dj_func_info.append(class_entry)
|
||||
1
data_juicer_agent/tools/op_manager/dj_funcs_all.json
Normal file
1
data_juicer_agent/tools/op_manager/dj_funcs_all.json
Normal file
File diff suppressed because one or more lines are too long
380
data_juicer_agent/tools/op_manager/op_retrieval.py
Normal file
380
data_juicer_agent/tools/op_manager/op_retrieval.py
Normal file
@@ -0,0 +1,380 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from langchain_community.vectorstores import FAISS
|
||||
|
||||
TOOLS_INFO_PATH = osp.join(osp.dirname(__file__), "dj_funcs_all.json")
|
||||
CACHE_RETRIEVED_TOOLS_PATH = osp.join(osp.dirname(__file__), "cache_retrieve")
|
||||
VECTOR_INDEX_CACHE_PATH = osp.join(osp.dirname(__file__), "vector_index_cache")
|
||||
|
||||
# Global variable to cache the vector store
|
||||
_cached_vector_store: Optional[FAISS] = None
|
||||
_cached_tools_info: Optional[list] = None
|
||||
_cached_file_hash: Optional[str] = None
|
||||
|
||||
RETRIEVAL_PROMPT = """You are a professional tool retrieval assistant responsible for filtering the top {limit} most relevant tools from a large tool library based on user requirements. Execute the following steps:
|
||||
|
||||
# Requirement Analysis
|
||||
Carefully read the user's [requirement description], extract core keywords, functional objectives, usage scenarios, and technical requirements (such as real-time performance, data types, industry domains, etc.).
|
||||
|
||||
# Tool Matching
|
||||
Perform multi-dimensional matching based on the following tool attributes:
|
||||
- Tool name and functional description
|
||||
- Supported input/output formats
|
||||
- Applicable industry or scenario tags
|
||||
- Technical implementation principles (API, local deployment, AI model types)
|
||||
- Relevance ranking
|
||||
|
||||
# Use weighted scoring mechanism (example weights):
|
||||
- Functional match (40%)
|
||||
- Scenario compatibility (30%)
|
||||
- Technical compatibility (20%)
|
||||
- User rating/usage rate (10%)
|
||||
|
||||
# Deduplication and Optimization
|
||||
Exclude the following low-quality results:
|
||||
- Tools with duplicate functionality (keep only the best one)
|
||||
- Tools that cannot meet basic requirements
|
||||
- Tools missing critical parameter descriptions
|
||||
|
||||
# Constraints
|
||||
- Strictly control output to a maximum of {limit} tools
|
||||
- Refuse to speculate on unknown tool attributes
|
||||
- Maintain accuracy of domain expertise
|
||||
|
||||
# Output Format
|
||||
Return a JSON format TOP{limit} tool list containing:
|
||||
[
|
||||
{{
|
||||
"rank": 1,
|
||||
"tool_name": "Tool Name",
|
||||
"description": "Core functionality summary",
|
||||
"relevance_score": 98.7,
|
||||
"key_match": ["Matching keywords/features"]
|
||||
}}
|
||||
]
|
||||
Output strictly in JSON array format, and only output the JSON array format tool list.
|
||||
"""
|
||||
|
||||
|
||||
def fast_text_encoder(text: str) -> str:
|
||||
"""Fast encoding using xxHash algorithm"""
|
||||
import xxhash
|
||||
|
||||
hasher = xxhash.xxh64(seed=0)
|
||||
hasher.update(text.encode("utf-8"))
|
||||
|
||||
# Return 16-bit hexadecimal string
|
||||
return hasher.hexdigest()
|
||||
|
||||
|
||||
async def retrieve_ops_lm(user_query, limit=20):
|
||||
"""Tool retrieval using language model - returns list of tool names"""
|
||||
hash_id = fast_text_encoder(user_query + str(limit))
|
||||
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(CACHE_RETRIEVED_TOOLS_PATH, exist_ok=True)
|
||||
|
||||
cache_tools_path = osp.join(CACHE_RETRIEVED_TOOLS_PATH, f"{hash_id}.json")
|
||||
if osp.exists(cache_tools_path):
|
||||
with open(cache_tools_path, "r", encoding="utf-8") as f:
|
||||
return json.loads(f.read())
|
||||
|
||||
if osp.exists(TOOLS_INFO_PATH):
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
dj_func_info = json.loads(f.read())
|
||||
tool_descriptions = [
|
||||
f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info
|
||||
]
|
||||
tools_string = "\n".join(tool_descriptions)
|
||||
else:
|
||||
from create_dj_func_info import dj_func_info
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
with open(os.path.join(project_root, TOOLS_INFO_PATH), "w") as f:
|
||||
f.write(json.dumps(dj_func_info))
|
||||
|
||||
tool_descriptions = [
|
||||
f"{t['class_name']}: {t['class_desc']}" for t in dj_func_info
|
||||
]
|
||||
tools_string = "\n".join(tool_descriptions)
|
||||
|
||||
from agentscope.model import DashScopeChatModel
|
||||
from agentscope.message import Msg
|
||||
from agentscope.formatter import DashScopeChatFormatter
|
||||
|
||||
model = DashScopeChatModel(
|
||||
model_name="qwen-turbo",
|
||||
api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
formatter = DashScopeChatFormatter()
|
||||
|
||||
# Update retrieval prompt to use the specified limit
|
||||
retrieval_prompt_with_limit = RETRIEVAL_PROMPT.format(limit=limit)
|
||||
|
||||
user_prompt = (
|
||||
retrieval_prompt_with_limit
|
||||
+ """
|
||||
User requirement description:
|
||||
{user_query}
|
||||
|
||||
Available tools:
|
||||
{tools_string}
|
||||
""".format(
|
||||
user_query=user_query, tools_string=tools_string
|
||||
)
|
||||
)
|
||||
|
||||
msgs = [
|
||||
Msg(name="user", role="user", content=user_prompt),
|
||||
]
|
||||
|
||||
formatted_msgs = await formatter.format(msgs)
|
||||
|
||||
response = await model(formatted_msgs)
|
||||
|
||||
msg = Msg(name="assistant", role="assistant", content=response.content)
|
||||
retrieved_tools_text = msg.get_text_content()
|
||||
retrieved_tools = json.loads(retrieved_tools_text)
|
||||
|
||||
# Extract tool names and validate they exist
|
||||
tool_names = []
|
||||
for tool_info in retrieved_tools:
|
||||
if not isinstance(tool_info, dict) or "tool_name" not in tool_info:
|
||||
logging.warning(f"Invalid tool info format: {tool_info}")
|
||||
continue
|
||||
|
||||
tool_name = tool_info["tool_name"]
|
||||
|
||||
# Verify tool exists in dj_func_info
|
||||
tool_exists = any(t["class_name"] == tool_name for t in dj_func_info)
|
||||
if not tool_exists:
|
||||
logging.error(f"Tool not found: `{tool_name}`, skipping!")
|
||||
continue
|
||||
|
||||
tool_names.append(tool_name)
|
||||
|
||||
# Cache the result
|
||||
with open(cache_tools_path, "w", encoding="utf-8") as f:
|
||||
json.dump(tool_names, f)
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
def _get_file_hash(file_path: str) -> str:
|
||||
"""Get file content hash using SHA256"""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
return hashlib.sha256(file_content).hexdigest()
|
||||
except (OSError, IOError):
|
||||
return ""
|
||||
|
||||
|
||||
def _load_cached_index() -> bool:
|
||||
"""Load cached vector index from disk"""
|
||||
global _cached_vector_store, _cached_tools_info, _cached_file_hash
|
||||
|
||||
try:
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)
|
||||
|
||||
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
|
||||
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
|
||||
|
||||
if not all(
|
||||
os.path.exists(p) for p in [index_path, metadata_path]
|
||||
):
|
||||
return False
|
||||
|
||||
# Check if cached index matches current tools info file
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
cached_hash = metadata.get("tools_info_hash", "")
|
||||
current_hash = _get_file_hash(TOOLS_INFO_PATH)
|
||||
|
||||
if current_hash != cached_hash:
|
||||
return False
|
||||
|
||||
# Load cached data
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
|
||||
embeddings = DashScopeEmbeddings(
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
||||
model="text-embedding-v1",
|
||||
)
|
||||
|
||||
_cached_vector_store = FAISS.load_local(
|
||||
index_path, embeddings, allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
_cached_file_hash = cached_hash
|
||||
|
||||
logging.info("Successfully loaded cached vector index")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to load cached index: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _save_cached_index():
|
||||
"""Save vector index to disk cache"""
|
||||
global _cached_vector_store, _cached_file_hash
|
||||
|
||||
try:
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(VECTOR_INDEX_CACHE_PATH, exist_ok=True)
|
||||
|
||||
index_path = osp.join(VECTOR_INDEX_CACHE_PATH, "faiss_index")
|
||||
metadata_path = osp.join(VECTOR_INDEX_CACHE_PATH, "metadata.json")
|
||||
|
||||
# Save vector store
|
||||
if _cached_vector_store:
|
||||
_cached_vector_store.save_local(index_path)
|
||||
|
||||
# Save metadata
|
||||
metadata = {"tools_info_hash": _cached_file_hash, "created_at": time.time()}
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
logging.info("Successfully saved vector index to cache")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save cached index: {e}")
|
||||
|
||||
|
||||
def _build_vector_index():
|
||||
"""Build and cache vector index"""
|
||||
global _cached_vector_store, _cached_file_hash
|
||||
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
tools_info = json.loads(f.read())
|
||||
|
||||
tool_descriptions = [f"{t['class_name']}: {t['class_desc']}" for t in tools_info]
|
||||
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
|
||||
embeddings = DashScopeEmbeddings(
|
||||
dashscope_api_key=os.environ.get("DASHSCOPE_API_KEY"), model="text-embedding-v1"
|
||||
)
|
||||
|
||||
metadatas = [{"index": i} for i in range(len(tool_descriptions))]
|
||||
vector_store = FAISS.from_texts(tool_descriptions, embeddings, metadatas=metadatas)
|
||||
|
||||
# Cache the results
|
||||
_cached_vector_store = vector_store
|
||||
_cached_file_hash = _get_file_hash(TOOLS_INFO_PATH)
|
||||
|
||||
# Save to disk cache
|
||||
_save_cached_index()
|
||||
|
||||
logging.info("Successfully built and cached vector index")
|
||||
|
||||
|
||||
def retrieve_ops_vector(user_query, limit=20):
|
||||
"""Tool retrieval using vector search with caching - returns list of tool names"""
|
||||
global _cached_vector_store
|
||||
|
||||
# Try to load from cache first
|
||||
if not _load_cached_index():
|
||||
logging.info("Building new vector index...")
|
||||
_build_vector_index()
|
||||
|
||||
# Perform similarity search
|
||||
retrieved_tools = _cached_vector_store.similarity_search(user_query, k=limit)
|
||||
retrieved_indices = [doc.metadata["index"] for doc in retrieved_tools]
|
||||
|
||||
with open(TOOLS_INFO_PATH, "r", encoding="utf-8") as f:
|
||||
tools_info = json.loads(f.read())
|
||||
|
||||
# Extract tool names from retrieved indices
|
||||
tool_names = []
|
||||
for raw_idx in retrieved_indices:
|
||||
tool_info = tools_info[raw_idx]
|
||||
tool_names.append(tool_info["class_name"])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
async def retrieve_ops(user_query: str, limit: int = 20, mode: str = "auto") -> list:
|
||||
"""
|
||||
Tool retrieval with configurable mode
|
||||
|
||||
Args:
|
||||
user_query: User query string
|
||||
limit: Maximum number of tools to retrieve
|
||||
mode: Retrieval mode - "llm", "vector", or "auto" (default: "auto")
|
||||
- "llm": Use language model only
|
||||
- "vector": Use vector search only
|
||||
- "auto": Try LLM first, fallback to vector search on failure
|
||||
|
||||
Returns:
|
||||
List of tool names
|
||||
"""
|
||||
if mode == "llm":
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"LLM retrieval failed: {str(e)}")
|
||||
return []
|
||||
|
||||
elif mode == "vector":
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
logging.error(f"Vector retrieval failed: {str(e)}")
|
||||
return []
|
||||
|
||||
elif mode == "auto":
|
||||
try:
|
||||
return await retrieve_ops_lm(user_query, limit=limit)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
print(traceback.format_exc())
|
||||
try:
|
||||
return retrieve_ops_vector(user_query, limit=limit)
|
||||
except Exception as fallback_e:
|
||||
logging.error(
|
||||
f"Tool retrieval failed: {str(e)}, fallback retrieval also failed: {str(fallback_e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {mode}. Must be 'llm', 'vector', or 'auto'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
user_query = (
|
||||
"Clean special characters from text and filter samples with excessive length. Mask sensitive information and filter unsafe content including adult/terror-related terms."
|
||||
+ "Additionally, filter out small images, perform image tagging, and remove duplicate images."
|
||||
)
|
||||
|
||||
# Test different modes
|
||||
print("=== Testing LLM mode ===")
|
||||
tool_names_llm = asyncio.run(retrieve_ops(user_query, limit=10, mode="llm"))
|
||||
print("Retrieved tool names (LLM):")
|
||||
print(tool_names_llm)
|
||||
|
||||
print("\n=== Testing Vector mode ===")
|
||||
tool_names_vector = asyncio.run(retrieve_ops(user_query, limit=10, mode="vector"))
|
||||
print("Retrieved tool names (Vector):")
|
||||
print(tool_names_vector)
|
||||
|
||||
print("\n=== Testing Auto mode (default) ===")
|
||||
tool_names_auto = asyncio.run(retrieve_ops(user_query, limit=10, mode="auto"))
|
||||
print("Retrieved tool names (Auto):")
|
||||
print(tool_names_auto)
|
||||
62
data_juicer_agent/tools/router_tools.py
Normal file
62
data_juicer_agent/tools/router_tools.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Router agent using implicit routing"""
|
||||
from typing import Callable
|
||||
from agentscope.agent import AgentBase
|
||||
from agentscope.message import Msg
|
||||
from agentscope.tool import ToolResponse
|
||||
|
||||
|
||||
def agent_to_tool(
|
||||
agent: AgentBase, tool_name: str = None, description: str = None
|
||||
) -> Callable:
|
||||
"""
|
||||
Convert any agent to a tool function that can be registered in toolkit.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to convert
|
||||
tool_name: Optional custom tool name (defaults to agent.name)
|
||||
description: Optional tool description (defaults to agent's docstring or sys_prompt)
|
||||
|
||||
Returns:
|
||||
A tool function that can be registered with toolkit.register_tool_function()
|
||||
"""
|
||||
# Get tool name and description
|
||||
if tool_name is None:
|
||||
tool_name = getattr(agent, "name", "agent_tool")
|
||||
|
||||
if description is None:
|
||||
# Try to get description from agent's docstring or sys_prompt
|
||||
if hasattr(agent, "__doc__") and agent.__doc__:
|
||||
description = agent.__doc__.strip()
|
||||
elif hasattr(agent, "sys_prompt"):
|
||||
description = f"Agent: {agent.sys_prompt[:100]}..."
|
||||
elif hasattr(agent, "_sys_prompt"):
|
||||
description = f"Agent: {agent._sys_prompt[:100]}..."
|
||||
else:
|
||||
description = f"Tool function for {tool_name}"
|
||||
|
||||
async def tool_function(task: str) -> ToolResponse:
|
||||
# Create message and call the agent
|
||||
msg = Msg("user", task, "user")
|
||||
result = await agent(msg)
|
||||
|
||||
# Extract content from the result
|
||||
if hasattr(result, "get_content_blocks"):
|
||||
content = result.get_content_blocks("text")
|
||||
return ToolResponse(
|
||||
content=content,
|
||||
metadata={
|
||||
"agent_name": getattr(agent, "name", "unknown"),
|
||||
"task": task,
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Not a valid Msg object: {result}")
|
||||
|
||||
# Set function name and docstring
|
||||
tool_function.__name__ = f"call_{tool_name.lower().replace(' ', '_')}"
|
||||
tool_function.__doc__ = (
|
||||
f"{description}\n\nArgs:\n task (str): The task for {tool_name} to handle"
|
||||
)
|
||||
|
||||
return tool_function
|
||||
Reference in New Issue
Block a user